This is an automated email from the ASF dual-hosted git repository.

dinglei pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/rocketmq-client-go.git


The following commit(s) were added to refs/heads/master by this push:
     new 7ae83c4  fix: fix producer send msg timeout option does not take 
effect (#1109)
7ae83c4 is described below

commit 7ae83c49351a96d78ebccec5dd545f0c05e3d514
Author: WeizhongTu <weizhong....@alibaba-inc.com>
AuthorDate: Tue Nov 28 10:35:34 2023 +0800

    fix: fix producer send msg timeout option does not take effect (#1109)
---
 consumer/option.go               |  9 +++++++++
 consumer/option_test.go          | 15 +++++++++++++++
 internal/remote/remote_client.go | 26 +++++++++++++++-----------
 producer/option.go               |  9 +++++++++
 producer/option_test.go          | 15 +++++++++++++++
 producer/producer.go             | 11 ++++++-----
 6 files changed, 69 insertions(+), 16 deletions(-)

diff --git a/consumer/option.go b/consumer/option.go
index 24acf7c..2e08163 100644
--- a/consumer/option.go
+++ b/consumer/option.go
@@ -382,6 +382,15 @@ func WithLimiter(limiter Limiter) Option {
        }
 }
 
+// WithRemotingTimeout set remote client timeout options
+func WithRemotingTimeout(connectionTimeout, readTimeout, writeTimeout 
time.Duration) Option {
+       return func(opts *consumerOptions) {
+               opts.ClientOptions.RemotingClientConfig.ConnectionTimeout = 
connectionTimeout
+               opts.ClientOptions.RemotingClientConfig.ReadTimeout = 
readTimeout
+               opts.ClientOptions.RemotingClientConfig.WriteTimeout = 
writeTimeout
+       }
+}
+
 func WithTls(useTls bool) Option {
        return func(opts *consumerOptions) {
                opts.ClientOptions.RemotingClientConfig.UseTls = useTls
diff --git a/consumer/option_test.go b/consumer/option_test.go
index ab99b63..4db5b93 100644
--- a/consumer/option_test.go
+++ b/consumer/option_test.go
@@ -3,6 +3,7 @@ package consumer
 import (
        "reflect"
        "testing"
+       "time"
 )
 
 func getFieldString(obj interface{}, field string) string {
@@ -12,6 +13,20 @@ func getFieldString(obj interface{}, field string) string {
        }).String()
 }
 
+func TestWithRemotingTimeout(t *testing.T) {
+       opt := defaultPushConsumerOptions()
+       WithRemotingTimeout(3*time.Second, 4*time.Second, 5*time.Second)(&opt)
+       if timeout := opt.RemotingClientConfig.ConnectionTimeout; timeout != 
3*time.Second {
+               t.Errorf("consumer option WithRemotingTimeout 
connectionTimeout. want:%s, got=%s", 3*time.Second, timeout)
+       }
+       if timeout := opt.RemotingClientConfig.ReadTimeout; timeout != 
4*time.Second {
+               t.Errorf("consumer option WithRemotingTimeout readTimeout. 
want:%s, got=%s", 4*time.Second, timeout)
+       }
+       if timeout := opt.RemotingClientConfig.WriteTimeout; timeout != 
5*time.Second {
+               t.Errorf("consumer option WithRemotingTimeout writeTimeout. 
want:%s, got=%s", 5*time.Second, timeout)
+       }
+}
+
 func TestWithUnitName(t *testing.T) {
        opt := defaultPushConsumerOptions()
        unitName := "unsh"
diff --git a/internal/remote/remote_client.go b/internal/remote/remote_client.go
index 45dfbbf..36fbea7 100644
--- a/internal/remote/remote_client.go
+++ b/internal/remote/remote_client.go
@@ -102,7 +102,7 @@ func (c *remotingClient) InvokeSync(ctx context.Context, 
addr string, request *R
        c.responseTable.Store(resp.Opaque, resp)
        defer c.responseTable.Delete(request.Opaque)
 
-       err = c.sendRequest(conn, request)
+       err = c.sendRequest(ctx, conn, request)
        if err != nil {
                return nil, err
        }
@@ -120,7 +120,7 @@ func (c *remotingClient) InvokeAsync(ctx context.Context, 
addr string, request *
        resp := NewResponseFuture(ctx, request.Opaque, callback)
        c.responseTable.Store(resp.Opaque, resp)
 
-       err = c.sendRequest(conn, request)
+       err = c.sendRequest(ctx, conn, request)
        if err != nil {
                c.responseTable.Delete(request.Opaque)
                return err
@@ -146,11 +146,11 @@ func (c *remotingClient) InvokeOneWay(ctx 
context.Context, addr string, request
        if err != nil {
                return err
        }
-       return c.sendRequest(conn, request)
+       return c.sendRequest(ctx, conn, request)
 }
 
 func (c *remotingClient) connect(ctx context.Context, addr string) 
(*tcpConnWrapper, error) {
-       //it needs additional locker.
+       // it needs additional locker.
        c.connectionLocker.Lock()
        defer c.connectionLocker.Unlock()
        conn, ok := c.connectionTable.Load(addr)
@@ -246,7 +246,7 @@ func (c *remotingClient) processCMD(cmd *RemotingCommand, r 
*tcpConnWrapper) {
                                if res != nil {
                                        res.Opaque = cmd.Opaque
                                        res.Flag |= 1 << 0
-                                       err := c.sendRequest(r, res)
+                                       err := 
c.sendRequest(context.Background(), r, res)
                                        if err != nil {
                                                rlog.Warning("send response to 
broker error", map[string]interface{}{
                                                        
rlog.LogKeyUnderlayError: err,
@@ -297,23 +297,27 @@ func (c *remotingClient) createScanner(r io.Reader) 
*bufio.Scanner {
        return scanner
 }
 
-func (c *remotingClient) sendRequest(conn *tcpConnWrapper, request 
*RemotingCommand) error {
+func (c *remotingClient) sendRequest(ctx context.Context, conn 
*tcpConnWrapper, request *RemotingCommand) error {
        var err error
        if c.interceptor != nil {
-               err = c.interceptor(context.Background(), request, nil, 
func(ctx context.Context, req, reply interface{}) error {
-                       return c.doRequest(conn, request)
+               err = c.interceptor(ctx, request, nil, func(ctx 
context.Context, req, reply interface{}) error {
+                       return c.doRequest(ctx, conn, request)
                })
        } else {
-               err = c.doRequest(conn, request)
+               err = c.doRequest(ctx, conn, request)
        }
        return err
 }
 
-func (c *remotingClient) doRequest(conn *tcpConnWrapper, request 
*RemotingCommand) error {
+func (c *remotingClient) doRequest(ctx context.Context, conn *tcpConnWrapper, 
request *RemotingCommand) error {
        conn.Lock()
        defer conn.Unlock()
 
-       err := conn.Conn.SetWriteDeadline(time.Now().Add(c.config.WriteTimeout))
+       deadline, ok := ctx.Deadline()
+       if !ok {
+               deadline = time.Now().Add(c.config.WriteTimeout)
+       }
+       err := conn.Conn.SetWriteDeadline(deadline)
        if err != nil {
                rlog.Error("conn error, close connection", 
map[string]interface{}{
                        rlog.LogKeyUnderlayError: err,
diff --git a/producer/option.go b/producer/option.go
index 6e43cc2..72af3c6 100644
--- a/producer/option.go
+++ b/producer/option.go
@@ -179,6 +179,15 @@ func WithCompressLevel(level int) Option {
        }
 }
 
+// WithRemotingTimeout set remote client timeout options
+func WithRemotingTimeout(connectionTimeout, readTimeout, writeTimeout 
time.Duration) Option {
+       return func(opts *producerOptions) {
+               opts.ClientOptions.RemotingClientConfig.ConnectionTimeout = 
connectionTimeout
+               opts.ClientOptions.RemotingClientConfig.ReadTimeout = 
readTimeout
+               opts.ClientOptions.RemotingClientConfig.WriteTimeout = 
writeTimeout
+       }
+}
+
 func WithTls(useTls bool) Option {
        return func(opts *producerOptions) {
                opts.ClientOptions.RemotingClientConfig.UseTls = useTls
diff --git a/producer/option_test.go b/producer/option_test.go
index 723da03..9b6ee13 100644
--- a/producer/option_test.go
+++ b/producer/option_test.go
@@ -3,6 +3,7 @@ package producer
 import (
        "reflect"
        "testing"
+       "time"
 )
 
 func getFieldString(obj interface{}, field string) string {
@@ -12,6 +13,20 @@ func getFieldString(obj interface{}, field string) string {
        }).String()
 }
 
+func TestWithRemotingTimeout(t *testing.T) {
+       opt := defaultProducerOptions()
+       WithRemotingTimeout(3*time.Second, 4*time.Second, 5*time.Second)(&opt)
+       if timeout := opt.RemotingClientConfig.ConnectionTimeout; timeout != 
3*time.Second {
+               t.Errorf("consumer option WithRemotingTimeout 
connectionTimeout. want:%s, got=%s", 3*time.Second, timeout)
+       }
+       if timeout := opt.RemotingClientConfig.ReadTimeout; timeout != 
4*time.Second {
+               t.Errorf("consumer option WithRemotingTimeout readTimeout. 
want:%s, got=%s", 4*time.Second, timeout)
+       }
+       if timeout := opt.RemotingClientConfig.WriteTimeout; timeout != 
5*time.Second {
+               t.Errorf("consumer option WithRemotingTimeout writeTimeout. 
want:%s, got=%s", 5*time.Second, timeout)
+       }
+}
+
 func TestWithUnitName(t *testing.T) {
        opt := defaultProducerOptions()
        unitName := "unsh"
diff --git a/producer/producer.go b/producer/producer.go
index f823884..70e8d01 100644
--- a/producer/producer.go
+++ b/producer/producer.go
@@ -26,14 +26,15 @@ import (
        "sync/atomic"
        "time"
 
+       "github.com/google/uuid"
+       "github.com/pkg/errors"
+
        errors2 "github.com/apache/rocketmq-client-go/v2/errors"
        "github.com/apache/rocketmq-client-go/v2/internal"
        "github.com/apache/rocketmq-client-go/v2/internal/remote"
        "github.com/apache/rocketmq-client-go/v2/internal/utils"
        "github.com/apache/rocketmq-client-go/v2/primitive"
        "github.com/apache/rocketmq-client-go/v2/rlog"
-       "github.com/google/uuid"
-       "github.com/pkg/errors"
 )
 
 type defaultProducer struct {
@@ -355,7 +356,7 @@ func (p *defaultProducer) sendSync(ctx context.Context, msg 
*primitive.Message,
                        producerCtx.MQ = *mq
                }
 
-               res, _err := p.client.InvokeSync(ctx, addr, 
p.buildSendRequest(mq, msg), 3*time.Second)
+               res, _err := p.client.InvokeSync(ctx, addr, 
p.buildSendRequest(mq, msg), p.options.SendMsgTimeout)
                if _err != nil {
                        err = _err
                        continue
@@ -400,7 +401,7 @@ func (p *defaultProducer) sendAsync(ctx context.Context, 
msg *primitive.Message,
                return errors.Errorf("topic=%s route info not found", mq.Topic)
        }
 
-       ctx, cancel := context.WithTimeout(ctx, 3*time.Second)
+       ctx, cancel := context.WithTimeout(ctx, p.options.SendMsgTimeout)
        err := p.client.InvokeAsync(ctx, addr, p.buildSendRequest(mq, msg), 
func(command *remote.RemotingCommand, err error) {
                cancel()
                if err != nil {
@@ -465,7 +466,7 @@ func (p *defaultProducer) sendOneWay(ctx context.Context, 
msg *primitive.Message
                        return fmt.Errorf("topic=%s route info not found", 
mq.Topic)
                }
 
-               _err := p.client.InvokeOneWay(ctx, addr, p.buildSendRequest(mq, 
msg), 3*time.Second)
+               _err := p.client.InvokeOneWay(ctx, addr, p.buildSendRequest(mq, 
msg), p.options.SendMsgTimeout)
                if _err != nil {
                        err = _err
                        continue

Reply via email to