This is an automated email from the ASF dual-hosted git repository.

git-hulk pushed a commit to branch unstable
in repository https://gitbox.apache.org/repos/asf/kvrocks-controller.git


The following commit(s) were added to refs/heads/unstable by this push:
     new c89a276  Fix manul failover with CLIENT PAUSE/UNPAUSE (#389)
c89a276 is described below

commit c89a276a1bdb81191de4485f2a6f4a6685e98706
Author: Ruifeng Guo <[email protected]>
AuthorDate: Wed May 20 15:24:16 2026 +0800

    Fix manul failover with CLIENT PAUSE/UNPAUSE (#389)
---
 config/config.go            |   1 +
 consts/errors.go            |   1 +
 controller/cluster.go       |  22 +++---
 server/api/handler.go       |   4 +-
 server/api/shard.go         |  75 +++++++++++++++++++--
 server/route.go             |   3 +-
 store/cluster.go            |  11 ++-
 store/cluster_mock_node.go  |  39 ++++++++++-
 store/cluster_node.go       | 112 ++++++++++++++++++++++++++++++
 store/cluster_shard.go      | 161 +++++++++++++++++++++++++++++++++++++++++---
 store/cluster_shard_test.go |  30 +++++++++
 store/cluster_test.go       |  18 ++---
 12 files changed, 435 insertions(+), 42 deletions(-)

diff --git a/config/config.go b/config/config.go
index d88862a..2581e6d 100644
--- a/config/config.go
+++ b/config/config.go
@@ -47,6 +47,7 @@ type FailOverConfig struct {
        // propagates the updated topology. Requires kvrocks to support node 
status
        // modification (new versions only). Defaults to false for backward 
compatibility.
        EnableSlaveHAUpdate bool `yaml:"enable_slave_ha_update"`
+       WaitForSync         bool `yaml:"wait_for_sync"`
 }
 
 type ControllerConfig struct {
diff --git a/consts/errors.go b/consts/errors.go
index 4d3f1bc..0a9d00a 100644
--- a/consts/errors.go
+++ b/consts/errors.go
@@ -40,4 +40,5 @@ var (
        ErrShardNoMatchNewMaster            = errors.New("no match new master 
in shard")
        ErrCannotOfflineMaster              = errors.New("cannot take master 
node offline, failover first")
        ErrSlotStartAndStopEqual            = errors.New("start and stop of a 
range cannot be equal")
+       ErrSyncTimeout                      = errors.New("replication sync 
timeout")
 )
diff --git a/controller/cluster.go b/controller/cluster.go
index 22242a9..b6f78bc 100755
--- a/controller/cluster.go
+++ b/controller/cluster.go
@@ -41,6 +41,7 @@ type ClusterCheckOptions struct {
        pingInterval        time.Duration
        maxFailureCount     int64
        enableSlaveHAUpdate bool
+       failoverOpts        store.FailoverOptions
 }
 
 type ClusterChecker struct {
@@ -72,6 +73,7 @@ func NewClusterChecker(s store.Store, ns, cluster string) 
*ClusterChecker {
                options: ClusterCheckOptions{
                        pingInterval:    time.Second * 3,
                        maxFailureCount: 5,
+                       failoverOpts:    store.DefaultFailoverOptions(),
                },
                failureCounts: make(map[string]int64),
                syncCh:        make(chan struct{}, 1),
@@ -110,6 +112,11 @@ func (c *ClusterChecker) WithSlaveHAUpdate(enable bool) 
*ClusterChecker {
        return c
 }
 
+func (c *ClusterChecker) WithFailoverOptions(opts store.FailoverOptions) 
*ClusterChecker {
+       c.options.failoverOpts = opts
+       return c
+}
+
 func (c *ClusterChecker) probeNode(ctx context.Context, node store.Node) 
(int64, error) {
        clusterInfo, err := node.GetClusterInfo(ctx)
        if err != nil {
@@ -174,20 +181,19 @@ func (c *ClusterChecker) increaseFailureCount(shardIndex 
int, node store.Node) i
                        log.Error("Failed to get the cluster info", 
zap.Error(err))
                        return count
                }
-               newMasterID, err := cluster.PromoteNewMaster(c.ctx, shardIndex, 
node.ID(), "")
-               if err != nil {
-                       log.Error("Failed to promote the new master", 
zap.Error(err))
+               _, newMaster, promoteErr := cluster.PromoteNewMaster(c.ctx, 
shardIndex, node.ID(), "", c.options.failoverOpts)
+               if promoteErr != nil {
+                       log.Error("Failed to promote the new master", 
zap.Error(promoteErr))
                        return count
                }
-               err = c.clusterStore.UpdateCluster(c.ctx, c.namespace, cluster)
-               if err != nil {
-                       log.Error("Failed to update the cluster", 
zap.Error(err))
+               if updateErr := c.clusterStore.UpdateCluster(c.ctx, 
c.namespace, cluster); updateErr != nil {
+                       log.Error("Failed to persist cluster after promoting 
new master", zap.Error(updateErr))
                        return count
                }
                // the node is normal if it can be elected as the new master,
                // because it requires the node is healthy.
-               c.resetFailureCount(newMasterID)
-               log.With(zap.String("new_master_id", 
newMasterID)).Info("Promote the new master")
+               c.resetFailureCount(newMaster.ID())
+               log.With(zap.String("new_master_id", 
newMaster.ID())).Info("Promote the new master")
        }
        return count
 }
diff --git a/server/api/handler.go b/server/api/handler.go
index d80b3ea..1ac4629 100644
--- a/server/api/handler.go
+++ b/server/api/handler.go
@@ -32,11 +32,11 @@ type Handler struct {
        Raft      *RaftHandler
 }
 
-func NewHandler(s *store.ClusterStore) *Handler {
+func NewHandler(s *store.ClusterStore, waitForSync bool) *Handler {
        return &Handler{
                Namespace: &NamespaceHandler{s: s},
                Cluster:   &ClusterHandler{s: s},
-               Shard:     &ShardHandler{s: s},
+               Shard:     &ShardHandler{s: s, configWaitForSync: waitForSync},
                Node:      &NodeHandler{s: s},
                Raft:      &RaftHandler{},
        }
diff --git a/server/api/shard.go b/server/api/shard.go
index 2706146..502b23a 100644
--- a/server/api/shard.go
+++ b/server/api/shard.go
@@ -23,16 +23,21 @@ import (
        "errors"
        "fmt"
        "strconv"
+       "sync"
+       "time"
 
        "github.com/gin-gonic/gin"
+       "go.uber.org/zap"
 
        "github.com/apache/kvrocks-controller/consts"
+       "github.com/apache/kvrocks-controller/logger"
        "github.com/apache/kvrocks-controller/server/helper"
        "github.com/apache/kvrocks-controller/store"
 )
 
 type ShardHandler struct {
-       s store.Store
+       s                 store.Store
+       configWaitForSync bool
 }
 
 type SlotsRequest struct {
@@ -114,12 +119,21 @@ func (handler *ShardHandler) Remove(c *gin.Context) {
        helper.ResponseNoContent(c)
 }
 
+// FailoverOpts holds optional parameters for manual failover.
+type FailoverOpts struct {
+       WaitForSync    bool `json:"wait_for_sync"`
+       ForceOnTimeout bool `json:"force_on_timeout"`
+       SyncTimeoutMs  int  `json:"sync_timeout_ms"`  // 0 means use default
+       PauseTimeoutMs int  `json:"pause_timeout_ms"` // 0 means use default
+}
+
 func (handler *ShardHandler) Failover(c *gin.Context) {
        ns := c.Param("namespace")
        cluster, _ := c.MustGet(consts.ContextKeyCluster).(*store.Cluster)
 
        var req struct {
-               PreferredNodeID string `json:"preferred_node_id"`
+               PreferredNodeID string        `json:"preferred_node_id"`
+               Options         *FailoverOpts `json:"options"`
        }
        if c.Request.Body != nil {
                if err := c.ShouldBindJSON(&req); err != nil {
@@ -131,16 +145,65 @@ func (handler *ShardHandler) Failover(c *gin.Context) {
                helper.ResponseBadRequest(c, fmt.Errorf("invalid node id: %s", 
req.PreferredNodeID))
                return
        }
-       // We have checked this if statement in middleware.RequiredClusterShard
-       shardIndex, _ := strconv.Atoi(c.Param("shard"))
-       newMasterNodeID, err := cluster.PromoteNewMaster(c, shardIndex, "", 
req.PreferredNodeID)
+
+       opts := store.DefaultFailoverOptions()
+       if handler.configWaitForSync {
+               opts.WaitForSync = true
+       } else if req.Options != nil {
+               opts.WaitForSync = req.Options.WaitForSync
+       }
+       if req.Options != nil {
+               if req.Options.SyncTimeoutMs > 0 {
+                       opts.SyncTimeout = 
time.Duration(req.Options.SyncTimeoutMs) * time.Millisecond
+               }
+               if req.Options.PauseTimeoutMs > 0 {
+                       opts.PauseDuration = 
time.Duration(req.Options.PauseTimeoutMs) * time.Millisecond
+               }
+               opts.ForceOnTimeout = req.Options.ForceOnTimeout
+       }
+
+       shardIndex, err := strconv.Atoi(c.Param("shard"))
+       if err != nil {
+               helper.ResponseBadRequest(c, err)
+               return
+       }
+       oldMaster, newMaster, err := cluster.PromoteNewMaster(c, shardIndex, 
"", req.PreferredNodeID, opts)
        if err != nil {
                helper.ResponseError(c, err)
                return
        }
+
+       unpauseOldMaster := func() {
+               if !opts.WaitForSync {
+                       return
+               }
+               if e := oldMaster.UnpauseClient(c); e != nil {
+                       logger.Get().With(zap.Error(e), zap.String("node", 
oldMaster.Addr())).Error("Failed to unpause old master")
+               }
+       }
+
        if err := handler.s.UpdateCluster(c, ns, cluster); err != nil {
+               unpauseOldMaster()
                helper.ResponseError(c, err)
                return
        }
-       helper.ResponseOK(c, gin.H{"new_master_id": newMasterNodeID})
+
+       var wg sync.WaitGroup
+       wg.Add(2)
+       go func() {
+               defer wg.Done()
+               if e := oldMaster.SyncClusterInfo(c, cluster); e != nil {
+                       logger.Get().With(zap.Error(e), zap.String("node", 
oldMaster.Addr())).Warn("Failed to sync cluster info to old master")
+               }
+       }()
+       go func() {
+               defer wg.Done()
+               if e := newMaster.SyncClusterInfo(c, cluster); e != nil {
+                       logger.Get().With(zap.Error(e), zap.String("node", 
newMaster.Addr())).Warn("Failed to sync cluster info to new master")
+               }
+       }()
+       wg.Wait()
+
+       unpauseOldMaster()
+       helper.ResponseOK(c, gin.H{"new_master_id": newMaster.ID()})
 }
diff --git a/server/route.go b/server/route.go
index b5eb94a..113a1a0 100644
--- a/server/route.go
+++ b/server/route.go
@@ -36,7 +36,8 @@ func (srv *Server) initHandlers() {
                c.Set(consts.ContextKeyStore, srv.store)
                c.Next()
        }, middleware.RedirectIfNotLeader)
-       handler := api.NewHandler(srv.store)
+       waitForSync := srv.config.Controller != nil && 
srv.config.Controller.FailOver != nil && 
srv.config.Controller.FailOver.WaitForSync
+       handler := api.NewHandler(srv.store, waitForSync)
 
        engine.Any("/debug/pprof/*profile", PProf)
        engine.GET("/metrics", gin.WrapH(promhttp.Handler()))
diff --git a/store/cluster.go b/store/cluster.go
index 00c0ee3..b402594 100644
--- a/store/cluster.go
+++ b/store/cluster.go
@@ -132,18 +132,17 @@ func (cluster *Cluster) RemoveNode(shardIndex int, nodeID 
string) error {
 }
 
 func (cluster *Cluster) PromoteNewMaster(ctx context.Context,
-       shardIdx int, masterNodeID, preferredNodeID string,
-) (string, error) {
+       shardIdx int, masterNodeID, preferredNodeID string, opts 
FailoverOptions) (oldMasterNode Node, newMasterNode Node, err error) {
        shard, err := cluster.GetShard(shardIdx)
        if err != nil {
-               return "", err
+               return nil, nil, err
        }
-       newMasterNodeID, err := shard.promoteNewMaster(ctx, masterNodeID, 
preferredNodeID)
+       oldMaster, newMaster, err := shard.promoteNewMaster(ctx, masterNodeID, 
preferredNodeID, opts)
        if err != nil {
-               return "", err
+               return nil, nil, err
        }
        cluster.Shards[shardIdx] = shard
-       return newMasterNodeID, nil
+       return oldMaster, newMaster, nil
 }
 
 func (cluster *Cluster) SyncToNodes(ctx context.Context) error {
diff --git a/store/cluster_mock_node.go b/store/cluster_mock_node.go
index 1f72ae0..200ae93 100644
--- a/store/cluster_mock_node.go
+++ b/store/cluster_mock_node.go
@@ -20,14 +20,20 @@
 
 package store
 
-import "context"
+import (
+       "context"
+       "time"
+)
 
 // ClusterMockNode is a mock implementation of the Node interface,
 // it is used for testing purposes.
 type ClusterMockNode struct {
        *ClusterNode
 
-       Sequence uint64
+       Sequence         uint64
+       MasterReplOffset uint64 // used when simulating master in 
GetReplicationInfo
+       SlaveOffset      uint64 // used when simulating slave offset in 
GetReplicationInfo
+       SlaveAddr        string // when master, slave Addr for matching; empty 
means use mock.Addr()
 }
 
 var _ Node = (*ClusterMockNode)(nil)
@@ -53,3 +59,32 @@ func (mock *ClusterMockNode) SyncClusterInfo(ctx 
context.Context, cluster *Clust
 func (mock *ClusterMockNode) Reset(ctx context.Context) error {
        return nil
 }
+
+func (mock *ClusterMockNode) PauseClient(ctx context.Context, timeout 
time.Duration) error {
+       return nil
+}
+
+func (mock *ClusterMockNode) UnpauseClient(ctx context.Context) error {
+       return nil
+}
+
+func (mock *ClusterMockNode) GetReplicationInfo(ctx context.Context) 
(*ReplicationInfo, error) {
+       if mock.IsMaster() {
+               addr := mock.SlaveAddr
+               if addr == "" {
+                       addr = mock.Addr()
+               }
+               return &ReplicationInfo{
+                       Role:             RoleMaster,
+                       MasterReplOffset: mock.MasterReplOffset,
+                       Slaves: []SlaveReplInfo{
+                               {Addr: addr, Offset: mock.SlaveOffset},
+                       },
+               }, nil
+       }
+       return &ReplicationInfo{
+               Role:             RoleSlave,
+               MasterReplOffset: mock.SlaveOffset,
+               SlaveReplOffset:  mock.SlaveOffset,
+       }, nil
+}
diff --git a/store/cluster_node.go b/store/cluster_node.go
index d0b2523..bce333f 100755
--- a/store/cluster_node.go
+++ b/store/cluster_node.go
@@ -88,6 +88,10 @@ type Node interface {
        CheckClusterMode(ctx context.Context) (int64, error)
        MigrateSlot(ctx context.Context, slot SlotRange, NodeID string) error
 
+       PauseClient(ctx context.Context, timeout time.Duration) error
+       UnpauseClient(ctx context.Context) error
+       GetReplicationInfo(ctx context.Context) (*ReplicationInfo, error)
+
        MarshalJSON() ([]byte, error)
        UnmarshalJSON(data []byte) error
 
@@ -114,6 +118,36 @@ type ClusterNodeInfo struct {
        Role     string `json:"role"`
 }
 
+// ReplicationInfo holds parsed output from INFO replication.
+type ReplicationInfo struct {
+       Role             string
+       MasterReplOffset uint64
+       // SlaveReplOffset is the replica's local applied offset (INFO field 
slave_repl_offset); only set when role is slave.
+       SlaveReplOffset uint64
+       // MasterLinkStatus is the replica's master_link_status (e.g. up/down); 
empty if absent.
+       MasterLinkStatus string
+       Slaves           []SlaveReplInfo
+}
+
+// ReplicaAppliedReplOffset returns the replication offset on a node that 
should be compared against
+// the old master's MasterReplOffset to decide whether the replica has caught 
up. On replicas,
+// Kvrocks/Redis expose slave_repl_offset (preferred); if it is missing, 
MasterReplOffset is used.
+func ReplicaAppliedReplOffset(info *ReplicationInfo) uint64 {
+       if info == nil {
+               return 0
+       }
+       if info.Role == RoleSlave && info.SlaveReplOffset > 0 {
+               return info.SlaveReplOffset
+       }
+       return info.MasterReplOffset
+}
+
+// SlaveReplInfo holds slave replication offset from master's perspective.
+type SlaveReplInfo struct {
+       Addr   string // "ip:port", matches node.Addr()
+       Offset uint64
+}
+
 func NewClusterNode(addr, password string) *ClusterNode {
        return &ClusterNode{
                id:        util.GenerateNodeID(),
@@ -305,6 +339,84 @@ func (n *ClusterNode) MigrateSlot(ctx context.Context, 
slot SlotRange, targetNod
        return n.GetClient().Do(ctx, "CLUSTERX", "MIGRATE", slot.String(), 
targetNodeID).Err()
 }
 
+func (n *ClusterNode) PauseClient(ctx context.Context, timeout time.Duration) 
error {
+       ms := timeout.Milliseconds()
+       if ms <= 0 {
+               ms = 1
+       }
+       return n.GetClient().Do(ctx, "CLIENT", "PAUSE", ms, "WRITE").Err()
+}
+
+func (n *ClusterNode) UnpauseClient(ctx context.Context) error {
+       return n.GetClient().Do(ctx, "CLIENT", "UNPAUSE").Err()
+}
+
+func (n *ClusterNode) GetReplicationInfo(ctx context.Context) 
(*ReplicationInfo, error) {
+       infoStr, err := n.GetClient().Info(ctx, "replication").Result()
+       if err != nil {
+               return nil, err
+       }
+
+       info := &ReplicationInfo{}
+       lines := strings.Split(infoStr, "\r\n")
+       for _, line := range lines {
+               fields := strings.SplitN(line, ":", 2)
+               if len(fields) != 2 {
+                       continue
+               }
+               key := strings.TrimSpace(fields[0])
+               val := strings.TrimSpace(fields[1])
+
+               switch key {
+               case "role":
+                       info.Role = val
+               case "master_repl_offset":
+                       info.MasterReplOffset, err = strconv.ParseUint(val, 10, 
64)
+                       if err != nil {
+                               return nil, err
+                       }
+               case "slave_repl_offset":
+                       info.SlaveReplOffset, err = strconv.ParseUint(val, 10, 
64)
+                       if err != nil {
+                               return nil, err
+                       }
+               case "master_link_status":
+                       info.MasterLinkStatus = val
+               default:
+                       if strings.HasPrefix(key, "slave") {
+                               if slave, ok := parseSlaveReplInfo(val); ok {
+                                       info.Slaves = append(info.Slaves, slave)
+                               }
+                       }
+               }
+       }
+       return info, nil
+}
+
+// parseSlaveReplInfo parses 
"ip=127.0.0.1,port=6380,state=online,offset=N,lag=M" into SlaveReplInfo.
+func parseSlaveReplInfo(val string) (SlaveReplInfo, bool) {
+       var ip, port string
+       var offset uint64
+       for _, part := range strings.Split(val, ",") {
+               kv := strings.SplitN(part, "=", 2)
+               if len(kv) != 2 {
+                       continue
+               }
+               switch strings.TrimSpace(kv[0]) {
+               case "ip":
+                       ip = strings.TrimSpace(kv[1])
+               case "port":
+                       port = strings.TrimSpace(kv[1])
+               case "offset":
+                       offset, _ = strconv.ParseUint(strings.TrimSpace(kv[1]), 
10, 64)
+               }
+       }
+       if ip == "" || port == "" {
+               return SlaveReplInfo{}, false
+       }
+       return SlaveReplInfo{Addr: ip + ":" + port, Offset: offset}, true
+}
+
 func (n *ClusterNode) MarshalJSON() ([]byte, error) {
        return json.Marshal(map[string]interface{}{
                "id":         n.id,
diff --git a/store/cluster_shard.go b/store/cluster_shard.go
index 910943b..ba1c10d 100644
--- a/store/cluster_shard.go
+++ b/store/cluster_shard.go
@@ -25,6 +25,8 @@ import (
        "errors"
        "fmt"
        "strings"
+       "sync"
+       "time"
 
        "go.uber.org/zap"
 
@@ -39,6 +41,32 @@ const (
        NotMigratingInt = -1
 )
 
+
+// FailoverOptions configures manual failover behavior.
+type FailoverOptions struct {
+       WaitForSync    bool          // whether to wait for replication gap to 
reach 0
+       SyncTimeout    time.Duration // max wait time for gap to reach 0
+       PauseDuration  time.Duration // CLIENT PAUSE timeout parameter, must be 
> SyncTimeout
+       ForceOnTimeout bool          // if true, proceed with failover on sync 
timeout
+       PollInterval   time.Duration // interval between INFO replication polls
+       PollTimeout    time.Duration // deadline for one poll cycle (two INFO 
replication RPCs: old master + target replica)
+}
+
+// DefaultFailoverOptions returns default options for manual failover.
+// WaitForSync is disabled by default to maintain compatibility with older 
kvrocks versions
+// that do not support CLIENT PAUSE/UNPAUSE. Set WaitForSync=true via the API 
options field
+// when targeting kvrocks instances that support these commands.
+func DefaultFailoverOptions() FailoverOptions {
+       return FailoverOptions{
+               WaitForSync:    false,
+               SyncTimeout:    100 * time.Millisecond,
+               PauseDuration:  500 * time.Millisecond,
+               ForceOnTimeout: false,
+               PollInterval:   10 * time.Millisecond,
+               PollTimeout:    40 * time.Millisecond,
+       }
+}
+
 type Shard struct {
        Nodes            []Node         `json:"nodes"`
        SlotRanges       []SlotRange    `json:"slot_ranges"`
@@ -220,15 +248,104 @@ func (shard *Shard) getNewMasterNodeIndex(ctx 
context.Context, masterNodeIndex i
        return newMasterNodeIndex
 }
 
-// PromoteNewMaster promotes a new master node in the shard,
-// it will return the new master node ID.
+// waitForReplicationSync polls INFO replication on the old master and on the 
target replica until
+// ReplicaAppliedReplOffset(replica) >= master.MasterReplOffset, so offsets 
come from each process
+// directly instead of the master's slave list (which can lag).
+func (shard *Shard) waitForReplicationSync(ctx context.Context, oldMaster 
Node, targetSlave Node, opts FailoverOptions) error {
+       // Bound the entire sync operation with SyncTimeout. Each poll cycle 
issues two concurrent INFO calls;
+       // PollTimeout is the budget for that cycle (both RPCs share one 
deadline).
+       syncCtx, syncCancel := context.WithTimeout(ctx, opts.SyncTimeout)
+       defer syncCancel()
+
+       ticker := time.NewTicker(opts.PollInterval)
+       defer ticker.Stop()
+
+       targetAddr := targetSlave.Addr()
+       // waitNextTick blocks until the next poll interval or the sync 
deadline is exceeded.
+       // Returns nil to signal the caller should continue, or a non-nil error 
to abort.
+       waitNextTick := func() error {
+               select {
+               case <-syncCtx.Done():
+                       if ctx.Err() != nil {
+                               return ctx.Err()
+                       }
+                       return fmt.Errorf("%w: slave %s did not catch up within 
%v", consts.ErrSyncTimeout, targetAddr, opts.SyncTimeout)
+               case <-ticker.C:
+                       return nil
+               }
+       }
+
+       for {
+               pollCtx, cancel := context.WithTimeout(syncCtx, 
opts.PollTimeout)
+               var masterInfo, slaveInfo *ReplicationInfo
+               var errM, errS error
+               var wg sync.WaitGroup
+               wg.Add(2)
+               go func() {
+                       defer wg.Done()
+                       masterInfo, errM = oldMaster.GetReplicationInfo(pollCtx)
+               }()
+               go func() {
+                       defer wg.Done()
+                       slaveInfo, errS = 
targetSlave.GetReplicationInfo(pollCtx)
+               }()
+               wg.Wait()
+               cancel()
+               if errM != nil || errS != nil {
+                       if errM != nil {
+                               logger.Get().With(
+                                       zap.Error(errM),
+                                       zap.String("master", oldMaster.Addr()),
+                               ).Warn("Failed to get replication info from old 
master, will retry")
+                       }
+                       if errS != nil {
+                               logger.Get().With(
+                                       zap.Error(errS),
+                                       zap.String("slave", targetAddr),
+                               ).Warn("Failed to get replication info from 
target replica, will retry")
+                       }
+                       if err := waitNextTick(); err != nil {
+                               return err
+                       }
+                       continue
+               }
+
+               if masterInfo.Role != RoleMaster {
+                       return fmt.Errorf("node %s is not master (role=%s)", 
oldMaster.Addr(), masterInfo.Role)
+               }
+               if slaveInfo.Role != RoleSlave {
+                       return fmt.Errorf("node %s is not slave (role=%s)", 
targetAddr, slaveInfo.Role)
+               }
+               if slaveInfo.MasterLinkStatus != "" && 
!strings.EqualFold(slaveInfo.MasterLinkStatus, "up") {
+                       return fmt.Errorf("replication link for %s is not up 
(master_link_status=%s)", targetAddr, slaveInfo.MasterLinkStatus)
+               }
+
+               masterOff := masterInfo.MasterReplOffset
+               slaveOff := ReplicaAppliedReplOffset(slaveInfo)
+               if slaveOff >= masterOff {
+                       return nil
+               }
+
+               if err := waitNextTick(); err != nil {
+                       return err
+               }
+       }
+}
+
+// promoteNewMaster promotes a new master node in the shard.
+// It returns oldMasterNode and newMasterNode for the handler to orchestrate
+// UpdateCluster, SyncClusterInfo, and UnpauseClient.
 //
 // The masterNodeID is used to check if the node is the current master node if 
it's not empty.
 // The preferredNodeID is used to specify the preferred node to be promoted as 
the new master node,
 // it will choose the node with the highest sequence number if the 
preferredNodeID is empty.
-func (shard *Shard) promoteNewMaster(ctx context.Context, masterNodeID, 
preferredNodeID string) (string, error) {
+//
+// When WaitForSync is true, it will CLIENT PAUSE the old master, wait for 
replication gap to reach 0,
+// then modify roles. The handler must call UnpauseClient on oldMaster after 
UpdateCluster and push.
+func (shard *Shard) promoteNewMaster(ctx context.Context, masterNodeID, 
preferredNodeID string, opts FailoverOptions) (
+       oldMasterNode Node, newMasterNode Node, err error) {
        if len(shard.Nodes) <= 1 {
-               return "", consts.ErrShardNoReplica
+               return nil, nil, consts.ErrShardNoReplica
        }
 
        oldMasterNodeIndex := -1
@@ -239,19 +356,45 @@ func (shard *Shard) promoteNewMaster(ctx context.Context, 
masterNodeID, preferre
                }
        }
        if oldMasterNodeIndex == -1 {
-               return "", consts.ErrOldMasterNodeNotFound
+               return nil, nil, consts.ErrOldMasterNodeNotFound
        }
        if masterNodeID != "" && shard.Nodes[oldMasterNodeIndex].ID() != 
masterNodeID {
-               return "", consts.ErrNodeIsNotMaster
+               return nil, nil, consts.ErrNodeIsNotMaster
        }
        newMasterNodeIndex := shard.getNewMasterNodeIndex(ctx, 
oldMasterNodeIndex, preferredNodeID)
        if newMasterNodeIndex == -1 {
-               return "", consts.ErrShardNoMatchNewMaster
+               return nil, nil, consts.ErrShardNoMatchNewMaster
+       }
+
+       oldMaster := shard.Nodes[oldMasterNodeIndex]
+       newMaster := shard.Nodes[newMasterNodeIndex]
+
+       if opts.WaitForSync {
+               if opts.PauseDuration <= opts.SyncTimeout {
+                       return nil, nil, fmt.Errorf("PauseDuration (%v) must be 
greater than SyncTimeout (%v)", opts.PauseDuration, opts.SyncTimeout)
+               }
+               if err = oldMaster.PauseClient(ctx, opts.PauseDuration); err != 
nil {
+                       return nil, nil, fmt.Errorf("CLIENT PAUSE failed: %w", 
err)
+               }
+               defer func() {
+                       if err != nil {
+                               _ = oldMaster.UnpauseClient(ctx)
+                       }
+               }()
+
+               syncErr := shard.waitForReplicationSync(ctx, oldMaster, 
newMaster, opts)
+               if syncErr != nil {
+                       if opts.ForceOnTimeout && errors.Is(syncErr, 
consts.ErrSyncTimeout) {
+                               
logger.Get().With(zap.Error(syncErr)).Warn("Replication sync timeout, forcing 
failover")
+                       } else {
+                               return nil, nil, syncErr
+                       }
+               }
        }
+
        shard.Nodes[oldMasterNodeIndex].SetRole(RoleSlave)
        shard.Nodes[newMasterNodeIndex].SetRole(RoleMaster)
-       preferredNewMasterNode := shard.Nodes[newMasterNodeIndex]
-       return preferredNewMasterNode.ID(), nil
+       return oldMaster, newMaster, nil
 }
 
 func (shard *Shard) HasOverlap(slotRange SlotRange) bool {
diff --git a/store/cluster_shard_test.go b/store/cluster_shard_test.go
index 971bde1..97216fc 100644
--- a/store/cluster_shard_test.go
+++ b/store/cluster_shard_test.go
@@ -21,8 +21,10 @@
 package store
 
 import (
+       "context"
        "sort"
        "testing"
+       "time"
 
        "github.com/stretchr/testify/require"
 )
@@ -97,6 +99,34 @@ func TestToSlotsString_WithFailedSlave(t *testing.T) {
        require.Contains(t, result, "slave,fail "+master.ID())
 }
 
+func TestReplicaAppliedReplOffset(t *testing.T) {
+       require.Equal(t, uint64(0), ReplicaAppliedReplOffset(nil))
+       require.Equal(t, uint64(10), 
ReplicaAppliedReplOffset(&ReplicationInfo{Role: RoleMaster, MasterReplOffset: 
10}))
+       require.Equal(t, uint64(20), 
ReplicaAppliedReplOffset(&ReplicationInfo{Role: RoleSlave, MasterReplOffset: 
10, SlaveReplOffset: 20}))
+       require.Equal(t, uint64(10), 
ReplicaAppliedReplOffset(&ReplicationInfo{Role: RoleSlave, MasterReplOffset: 
10}))
+}
+
+func TestShard_waitForReplicationSync(t *testing.T) {
+       shard := NewShard()
+       master := &ClusterMockNode{ClusterNode: 
NewClusterNode("127.0.0.1:6379", "")}
+       master.SetRole(RoleMaster)
+       master.MasterReplOffset = 1000
+
+       slave := &ClusterMockNode{ClusterNode: NewClusterNode("127.0.0.1:6380", 
"")}
+       slave.SetRole(RoleSlave)
+       slave.SlaveOffset = 500
+
+       ctx := context.Background()
+       opts := DefaultFailoverOptions()
+       opts.SyncTimeout = 30 * time.Millisecond
+       err := shard.waitForReplicationSync(ctx, master, slave, opts)
+       require.Error(t, err)
+
+       slave.SlaveOffset = 1000
+       err = shard.waitForReplicationSync(ctx, master, slave, opts)
+       require.NoError(t, err)
+}
+
 func TestToSlotsString_WithOnlineSlave(t *testing.T) {
        shard := NewShard()
        shard.SlotRanges = []SlotRange{{Start: 0, Stop: 100}}
diff --git a/store/cluster_test.go b/store/cluster_test.go
index 31ae905..d9dffdd 100644
--- a/store/cluster_test.go
+++ b/store/cluster_test.go
@@ -86,25 +86,27 @@ func TestCluster_PromoteNewMaster(t *testing.T) {
        }
 
        ctx := context.Background()
-       _, err := cluster.PromoteNewMaster(ctx, -1, node0.ID(), "")
+       opts := FailoverOptions{WaitForSync: false}
+
+       _, _, err := cluster.PromoteNewMaster(ctx, -1, node0.ID(), "", opts)
        require.ErrorIs(t, err, consts.ErrIndexOutOfRange)
-       _, err = cluster.PromoteNewMaster(ctx, 1, node0.ID(), "")
+       _, _, err = cluster.PromoteNewMaster(ctx, 1, node0.ID(), "", opts)
        require.ErrorIs(t, err, consts.ErrIndexOutOfRange)
-       _, err = cluster.PromoteNewMaster(ctx, 0, node0.ID(), "")
+       _, _, err = cluster.PromoteNewMaster(ctx, 0, node0.ID(), "", opts)
        require.ErrorIs(t, err, consts.ErrShardNoReplica)
 
        shard.Nodes = append(shard.Nodes, node1, node2, node3)
-       _, err = cluster.PromoteNewMaster(ctx, 0, node1.ID(), "")
+       _, _, err = cluster.PromoteNewMaster(ctx, 0, node1.ID(), "", opts)
        require.ErrorIs(t, err, consts.ErrNodeIsNotMaster)
 
-       newMasterID, err := cluster.PromoteNewMaster(ctx, 0, node0.ID(), "")
+       _, newMaster, err := cluster.PromoteNewMaster(ctx, 0, node0.ID(), "", 
opts)
        require.NoError(t, err)
-       require.Equal(t, node3.ID(), newMasterID)
+       require.Equal(t, node3.ID(), newMaster.ID())
 
        // test preferredNodeID
-       newMasterID, err = cluster.PromoteNewMaster(ctx, 0, node3.ID(), 
node2.ID())
+       _, newMaster, err = cluster.PromoteNewMaster(ctx, 0, node3.ID(), 
node2.ID(), opts)
        require.NoError(t, err)
-       require.Equal(t, node2.ID(), newMasterID)
+       require.Equal(t, node2.ID(), newMaster.ID())
 }
 
 func TestCluster_SetNodeStatusByID(t *testing.T) {

Reply via email to