This is an automated email from the ASF dual-hosted git repository.

dinglei pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/rocketmq-client-go.git


The following commit(s) were added to refs/heads/master by this push:
     new 6c97b55  [ISSUE #787]Refactor the client instance struct,converge the 
namesrv module (#788)
6c97b55 is described below

commit 6c97b55c8ca21a0dc5e1f25a8ce949cb99c213f3
Author: guyinyou <36399867+guyin...@users.noreply.github.com>
AuthorDate: Tue Mar 15 15:28:39 2022 +0800

    [ISSUE #787]Refactor the client instance struct,converge the namesrv module 
(#788)
---
 admin/admin.go            | 30 ++++++++++++---------
 consumer/consumer.go      | 36 ++++++++++++--------------
 consumer/consumer_test.go |  5 ++--
 consumer/interceptor.go   |  8 +++++-
 consumer/pull_consumer.go | 10 ++++---
 consumer/push_consumer.go | 18 +++++--------
 internal/client.go        | 66 ++++++++++++++++++++++++++++++++++-------------
 internal/mock_client.go   | 15 +++++++++--
 internal/mock_namesrv.go  |  4 +++
 internal/namesrv.go       |  2 ++
 internal/route.go         |  2 +-
 internal/trace.go         |  4 +++
 producer/interceptor.go   |  8 +++++-
 producer/option.go        |  2 +-
 producer/producer.go      | 30 ++++++++++++---------
 producer/producer_test.go | 24 ++++++++++++-----
 rlog/log.go               |  1 +
 17 files changed, 174 insertions(+), 91 deletions(-)

diff --git a/admin/admin.go b/admin/admin.go
index 1957a06..06908f4 100644
--- a/admin/admin.go
+++ b/admin/admin.go
@@ -19,6 +19,7 @@ package admin
 
 import (
        "context"
+       "fmt"
        "sync"
        "time"
 
@@ -61,8 +62,7 @@ func WithResolver(resolver primitive.NsResolver) AdminOption {
 }
 
 type admin struct {
-       cli     internal.RMQClient
-       namesrv internal.Namesrvs
+       cli internal.RMQClient
 
        opts *adminOptions
 
@@ -75,17 +75,21 @@ func NewAdmin(opts ...AdminOption) (Admin, error) {
        for _, opt := range opts {
                opt(defaultOpts)
        }
-
-       cli := internal.GetOrNewRocketMQClient(defaultOpts.ClientOptions, nil)
        namesrv, err := internal.NewNamesrv(defaultOpts.Resolver)
+       defaultOpts.Namesrv = namesrv
        if err != nil {
                return nil, err
        }
+
+       cli := internal.GetOrNewRocketMQClient(defaultOpts.ClientOptions, nil)
+       if cli == nil {
+               return nil, fmt.Errorf("GetOrNewRocketMQClient faild")
+       }
+       defaultOpts.Namesrv = cli.GetNameSrv()
        //log.Printf("Client: %#v", namesrv.srvs)
        return &admin{
-               cli:     cli,
-               namesrv: namesrv,
-               opts:    defaultOpts,
+               cli:  cli,
+               opts: defaultOpts,
        }, nil
 }
 
@@ -153,8 +157,8 @@ func (a *admin) DeleteTopic(ctx context.Context, opts 
...OptionDelete) error {
        }
        //delete topic in broker
        if cfg.BrokerAddr == "" {
-               a.namesrv.UpdateTopicRouteInfo(cfg.Topic)
-               cfg.BrokerAddr = a.namesrv.FindBrokerAddrByTopic(cfg.Topic)
+               a.cli.GetNameSrv().UpdateTopicRouteInfo(cfg.Topic)
+               cfg.BrokerAddr = 
a.cli.GetNameSrv().FindBrokerAddrByTopic(cfg.Topic)
        }
 
        if _, err := a.deleteTopicInBroker(ctx, cfg.Topic, cfg.BrokerAddr); err 
!= nil {
@@ -168,14 +172,16 @@ func (a *admin) DeleteTopic(ctx context.Context, opts 
...OptionDelete) error {
 
        //delete topic in nameserver
        if len(cfg.NameSrvAddr) == 0 {
-               _, _, err := a.namesrv.UpdateTopicRouteInfo(cfg.Topic)
+               a.cli.GetNameSrv().UpdateTopicRouteInfo(cfg.Topic)
+               cfg.NameSrvAddr = a.cli.GetNameSrv().AddrList()
+               _, _, err := a.cli.GetNameSrv().UpdateTopicRouteInfo(cfg.Topic)
                if err != nil {
                        rlog.Error("delete topic in nameserver error", 
map[string]interface{}{
-                               rlog.LogKeyTopic: cfg.Topic,
+                               rlog.LogKeyTopic:         cfg.Topic,
                                rlog.LogKeyUnderlayError: err,
                        })
                }
-               cfg.NameSrvAddr = a.namesrv.AddrList()
+               cfg.NameSrvAddr = a.cli.GetNameSrv().AddrList()
        }
 
        for _, nameSrvAddr := range cfg.NameSrvAddr {
diff --git a/consumer/consumer.go b/consumer/consumer.go
index ad77bc1..8056b22 100644
--- a/consumer/consumer.go
+++ b/consumer/consumer.go
@@ -263,8 +263,6 @@ type defaultConsumer struct {
        // chan for push consumer
        prCh chan PullRequest
 
-       namesrv internal.Namesrvs
-
        pullFromWhichNodeTable sync.Map
 
        stat *StatsManager
@@ -280,7 +278,7 @@ func (dc *defaultConsumer) start() error {
 
        if dc.model == Clustering {
                dc.option.ChangeInstanceNameToPID()
-               dc.storage = NewRemoteOffsetStore(dc.consumerGroup, dc.client, 
dc.namesrv)
+               dc.storage = NewRemoteOffsetStore(dc.consumerGroup, dc.client, 
dc.client.GetNameSrv())
        } else {
                dc.storage = NewLocalFileOffsetStore(dc.consumerGroup, 
dc.client.ClientID())
        }
@@ -448,7 +446,7 @@ type lockBatchRequestBody struct {
 }
 
 func (dc *defaultConsumer) lock(mq *primitive.MessageQueue) bool {
-       brokerResult := dc.namesrv.FindBrokerAddressInSubscribe(mq.BrokerName, 
internal.MasterId, true)
+       brokerResult := 
dc.client.GetNameSrv().FindBrokerAddressInSubscribe(mq.BrokerName, 
internal.MasterId, true)
 
        if brokerResult == nil {
                return false
@@ -488,7 +486,7 @@ func (dc *defaultConsumer) lock(mq *primitive.MessageQueue) 
bool {
 }
 
 func (dc *defaultConsumer) unlock(mq *primitive.MessageQueue, oneway bool) {
-       brokerResult := dc.namesrv.FindBrokerAddressInSubscribe(mq.BrokerName, 
internal.MasterId, true)
+       brokerResult := 
dc.client.GetNameSrv().FindBrokerAddressInSubscribe(mq.BrokerName, 
internal.MasterId, true)
 
        if brokerResult == nil {
                return
@@ -513,7 +511,7 @@ func (dc *defaultConsumer) lockAll() {
                if len(mqs) == 0 {
                        continue
                }
-               brokerResult := dc.namesrv.FindBrokerAddressInSubscribe(broker, 
internal.MasterId, true)
+               brokerResult := 
dc.client.GetNameSrv().FindBrokerAddressInSubscribe(broker, internal.MasterId, 
true)
                if brokerResult == nil {
                        continue
                }
@@ -559,7 +557,7 @@ func (dc *defaultConsumer) unlockAll(oneway bool) {
                if len(mqs) == 0 {
                        continue
                }
-               brokerResult := dc.namesrv.FindBrokerAddressInSubscribe(broker, 
internal.MasterId, true)
+               brokerResult := 
dc.client.GetNameSrv().FindBrokerAddressInSubscribe(broker, internal.MasterId, 
true)
                if brokerResult == nil {
                        continue
                }
@@ -892,10 +890,10 @@ func (dc *defaultConsumer) processPullResult(mq 
*primitive.MessageQueue, result
 }
 
 func (dc *defaultConsumer) findConsumerList(topic string) []string {
-       brokerAddr := dc.namesrv.FindBrokerAddrByTopic(topic)
+       brokerAddr := dc.client.GetNameSrv().FindBrokerAddrByTopic(topic)
        if brokerAddr == "" {
-               dc.namesrv.UpdateTopicRouteInfo(topic)
-               brokerAddr = dc.namesrv.FindBrokerAddrByTopic(topic)
+               dc.client.GetNameSrv().UpdateTopicRouteInfo(topic)
+               brokerAddr = dc.client.GetNameSrv().FindBrokerAddrByTopic(topic)
        }
 
        if brokerAddr != "" {
@@ -929,10 +927,10 @@ func (dc *defaultConsumer) sendBack(msg 
*primitive.MessageExt, level int) error
 
 // QueryMaxOffset with specific queueId and topic
 func (dc *defaultConsumer) queryMaxOffset(mq *primitive.MessageQueue) (int64, 
error) {
-       brokerAddr := dc.namesrv.FindBrokerAddrByName(mq.BrokerName)
+       brokerAddr := dc.client.GetNameSrv().FindBrokerAddrByName(mq.BrokerName)
        if brokerAddr == "" {
-               dc.namesrv.UpdateTopicRouteInfo(mq.Topic)
-               brokerAddr = dc.namesrv.FindBrokerAddrByName(mq.BrokerName)
+               dc.client.GetNameSrv().UpdateTopicRouteInfo(mq.Topic)
+               brokerAddr = 
dc.client.GetNameSrv().FindBrokerAddrByName(mq.BrokerName)
        }
        if brokerAddr == "" {
                return -1, fmt.Errorf("the broker [%s] does not exist", 
mq.BrokerName)
@@ -958,10 +956,10 @@ func (dc *defaultConsumer) queryOffset(mq 
*primitive.MessageQueue) int64 {
 
 // SearchOffsetByTimestamp with specific queueId and topic
 func (dc *defaultConsumer) searchOffsetByTimestamp(mq *primitive.MessageQueue, 
timestamp int64) (int64, error) {
-       brokerAddr := dc.namesrv.FindBrokerAddrByName(mq.BrokerName)
+       brokerAddr := dc.client.GetNameSrv().FindBrokerAddrByName(mq.BrokerName)
        if brokerAddr == "" {
-               dc.namesrv.UpdateTopicRouteInfo(mq.Topic)
-               brokerAddr = dc.namesrv.FindBrokerAddrByName(mq.BrokerName)
+               dc.client.GetNameSrv().UpdateTopicRouteInfo(mq.Topic)
+               brokerAddr = 
dc.client.GetNameSrv().FindBrokerAddrByName(mq.BrokerName)
        }
        if brokerAddr == "" {
                return -1, fmt.Errorf("the broker [%s] does not exist", 
mq.BrokerName)
@@ -1044,12 +1042,12 @@ func clearCommitOffsetFlag(sysFlag int32) int32 {
 }
 
 func (dc *defaultConsumer) tryFindBroker(mq *primitive.MessageQueue) 
*internal.FindBrokerResult {
-       result := dc.namesrv.FindBrokerAddressInSubscribe(mq.BrokerName, 
dc.recalculatePullFromWhichNode(mq), false)
+       result := 
dc.client.GetNameSrv().FindBrokerAddressInSubscribe(mq.BrokerName, 
dc.recalculatePullFromWhichNode(mq), false)
        if result != nil {
                return result
        }
-       dc.namesrv.UpdateTopicRouteInfo(mq.Topic)
-       return dc.namesrv.FindBrokerAddressInSubscribe(mq.BrokerName, 
dc.recalculatePullFromWhichNode(mq), false)
+       dc.client.GetNameSrv().UpdateTopicRouteInfo(mq.Topic)
+       return 
dc.client.GetNameSrv().FindBrokerAddressInSubscribe(mq.BrokerName, 
dc.recalculatePullFromWhichNode(mq), false)
 }
 
 func (dc *defaultConsumer) updatePullFromWhichNode(mq *primitive.MessageQueue, 
brokerId int64) {
diff --git a/consumer/consumer_test.go b/consumer/consumer_test.go
index 8b99767..12ccd18 100644
--- a/consumer/consumer_test.go
+++ b/consumer/consumer_test.go
@@ -67,7 +67,6 @@ func TestDoRebalance(t *testing.T) {
                defer ctrl.Finish()
                namesrvCli := internal.NewMockNamesrvs(ctrl)
                
namesrvCli.EXPECT().FindBrokerAddrByTopic(gomock.Any()).Return(broker)
-               dc.namesrv = namesrvCli
 
                rmqCli := internal.NewMockRMQClient(ctrl)
                rmqCli.EXPECT().InvokeSync(gomock.Any(), gomock.Any(), 
gomock.Any(), gomock.Any()).
@@ -75,6 +74,8 @@ func TestDoRebalance(t *testing.T) {
                                Body: []byte("{\"consumerIdList\": [\"a1\", 
\"a2\", \"a3\"] }"),
                        }, nil)
                rmqCli.EXPECT().ClientID().Return(clientID)
+               rmqCli.SetNameSrv(namesrvCli)
+
                dc.client = rmqCli
 
                var wg sync.WaitGroup
@@ -109,10 +110,10 @@ func TestComputePullFromWhere(t *testing.T) {
                }
 
                namesrvCli := internal.NewMockNamesrvs(ctrl)
-               dc.namesrv = namesrvCli
 
                rmqCli := internal.NewMockRMQClient(ctrl)
                dc.client = rmqCli
+               rmqCli.SetNameSrv(namesrvCli)
 
                Convey("get effective offset", func() {
                        offsetStore.EXPECT().read(gomock.Any(), 
gomock.Any()).Return(int64(10))
diff --git a/consumer/interceptor.go b/consumer/interceptor.go
index aababfe..05ff94a 100644
--- a/consumer/interceptor.go
+++ b/consumer/interceptor.go
@@ -19,6 +19,7 @@ package consumer
 
 import (
        "context"
+       "fmt"
        "time"
 
        "github.com/apache/rocketmq-client-go/v2/internal"
@@ -39,9 +40,14 @@ func WithTrace(traceCfg *primitive.TraceConfig) Option {
 
 func newTraceInterceptor(traceCfg *primitive.TraceConfig) 
primitive.Interceptor {
        dispatcher := internal.NewTraceDispatcher(traceCfg)
-       dispatcher.Start()
+       if dispatcher != nil {
+               dispatcher.Start()
+       }
 
        return func(ctx context.Context, req, reply interface{}, next 
primitive.Invoker) error {
+               if dispatcher == nil {
+                       return fmt.Errorf("GetOrNewRocketMQClient faild")
+               }
                consumerCtx, exist := primitive.GetConsumerCtx(ctx)
                if !exist || len(consumerCtx.Msgs) == 0 {
                        return next(ctx, req, reply)
diff --git a/consumer/pull_consumer.go b/consumer/pull_consumer.go
index 81a3ec5..874973b 100644
--- a/consumer/pull_consumer.go
+++ b/consumer/pull_consumer.go
@@ -90,10 +90,12 @@ func NewPullConsumer(options ...Option) 
(*defaultPullConsumer, error) {
                prCh:          make(chan PullRequest, 4),
                model:         defaultOpts.ConsumerModel,
                option:        defaultOpts,
-
-               namesrv: srvs,
        }
-       dc.option.ClientOptions.Namesrv, err = 
internal.GetNamesrv(dc.client.ClientID())
+       if dc.client == nil {
+               return nil, fmt.Errorf("GetOrNewRocketMQClient faild")
+       }
+       defaultOpts.Namesrv = dc.client.GetNameSrv()
+
        c := &defaultPullConsumer{
                defaultConsumer: dc,
        }
@@ -132,7 +134,7 @@ func (c *defaultPullConsumer) Pull(ctx context.Context, 
topic string, selector M
 }
 
 func (c *defaultPullConsumer) getNextQueueOf(topic string) 
*primitive.MessageQueue {
-       queues, err := 
c.defaultConsumer.namesrv.FetchSubscribeMessageQueues(topic)
+       queues, err := 
c.defaultConsumer.client.GetNameSrv().FetchSubscribeMessageQueues(topic)
        if err != nil && len(queues) > 0 {
                rlog.Error("get next mq error", map[string]interface{}{
                        rlog.LogKeyTopic:         topic,
diff --git a/consumer/push_consumer.go b/consumer/push_consumer.go
index 3e4377b..8642aa4 100644
--- a/consumer/push_consumer.go
+++ b/consumer/push_consumer.go
@@ -100,14 +100,13 @@ func NewPushConsumer(opts ...Option) (*pushConsumer, 
error) {
                consumeOrderly: defaultOpts.ConsumeOrderly,
                fromWhere:      defaultOpts.FromWhere,
                allocate:       defaultOpts.Strategy,
-               namesrv:        srvs,
                option:         defaultOpts,
        }
-       dc.option.ClientOptions.Namesrv, err = 
internal.GetNamesrv(dc.client.ClientID())
-       if err != nil {
-               return nil, err
+       if dc.client == nil {
+               return nil, fmt.Errorf("GetOrNewRocketMQClient faild")
        }
-       dc.namesrv = dc.option.ClientOptions.Namesrv
+       defaultOpts.Namesrv = dc.client.GetNameSrv()
+
        p := &pushConsumer{
                defaultConsumer: dc,
                subscribedTopic: make(map[string]string, 0),
@@ -124,11 +123,6 @@ func NewPushConsumer(opts ...Option) (*pushConsumer, 
error) {
 
        p.interceptor = primitive.ChainInterceptors(p.option.Interceptors...)
 
-       if p.model == Clustering {
-               retryTopic := internal.GetRetryTopic(p.consumerGroup)
-               sub := buildSubscriptionData(retryTopic, MessageSelector{TAG, 
_SubAll})
-               p.subscriptionDataTable.Store(retryTopic, sub)
-       }
        return p, nil
 }
 
@@ -386,7 +380,7 @@ func (pc *pushConsumer) GetConsumerRunningInfo() 
*internal.ConsumerRunningInfo {
        })
 
        nsAddr := ""
-       for _, value := range pc.namesrv.AddrList() {
+       for _, value := range pc.client.GetNameSrv().AddrList() {
                nsAddr += fmt.Sprintf("%s;", value)
        }
        info.Properties[internal.PropNameServerAddr] = nsAddr
@@ -795,7 +789,7 @@ func (pc *pushConsumer) correctTagsOffset(pr *PullRequest) {
 func (pc *pushConsumer) sendMessageBack(brokerName string, msg 
*primitive.MessageExt, delayLevel int) bool {
        var brokerAddr string
        if len(brokerName) != 0 {
-               brokerAddr = 
pc.defaultConsumer.namesrv.FindBrokerAddrByName(brokerName)
+               brokerAddr = 
pc.defaultConsumer.client.GetNameSrv().FindBrokerAddrByName(brokerName)
        } else {
                brokerAddr = msg.StoreHost
        }
diff --git a/internal/client.go b/internal/client.go
index 4d7769e..c7f3e58 100644
--- a/internal/client.go
+++ b/internal/client.go
@@ -24,6 +24,7 @@ import (
        errors2 "github.com/apache/rocketmq-client-go/v2/errors"
        "net"
        "os"
+       "sort"
        "strconv"
        "strings"
        "sync"
@@ -104,7 +105,7 @@ func DefaultClientOptions() ClientOptions {
 type ClientOptions struct {
        GroupName         string
        NameServerAddrs   primitive.NamesrvAddr
-       Namesrv           *namesrvs
+       Namesrv           Namesrvs
        ClientIP          string
        InstanceName      string
        UnitMode          bool
@@ -136,7 +137,7 @@ type RMQClient interface {
 
        ClientID() string
 
-       RegisterProducer(group string, producer InnerProducer)
+       RegisterProducer(group string, producer InnerProducer) error
        UnregisterProducer(group string)
        InvokeSync(ctx context.Context, addr string, request 
*remote.RemotingCommand,
                timeoutMillis time.Duration) (*remote.RemotingCommand, error)
@@ -155,6 +156,8 @@ type RMQClient interface {
        PullMessage(ctx context.Context, brokerAddrs string, request 
*PullMessageRequestHeader) (*primitive.PullResult, error)
        RebalanceImmediately()
        UpdatePublishInfo(topic string, data *TopicRouteData, changed bool)
+
+       GetNameSrv() Namesrvs
 }
 
 var _ RMQClient = new(rmqClient)
@@ -172,25 +175,48 @@ type rmqClient struct {
        hbMutex      sync.Mutex
        close        bool
        rbMutex      sync.Mutex
-       namesrvs     *namesrvs
        done         chan struct{}
        shutdownOnce sync.Once
 }
 
+func (c *rmqClient) GetNameSrv() Namesrvs {
+       return c.option.Namesrv
+}
+
 var clientMap sync.Map
 
 func GetOrNewRocketMQClient(option ClientOptions, callbackCh chan interface{}) 
RMQClient {
        client := &rmqClient{
                option:       option,
                remoteClient: remote.NewRemotingClient(),
-               namesrvs:     option.Namesrv,
                done:         make(chan struct{}),
        }
        actual, loaded := clientMap.LoadOrStore(client.ClientID(), client)
-       client.namesrvs = GetOrSetNamesrv(client.ClientID(), client.namesrvs)
-       client.namesrvs.bundleClient = actual.(*rmqClient)
-       client.option.Namesrv = client.namesrvs
-       if !loaded {
+
+       if loaded {
+               // compare namesrv address
+               client = actual.(*rmqClient)
+               now := option.Namesrv.(*namesrvs).resolver.Resolve()
+               old := client.GetNameSrv().(*namesrvs).resolver.Resolve()
+               if len(now) != len(old) {
+                       rlog.Error("different namesrv option in the same 
instance", map[string]interface{}{
+                               "NewNameSrv":    now,
+                               "BeforeNameSrv": old,
+                       })
+                       return nil
+               }
+               sort.Strings(now)
+               sort.Strings(old)
+               for i := 0; i < len(now); i++ {
+                       if now[i] != old[i] {
+                               rlog.Error("different namesrv option in the 
same instance", map[string]interface{}{
+                                       "NewNameSrv":    now,
+                                       "BeforeNameSrv": old,
+                               })
+                               return nil
+                       }
+               }
+       } else {
                
client.remoteClient.RegisterRequestFunc(ReqNotifyConsumerIdsChanged, func(req 
*remote.RemotingCommand, addr net.Addr) *remote.RemotingCommand {
                        rlog.Info("receive broker's notification to consumer 
group", map[string]interface{}{
                                rlog.LogKeyConsumerGroup: 
req.ExtFields["consumerGroup"],
@@ -306,7 +332,7 @@ func GetOrNewRocketMQClient(option ClientOptions, 
callbackCh chan interface{}) R
                        return nil
                })
        }
-       return actual.(*rmqClient)
+       return client
 }
 
 func (c *rmqClient) Start() {
@@ -318,7 +344,7 @@ func (c *rmqClient) Start() {
                }
                go primitive.WithRecover(func() {
                        op := func() {
-                               c.namesrvs.UpdateNameServerAddress()
+                               c.GetNameSrv().UpdateNameServerAddress()
                        }
                        time.Sleep(10 * time.Second)
                        op()
@@ -364,7 +390,7 @@ func (c *rmqClient) Start() {
 
                go primitive.WithRecover(func() {
                        op := func() {
-                               c.namesrvs.cleanOfflineBroker()
+                               c.GetNameSrv().cleanOfflineBroker()
                                c.SendHeartbeatToAllBrokerWithLock()
                        }
 
@@ -529,7 +555,7 @@ func (c *rmqClient) SendHeartbeatToAllBrokerWithLock() {
                rlog.Info("sending heartbeat, but no producer and no consumer", 
nil)
                return
        }
-       c.namesrvs.brokerAddressesMap.Range(func(key, value interface{}) bool {
+       c.GetNameSrv().(*namesrvs).brokerAddressesMap.Range(func(key, value 
interface{}) bool {
                brokerName := key.(string)
                data := value.(*BrokerData)
                for id, addr := range data.BrokerAddresses {
@@ -559,7 +585,7 @@ func (c *rmqClient) SendHeartbeatToAllBrokerWithLock() {
                        }
                        cancel()
                        if response.Code == ResSuccess {
-                               c.namesrvs.AddBrokerVersion(brokerName, addr, 
int32(response.Version))
+                               
c.GetNameSrv().(*namesrvs).AddBrokerVersion(brokerName, addr, 
int32(response.Version))
                                rlog.Debug("send heart beat to broker success", 
map[string]interface{}{
                                        "brokerName": brokerName,
                                        "brokerId":   id,
@@ -589,7 +615,7 @@ func (c *rmqClient) UpdateTopicRouteInfo() {
                return true
        })
        for topic := range publishTopicSet {
-               data, changed, _ := c.namesrvs.UpdateTopicRouteInfo(topic)
+               data, changed, _ := c.GetNameSrv().UpdateTopicRouteInfo(topic)
                c.UpdatePublishInfo(topic, data, changed)
        }
 
@@ -604,7 +630,7 @@ func (c *rmqClient) UpdateTopicRouteInfo() {
        })
 
        for topic := range subscribedTopicSet {
-               data, changed, _ := c.namesrvs.UpdateTopicRouteInfo(topic)
+               data, changed, _ := c.GetNameSrv().UpdateTopicRouteInfo(topic)
                c.updateSubscribeInfo(topic, data, changed)
        }
 }
@@ -730,8 +756,12 @@ func (c *rmqClient) UnregisterConsumer(group string) {
        c.consumerMap.Delete(group)
 }
 
-func (c *rmqClient) RegisterProducer(group string, producer InnerProducer) {
-       c.producerMap.Store(group, producer)
+func (c *rmqClient) RegisterProducer(group string, producer InnerProducer) 
error {
+       _, loaded := c.producerMap.LoadOrStore(group, producer)
+       if loaded {
+               return fmt.Errorf("the producer group \"%s\" has been created, 
specify another one", c.option.GroupName)
+       }
+       return nil
 }
 
 func (c *rmqClient) UnregisterProducer(group string) {
@@ -760,7 +790,7 @@ func (c *rmqClient) UpdatePublishInfo(topic string, data 
*TopicRouteData, change
                        updated = p.IsPublishTopicNeedUpdate(topic)
                }
                if updated {
-                       publishInfo := c.namesrvs.routeData2PublishInfo(topic, 
data)
+                       publishInfo := 
c.GetNameSrv().(*namesrvs).routeData2PublishInfo(topic, data)
                        publishInfo.HaveTopicRouterInfo = true
                        p.UpdateTopicPublishInfo(topic, publishInfo)
                }
diff --git a/internal/mock_client.go b/internal/mock_client.go
index ab34ac1..c975038 100644
--- a/internal/mock_client.go
+++ b/internal/mock_client.go
@@ -208,6 +208,15 @@ func (mr *MockInnerConsumerMockRecorder) 
GetConsumerRunningInfo() *gomock.Call {
 type MockRMQClient struct {
        ctrl     *gomock.Controller
        recorder *MockRMQClientMockRecorder
+       Namesrv  *MockNamesrvs
+}
+
+func (m *MockRMQClient) GetNameSrv() Namesrvs {
+       return m.Namesrv
+}
+
+func (m *MockRMQClient) SetNameSrv(mockNamesrvs *MockNamesrvs) {
+       m.Namesrv = mockNamesrvs
 }
 
 // MockRMQClientMockRecorder is the mock recorder for MockRMQClient
@@ -260,8 +269,10 @@ func (mr *MockRMQClientMockRecorder) ClientID() 
*gomock.Call {
 }
 
 // RegisterProducer mocks base method
-func (m *MockRMQClient) RegisterProducer(group string, producer InnerProducer) 
{
-       m.ctrl.Call(m, "RegisterProducer", group, producer)
+func (m *MockRMQClient) RegisterProducer(group string, producer InnerProducer) 
error {
+       ret := m.ctrl.Call(m, "RegisterProducer", group, producer)
+       ret0, _ := ret[0].(error)
+       return ret0
 }
 
 // RegisterProducer indicates an expected call of RegisterProducer
diff --git a/internal/mock_namesrv.go b/internal/mock_namesrv.go
index f87d174..7ce6f97 100644
--- a/internal/mock_namesrv.go
+++ b/internal/mock_namesrv.go
@@ -33,6 +33,10 @@ type MockNamesrvs struct {
        recorder *MockNamesrvsMockRecorder
 }
 
+func (m *MockNamesrvs) UpdateTopicRouteInfoWithDefault(topic string, 
defaultTopic string, defaultQueueNum int) (*TopicRouteData, bool, error) {
+       return m.UpdateTopicRouteInfo(topic)
+}
+
 // MockNamesrvsMockRecorder is the mock recorder for MockNamesrvs
 type MockNamesrvsMockRecorder struct {
        mock *MockNamesrvs
diff --git a/internal/namesrv.go b/internal/namesrv.go
index 7776651..96e708a 100644
--- a/internal/namesrv.go
+++ b/internal/namesrv.go
@@ -50,6 +50,8 @@ type Namesrvs interface {
 
        UpdateTopicRouteInfo(topic string) (routeData *TopicRouteData, changed 
bool, err error)
 
+       UpdateTopicRouteInfoWithDefault(topic string, defaultTopic string, 
defaultQueueNum int) (*TopicRouteData, bool, error)
+
        FetchPublishMessageQueues(topic string) ([]*primitive.MessageQueue, 
error)
 
        FindBrokerAddrByTopic(topic string) string
diff --git a/internal/route.go b/internal/route.go
index 9cfa398..54dbbea 100644
--- a/internal/route.go
+++ b/internal/route.go
@@ -165,7 +165,7 @@ func (s *namesrvs) UpdateTopicRouteInfoWithDefault(topic 
string, defaultTopic st
                                        updated = 
p.IsPublishTopicNeedUpdate(topic)
                                }
                                if updated {
-                                       publishInfo := 
s.bundleClient.namesrvs.routeData2PublishInfo(topic, routeData)
+                                       publishInfo := 
s.bundleClient.GetNameSrv().(*namesrvs).routeData2PublishInfo(topic, routeData)
                                        publishInfo.HaveTopicRouterInfo = true
                                        p.UpdateTopicPublishInfo(topic, 
publishInfo)
                                }
diff --git a/internal/trace.go b/internal/trace.go
index cef1634..753a4d1 100644
--- a/internal/trace.go
+++ b/internal/trace.go
@@ -276,6 +276,10 @@ func NewTraceDispatcher(traceCfg *primitive.TraceConfig) 
*traceDispatcher {
        cliOp.Namesrv = srvs
        cliOp.Credentials = traceCfg.Credentials
        cli := GetOrNewRocketMQClient(cliOp, nil)
+       if cli == nil {
+               return nil
+       }
+       cliOp.Namesrv = cli.GetNameSrv()
        return &traceDispatcher{
                ctx:    ctx,
                cancel: cancel,
diff --git a/producer/interceptor.go b/producer/interceptor.go
index 160deac..71eb8e7 100644
--- a/producer/interceptor.go
+++ b/producer/interceptor.go
@@ -22,6 +22,7 @@ package producer
 
 import (
        "context"
+       "fmt"
        "time"
 
        "github.com/apache/rocketmq-client-go/v2/internal"
@@ -42,9 +43,14 @@ func WithTrace(traceCfg *primitive.TraceConfig) Option {
 
 func newTraceInterceptor(traceCfg *primitive.TraceConfig) 
primitive.Interceptor {
        dispatcher := internal.NewTraceDispatcher(traceCfg)
-       dispatcher.Start()
+       if dispatcher != nil {
+               dispatcher.Start()
+       }
 
        return func(ctx context.Context, req, reply interface{}, next 
primitive.Invoker) error {
+               if dispatcher == nil {
+                       return fmt.Errorf("GetOrNewRocketMQClient faild")
+               }
                beginT := time.Now()
                err := next(ctx, req, reply)
 
diff --git a/producer/option.go b/producer/option.go
index 5839402..ae76511 100644
--- a/producer/option.go
+++ b/producer/option.go
@@ -35,7 +35,7 @@ func defaultProducerOptions() producerOptions {
                CompressMsgBodyOverHowmuch: 4096,
                CompressLevel:              5,
        }
-       opts.ClientOptions.GroupName = "DEFAULT_CONSUMER"
+       opts.ClientOptions.GroupName = "DEFAULT_PRODUCER"
        return opts
 }
 
diff --git a/producer/producer.go b/producer/producer.go
index 226eedb..3c875c6 100644
--- a/producer/producer.go
+++ b/producer/producer.go
@@ -67,19 +67,25 @@ func NewDefaultProducer(opts ...Option) (*defaultProducer, 
error) {
                options:    defaultOpts,
        }
        producer.client = 
internal.GetOrNewRocketMQClient(defaultOpts.ClientOptions, producer.callbackCh)
-       producer.options.ClientOptions.Namesrv, err = 
internal.GetNamesrv(producer.client.ClientID())
-       if err != nil {
-               return nil, err
+       if producer.client == nil {
+               return nil, fmt.Errorf("GetOrNewRocketMQClient faild")
        }
+       defaultOpts.Namesrv = producer.client.GetNameSrv()
+
        producer.interceptor = 
primitive.ChainInterceptors(producer.options.Interceptors...)
 
        return producer, nil
 }
 
 func (p *defaultProducer) Start() error {
+       if p == nil || p.client == nil {
+               return fmt.Errorf("client instance is nil, can not start 
producer")
+       }
        atomic.StoreInt32(&p.state, int32(internal.StateRunning))
-
-       p.client.RegisterProducer(p.group, p)
+       err := p.client.RegisterProducer(p.group, p)
+       if err != nil {
+               return err
+       }
        p.client.Start()
        return nil
 }
@@ -195,7 +201,7 @@ func (p *defaultProducer) sendSync(ctx context.Context, msg 
*primitive.Message,
                        continue
                }
 
-               addr := p.options.Namesrv.FindBrokerAddrByName(mq.BrokerName)
+               addr := 
p.client.GetNameSrv().FindBrokerAddrByName(mq.BrokerName)
                if addr == "" {
                        return fmt.Errorf("topic=%s route info not found", 
mq.Topic)
                }
@@ -242,7 +248,7 @@ func (p *defaultProducer) sendAsync(ctx context.Context, 
msg *primitive.Message,
                return errors.Errorf("the topic=%s route info not found", 
msg.Topic)
        }
 
-       addr := p.options.Namesrv.FindBrokerAddrByName(mq.BrokerName)
+       addr := p.client.GetNameSrv().FindBrokerAddrByName(mq.BrokerName)
        if addr == "" {
                return errors.Errorf("topic=%s route info not found", mq.Topic)
        }
@@ -289,7 +295,7 @@ func (p *defaultProducer) sendOneWay(ctx context.Context, 
msg *primitive.Message
                        continue
                }
 
-               addr := p.options.Namesrv.FindBrokerAddrByName(mq.BrokerName)
+               addr := 
p.client.GetNameSrv().FindBrokerAddrByName(mq.BrokerName)
                if addr == "" {
                        return fmt.Errorf("topic=%s route info not found", 
mq.Topic)
                }
@@ -378,7 +384,7 @@ func (p *defaultProducer) selectMessageQueue(msg 
*primitive.Message) *primitive.
 
        v, exist := p.publishInfo.Load(topic)
        if !exist {
-               data, changed, err := 
p.options.Namesrv.UpdateTopicRouteInfo(topic)
+               data, changed, err := 
p.client.GetNameSrv().UpdateTopicRouteInfo(topic)
                if err != nil && primitive.IsRemotingErr(err) {
                        return nil
                }
@@ -387,7 +393,7 @@ func (p *defaultProducer) selectMessageQueue(msg 
*primitive.Message) *primitive.
        }
 
        if !exist {
-               data, changed, _ := 
p.options.Namesrv.UpdateTopicRouteInfoWithDefault(topic, 
p.options.CreateTopicKey, p.options.DefaultTopicQueueNums)
+               data, changed, _ := 
p.client.GetNameSrv().UpdateTopicRouteInfoWithDefault(topic, 
p.options.CreateTopicKey, p.options.DefaultTopicQueueNums)
                p.client.UpdatePublishInfo(topic, data, changed)
                v, exist = p.publishInfo.Load(topic)
        }
@@ -558,8 +564,8 @@ func (tp *transactionProducer) endTransaction(result 
primitive.SendResult, err e
        } else {
                msgID, _ = primitive.UnmarshalMsgID([]byte(result.MsgID))
        }
-       
-       brokerAddr := 
tp.producer.options.Namesrv.FindBrokerAddrByName(result.MessageQueue.BrokerName)
+       // 估计没有反序列化回来
+       brokerAddr := 
tp.producer.client.GetNameSrv().FindBrokerAddrByName(result.MessageQueue.BrokerName)
        requestHeader := &internal.EndTransactionRequestHeader{
                TransactionId:        result.TransactionID,
                CommitLogOffset:      msgID.Offset,
diff --git a/producer/producer_test.go b/producer/producer_test.go
index a7c15c1..e1d72dd 100644
--- a/producer/producer_test.go
+++ b/producer/producer_test.go
@@ -47,7 +47,7 @@ func TestShutdown(t *testing.T) {
        client := internal.NewMockRMQClient(ctrl)
        p.client = client
 
-       client.EXPECT().RegisterProducer(gomock.Any(), gomock.Any()).Return()
+       client.EXPECT().RegisterProducer(gomock.Any(), gomock.Any()).Return(nil)
        client.EXPECT().Start().Return()
        err := p.Start()
        assert.Nil(t, err)
@@ -117,10 +117,13 @@ func TestSync(t *testing.T) {
 
        ctrl := gomock.NewController(t)
        defer ctrl.Finish()
+       namesrvCli := internal.NewMockNamesrvs(ctrl)
        client := internal.NewMockRMQClient(ctrl)
        p.client = client
+       client.SetNameSrv(namesrvCli)
+       namesrvCli.EXPECT().FindBrokerAddrByName(gomock.Any()).Return("a")
 
-       client.EXPECT().RegisterProducer(gomock.Any(), gomock.Any()).Return()
+       client.EXPECT().RegisterProducer(gomock.Any(), gomock.Any()).Return(nil)
        client.EXPECT().Start().Return()
        err := p.Start()
        assert.Nil(t, err)
@@ -168,10 +171,13 @@ func TestASync(t *testing.T) {
 
        ctrl := gomock.NewController(t)
        defer ctrl.Finish()
+       namesrvCli := internal.NewMockNamesrvs(ctrl)
        client := internal.NewMockRMQClient(ctrl)
        p.client = client
+       client.SetNameSrv(namesrvCli)
+       namesrvCli.EXPECT().FindBrokerAddrByName(gomock.Any()).Return("a")
 
-       client.EXPECT().RegisterProducer(gomock.Any(), gomock.Any()).Return()
+       client.EXPECT().RegisterProducer(gomock.Any(), gomock.Any()).Return(nil)
        client.EXPECT().Start().Return()
        err := p.Start()
        assert.Nil(t, err)
@@ -230,10 +236,13 @@ func TestOneway(t *testing.T) {
 
        ctrl := gomock.NewController(t)
        defer ctrl.Finish()
+       namesrvCli := internal.NewMockNamesrvs(ctrl)
        client := internal.NewMockRMQClient(ctrl)
        p.client = client
+       client.SetNameSrv(namesrvCli)
+       namesrvCli.EXPECT().FindBrokerAddrByName(gomock.Any()).Return("a")
 
-       client.EXPECT().RegisterProducer(gomock.Any(), gomock.Any()).Return()
+       client.EXPECT().RegisterProducer(gomock.Any(), gomock.Any()).Return(nil)
        client.EXPECT().Start().Return()
        err := p.Start()
        assert.Nil(t, err)
@@ -268,10 +277,13 @@ func TestSyncWithNamespace(t *testing.T) {
 
        ctrl := gomock.NewController(t)
        defer ctrl.Finish()
+       namesrvCli := internal.NewMockNamesrvs(ctrl)
        client := internal.NewMockRMQClient(ctrl)
        p.client = client
+       client.SetNameSrv(namesrvCli)
+       namesrvCli.EXPECT().FindBrokerAddrByName(gomock.Any()).Return("a")
 
-       client.EXPECT().RegisterProducer(gomock.Any(), gomock.Any()).Return()
+       client.EXPECT().RegisterProducer(gomock.Any(), gomock.Any()).Return(nil)
        client.EXPECT().Start().Return()
        err := p.Start()
        assert.Nil(t, err)
@@ -323,7 +335,7 @@ func TestBatchSendDifferentTopics(t *testing.T) {
        client := internal.NewMockRMQClient(ctrl)
        p.client = client
 
-       client.EXPECT().RegisterProducer(gomock.Any(), gomock.Any()).Return()
+       client.EXPECT().RegisterProducer(gomock.Any(), gomock.Any()).Return(nil)
        client.EXPECT().Start().Return()
        err := p.Start()
        assert.Nil(t, err)
diff --git a/rlog/log.go b/rlog/log.go
index 037cfcf..a179a40 100644
--- a/rlog/log.go
+++ b/rlog/log.go
@@ -25,6 +25,7 @@ import (
 )
 
 const (
+       LogKeyProducerGroup    = "producerGroup"
        LogKeyConsumerGroup    = "consumerGroup"
        LogKeyTopic            = "topic"
        LogKeyMessageQueue     = "MessageQueue"

Reply via email to