This is an automated email from the ASF dual-hosted git repository.
xyz pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/pulsar-client-go.git
The following commit(s) were added to refs/heads/master by this push:
new bddaeba6 fix: SendAsync doesn't respect context and can't timeout
during reconnection (#1422)
bddaeba6 is described below
commit bddaeba6cae63eef56411d7c64f72999273605a8
Author: Zike Yang <[email protected]>
AuthorDate: Mon Sep 15 17:11:41 2025 +0800
fix: SendAsync doesn't respect context and can't timeout during
reconnection (#1422)
---
pulsar/internal/blocking_queue.go | 41 ++++++++-
pulsar/internal/blocking_queue_test.go | 98 ++++++++++++++++++++
pulsar/producer.go | 9 ++
pulsar/producer_partition.go | 106 +++++++++++++--------
pulsar/producer_test.go | 163 ++++++++++++++++++++++++++++++++-
5 files changed, 373 insertions(+), 44 deletions(-)
diff --git a/pulsar/internal/blocking_queue.go
b/pulsar/internal/blocking_queue.go
index b44ec166..a4b9cd92 100644
--- a/pulsar/internal/blocking_queue.go
+++ b/pulsar/internal/blocking_queue.go
@@ -24,8 +24,12 @@ import (
// BlockingQueue is a interface of block queue
type BlockingQueue interface {
// Put enqueue one item, block if the queue is full
+ // This is currently used for the internal testing
Put(item interface{})
+ // PutUnsafe enqueue one item without locking the queue, block if the
queue is full
+ PutUnsafe(item interface{})
+
// Take dequeue one item, block until it's available
Take() interface{}
@@ -46,6 +50,17 @@ type BlockingQueue interface {
// ReadableSlice returns a new view of the readable items in the queue
ReadableSlice() []interface{}
+
+ // IterateUnsafe iterates the items in the queue without blocking the
queue
+ IterateUnsafe(func(item interface{}))
+
+ // Lock locks the queue for manual control
+ // Users must call Unlock() after finishing their operations
+ Lock()
+
+ // Unlock unlocks the queue
+ // Must be called after Lock() to release the lock
+ Unlock()
}
type blockingQueue struct {
@@ -60,6 +75,17 @@ type blockingQueue struct {
isNotFull *sync.Cond
}
+func (bq *blockingQueue) IterateUnsafe(f func(item interface{})) {
+ readIdx := bq.headIdx
+ for i := 0; i < bq.size; i++ {
+ f(bq.items[readIdx])
+ readIdx++
+ if readIdx == bq.maxSize {
+ readIdx = 0
+ }
+ }
+}
+
// NewBlockingQueue init block queue and returns a BlockingQueue
func NewBlockingQueue(maxSize int) BlockingQueue {
bq := &blockingQueue{
@@ -76,9 +102,12 @@ func NewBlockingQueue(maxSize int) BlockingQueue {
}
func (bq *blockingQueue) Put(item interface{}) {
- bq.mutex.Lock()
- defer bq.mutex.Unlock()
+ bq.Lock()
+ defer bq.Unlock()
+ bq.PutUnsafe(item)
+}
+func (bq *blockingQueue) PutUnsafe(item interface{}) {
for bq.size == bq.maxSize {
bq.isNotFull.Wait()
}
@@ -192,3 +221,11 @@ func (bq *blockingQueue) ReadableSlice() []interface{} {
return res
}
+
+func (bq *blockingQueue) Lock() {
+ bq.mutex.Lock()
+}
+
+func (bq *blockingQueue) Unlock() {
+ bq.mutex.Unlock()
+}
diff --git a/pulsar/internal/blocking_queue_test.go
b/pulsar/internal/blocking_queue_test.go
index c93b1a6e..8acd282e 100644
--- a/pulsar/internal/blocking_queue_test.go
+++ b/pulsar/internal/blocking_queue_test.go
@@ -148,3 +148,101 @@ func TestBlockingQueue_ReadableSlice(t *testing.T) {
assert.Equal(t, items[1], 3)
assert.Equal(t, items[2], 4)
}
+
+func TestBlockingQueueIterate(t *testing.T) {
+ bq := NewBlockingQueue(5)
+
+ // Add some items
+ bq.PutUnsafe("item1")
+ bq.PutUnsafe("item2")
+ bq.PutUnsafe("item3")
+
+ // Test iteration
+ items := make([]interface{}, 0)
+ bq.IterateUnsafe(func(item interface{}) {
+ items = append(items, item)
+ })
+
+ assert.Equal(t, 3, len(items))
+ assert.Equal(t, "item1", items[0])
+ assert.Equal(t, "item2", items[1])
+ assert.Equal(t, "item3", items[2])
+}
+
+func TestBlockingQueueIteratePartial(t *testing.T) {
+ bq := NewBlockingQueue(5)
+
+ // Add some items
+ bq.PutUnsafe("item1")
+ bq.PutUnsafe("item2")
+ bq.PutUnsafe("item3")
+
+ // Test partial iteration (first 2 items only)
+ items := make([]interface{}, 0)
+ bq.IterateUnsafe(func(item interface{}) {
+ if len(items) < 2 {
+ items = append(items, item)
+ }
+ })
+
+ assert.Equal(t, 2, len(items))
+ assert.Equal(t, "item1", items[0])
+ assert.Equal(t, "item2", items[1])
+}
+
+func TestBlockingQueueIterateCircularBuffer(t *testing.T) {
+ bq := NewBlockingQueue(3)
+
+ // Fill the queue to test circular buffer behavior
+ bq.PutUnsafe("item1")
+ bq.PutUnsafe("item2")
+ bq.PutUnsafe("item3")
+
+ // Remove one item to create space
+ bq.Poll()
+
+ // Add another item to test wrapping
+ bq.PutUnsafe("item4")
+
+ // Test iteration with circular buffer
+ items := make([]interface{}, 0)
+ bq.IterateUnsafe(func(item interface{}) {
+ items = append(items, item)
+ })
+
+ assert.Equal(t, 3, len(items))
+ assert.Equal(t, "item2", items[0])
+ assert.Equal(t, "item3", items[1])
+ assert.Equal(t, "item4", items[2])
+}
+
+func TestBlockingQueueIterateEmpty(t *testing.T) {
+ bq := NewBlockingQueue(5)
+
+ // Test iteration on empty queue
+ items := make([]interface{}, 0)
+ bq.IterateUnsafe(func(item interface{}) {
+ items = append(items, item)
+ })
+
+ assert.Equal(t, 0, len(items))
+}
+
+func TestBlockingQueueManualLock(t *testing.T) {
+ bq := NewBlockingQueue(5)
+
+ // Test manual locking for batch PutUnsafe operations
+ bq.Lock()
+
+ bq.PutUnsafe("item1")
+ bq.PutUnsafe("item2")
+ bq.PutUnsafe("item3")
+
+ // Unlock
+ bq.Unlock()
+
+ // Verify all items were added
+ assert.Equal(t, 3, bq.Size())
+ assert.Equal(t, "item1", bq.Peek())
+ assert.Equal(t, "item3", bq.PeekLast())
+}
diff --git a/pulsar/producer.go b/pulsar/producer.go
index 76ab853a..633e761a 100644
--- a/pulsar/producer.go
+++ b/pulsar/producer.go
@@ -235,6 +235,15 @@ type Producer interface {
// This call is blocked when the `maxPendingMessages` becomes full
(default: 1000)
// The callback will report back the message being published and
// the eventual error in publishing
+ // The context passed in the call is only used for the duration of the
SendAsync call itself
+ // (i.e., to control blocking when the queue is full), and not for the
entire message lifetime.
+ // Once SendAsync returns, the message lifetime is controlled by the
SendTimeout configuration.
+ // Example:
+ // producer.SendAsync(ctx, &pulsar.ProducerMessage{
+ // Payload: myPayload,
+ // }, func(msgID pulsar.MessageID, message *pulsar.ProducerMessage, err
error) {
+ // // handle publish result
+ // })
SendAsync(context.Context, *ProducerMessage, func(MessageID,
*ProducerMessage, error))
// LastSequenceID get the last sequence id that was published by this
producer.
diff --git a/pulsar/producer_partition.go b/pulsar/producer_partition.go
index 9841222f..a6b83a8f 100755
--- a/pulsar/producer_partition.go
+++ b/pulsar/producer_partition.go
@@ -47,6 +47,7 @@ type producerState int32
const (
// producer states
producerInit = iota
+ producerConnecting
producerReady
producerClosing
producerClosed
@@ -218,6 +219,7 @@ func newPartitionProducer(client *client, topic string,
options *ProducerOptions
} else {
p.userProvidedProducerName = false
}
+ p.setProducerState(producerConnecting)
err := p.grabCnx("")
if err != nil {
p.batchFlushTicker.Stop()
@@ -265,6 +267,10 @@ func (p *partitionProducer) lookupTopic(brokerServiceURL
string) (*internal.Look
}
func (p *partitionProducer) grabCnx(assignedBrokerURL string) error {
+ if !p.casProducerState(producerReady, producerConnecting) &&
p.isClosingOrClosed() {
+ // closing or closed
+ return ErrProducerClosed
+ }
lr, err := p.lookupTopic(assignedBrokerURL)
if err != nil {
return err
@@ -385,31 +391,20 @@ func (p *partitionProducer) grabCnx(assignedBrokerURL
string) error {
"epoch": atomic.LoadUint64(&p.epoch),
}).Info("Connected producer")
- pendingItems := p.pendingQueue.ReadableSlice()
- viewSize := len(pendingItems)
- if viewSize > 0 {
- p.log.Infof("Resending %d pending batches", viewSize)
- lastViewItem := pendingItems[viewSize-1].(*pendingItem)
-
- // iterate at most pending items
- for i := 0; i < viewSize; i++ {
- item := p.pendingQueue.Poll()
- if item == nil {
- continue
- }
- pi := item.(*pendingItem)
- // when resending pending batches, we update the sendAt
timestamp to record the metric.
- pi.Lock()
- pi.sentAt = time.Now()
- pi.Unlock()
- pi.buffer.Retain() // Retain for writing to the
connection
- p.pendingQueue.Put(pi)
- p._getConn().WriteData(pi.ctx, pi.buffer)
+ p.pendingQueue.Lock()
+ defer p.pendingQueue.Unlock()
+ p.pendingQueue.IterateUnsafe(func(item any) {
+ pi := item.(*pendingItem)
+ // when resending pending batches, we update the sendAt
timestamp to record the metric.
+ pi.Lock()
+ pi.sentAt = time.Now()
+ pi.Unlock()
+ pi.buffer.Retain() // Retain for writing to the connection
+ p._getConn().WriteData(pi.ctx, pi.buffer)
+ })
- if pi == lastViewItem {
- break
- }
- }
+ if !p.casProducerState(producerConnecting, producerReady) &&
p.isClosingOrClosed() {
+ return ErrProducerClosed
}
return nil
}
@@ -495,9 +490,9 @@ func (p *partitionProducer)
reconnectToBroker(connectionClosed *connectionClosed
return struct{}{}, nil
}
- if p.getProducerState() != producerReady {
+ if p.isClosingOrClosed() {
// Producer is already closing
- p.log.Info("producer state not ready, exit reconnect")
+ p.log.Info("producer state is in closing or closed,
exit reconnect")
return struct{}{}, nil
}
@@ -561,6 +556,18 @@ func (p *partitionProducer)
reconnectToBroker(connectionClosed *connectionClosed
}
func (p *partitionProducer) runEventsLoop() {
+ go func() {
+ for {
+ select {
+ case connectionClosed := <-p.connectClosedCh:
+ p.log.Info("runEventsLoop will reconnect in
producer")
+ p.reconnectToBroker(connectionClosed)
+ case <-p.ctx.Done():
+ p.log.Info("Producer is shutting down. Close
the reconnect event loop")
+ return
+ }
+ }
+ }()
for {
select {
case data, ok := <-p.dataChan:
@@ -581,9 +588,6 @@ func (p *partitionProducer) runEventsLoop() {
p.internalClose(v)
return
}
- case connectionClosed := <-p.connectClosedCh:
- p.log.Info("runEventsLoop will reconnect in producer")
- p.reconnectToBroker(connectionClosed)
case <-p.batchFlushTicker.C:
p.internalFlushCurrentBatch()
}
@@ -902,7 +906,17 @@ func (p *partitionProducer) writeData(buffer
internal.Buffer, sequenceID uint64,
now := time.Now()
ctx, cancel := context.WithCancel(context.Background())
buffer.Retain()
- p.pendingQueue.Put(&pendingItem{
+ p.pendingQueue.Lock()
+ defer p.pendingQueue.Unlock()
+ conn := p._getConn()
+ if p.getProducerState() == producerReady {
+ // If the producer is reconnecting, we should not write
to the connection.
+ // We just need to push the buffer to the pending
queue, it will be sent during the reconnecting.
+ conn.WriteData(ctx, buffer)
+ } else {
+ p.log.Debug("Skipping write to connection, producer
state: ", p.getProducerState())
+ }
+ p.pendingQueue.PutUnsafe(&pendingItem{
ctx: ctx,
cancel: cancel,
createdAt: now,
@@ -911,7 +925,6 @@ func (p *partitionProducer) writeData(buffer
internal.Buffer, sequenceID uint64,
sequenceID: sequenceID,
sendRequests: callbacks,
})
- p._getConn().WriteData(ctx, buffer)
}
}
@@ -924,8 +937,7 @@ func (p *partitionProducer) failTimeoutMessages() {
defer t.Stop()
for range t.C {
- state := p.getProducerState()
- if state == producerClosing || state == producerClosed {
+ if p.isClosingOrClosed() {
return
}
@@ -1323,11 +1335,16 @@ func (p *partitionProducer) internalSendAsync(
return
}
- if p.getProducerState() != producerReady {
+ if p.isClosingOrClosed() {
sr.done(nil, ErrProducerClosed)
return
}
+ if err := ctx.Err(); err != nil {
+ sr.done(nil, ctx.Err())
+ return
+ }
+
p.options.Interceptors.BeforeSend(p, msg)
if err := p.updateSchema(sr); err != nil {
@@ -1356,7 +1373,13 @@ func (p *partitionProducer) internalSendAsync(
return
}
- p.dataChan <- sr
+ select {
+ case <-ctx.Done():
+ sr.done(nil, ctx.Err())
+ return
+ case p.dataChan <- sr:
+ return
+ }
}
func (p *partitionProducer) ReceivedSendReceipt(response
*pb.CommandSendReceipt) {
@@ -1441,7 +1464,7 @@ func (p *partitionProducer) internalClose(req
*closeProducer) {
}
func (p *partitionProducer) doClose(reason error) {
- if !p.casProducerState(producerReady, producerClosing) {
+ if !p.casProducerState(producerReady, producerClosing) &&
!p.casProducerState(producerConnecting, producerClosing) {
return
}
@@ -1522,7 +1545,7 @@ func (p *partitionProducer) Flush() error {
}
func (p *partitionProducer) FlushWithCtx(ctx context.Context) error {
- if p.getProducerState() != producerReady {
+ if p.isClosingOrClosed() {
return ErrProducerClosed
}
@@ -1549,6 +1572,11 @@ func (p *partitionProducer) getProducerState()
producerState {
return producerState(p.state.Load())
}
+func (p *partitionProducer) isClosingOrClosed() bool {
+ state := p.getProducerState()
+ return state == producerClosing || state == producerClosed
+}
+
func (p *partitionProducer) setProducerState(state producerState) {
p.state.Swap(int32(state))
}
@@ -1560,8 +1588,8 @@ func (p *partitionProducer) casProducerState(oldState,
newState producerState) b
}
func (p *partitionProducer) Close() {
- if p.getProducerState() != producerReady {
- // Producer is closing
+ if p.isClosingOrClosed() {
+ // Producer is closing or closed
return
}
diff --git a/pulsar/producer_test.go b/pulsar/producer_test.go
index 02b4b677..00876070 100644
--- a/pulsar/producer_test.go
+++ b/pulsar/producer_test.go
@@ -58,7 +58,8 @@ func TestInvalidURL(t *testing.T) {
func TestProducerConnectError(t *testing.T) {
client, err := NewClient(ClientOptions{
- URL: "pulsar://invalid-hostname:6650",
+ URL: "pulsar://invalid-hostname:6650",
+ OperationTimeout: 3 * time.Second,
})
assert.Nil(t, err)
@@ -348,7 +349,7 @@ func TestFlushInProducer(t *testing.T) {
assert.NoError(t, err)
defer client.Close()
- topicName := "test-flush-in-producer"
+ topicName := newTopicName()
subName := "subscription-name"
numOfMessages := 10
ctx := context.Background()
@@ -2334,7 +2335,7 @@ func TestMemLimitContextCancel(t *testing.T) {
Payload: make([]byte, 1024),
}, func(_ MessageID, _ *ProducerMessage, e error) {
assert.Error(t, e)
- assert.ErrorContains(t, e, getResultStr(TimeoutError))
+ assert.ErrorContains(t, e, "context canceled")
wg.Done()
})
}()
@@ -2353,6 +2354,32 @@ func TestMemLimitContextCancel(t *testing.T) {
assert.NoError(t, err)
}
+func TestSendAsyncWithContextCancel(t *testing.T) {
+
+ c, err := NewClient(ClientOptions{
+ URL: serviceURL,
+ MemoryLimitBytes: 100 * 1024,
+ })
+ assert.NoError(t, err)
+ defer c.Close()
+
+ topicName := newTopicName()
+ producer, _ := c.CreateProducer(ProducerOptions{
+ Topic: topicName,
+ })
+
+ ctx, cancel := context.WithCancel(context.Background())
+ cancel()
+ var callbackErr error
+ producer.SendAsync(ctx, &ProducerMessage{
+ Payload: make([]byte, 1024),
+ }, func(_ MessageID, _ *ProducerMessage, e error) {
+ callbackErr = e
+ })
+
+ require.Error(t, callbackErr)
+}
+
func TestBatchSendMessagesWithMetadata(t *testing.T) {
testSendMessagesWithMetadata(t, false)
}
@@ -2500,6 +2527,15 @@ func (pqw *pendingQueueWrapper) Put(item interface{}) {
pqw.pendingQueue.Put(item)
}
+func (pqw *pendingQueueWrapper) PutUnsafe(item interface{}) {
+ pi := item.(*pendingItem)
+ writerIdx := pi.buffer.WriterIndex()
+ buf := internal.NewBuffer(int(writerIdx))
+ buf.Write(pi.buffer.Get(0, writerIdx))
+ *pqw.writtenBuffers = append(*pqw.writtenBuffers, buf)
+ pqw.pendingQueue.PutUnsafe(item)
+}
+
func (pqw *pendingQueueWrapper) Take() interface{} {
return pqw.pendingQueue.Take()
}
@@ -2528,6 +2564,18 @@ func (pqw *pendingQueueWrapper) ReadableSlice()
[]interface{} {
return pqw.pendingQueue.ReadableSlice()
}
+func (pqw *pendingQueueWrapper) IterateUnsafe(f func(item interface{})) {
+ pqw.pendingQueue.IterateUnsafe(f)
+}
+
+func (pqw *pendingQueueWrapper) Lock() {
+ pqw.pendingQueue.Lock()
+}
+
+func (pqw *pendingQueueWrapper) Unlock() {
+ pqw.pendingQueue.Unlock()
+}
+
func TestDisableReplication(t *testing.T) {
client, err := NewClient(ClientOptions{
URL: serviceURL,
@@ -2741,3 +2789,112 @@ func TestSendBufferRetainWhenConnectionStuck(t
*testing.T) {
b := conn.buffers[0]
assert.Equal(t, int64(1), b.RefCnt(), "Expected buffer to have a
reference count of 1 after sending")
}
+
+func TestSendAsyncCouldTimeoutWhileReconnecting(t *testing.T) {
+ testSendAsyncCouldTimeoutWhileReconnecting(t, false)
+ testSendAsyncCouldTimeoutWhileReconnecting(t, true)
+}
+
+func testSendAsyncCouldTimeoutWhileReconnecting(t *testing.T,
isDisableBatching bool) {
+ t.Helper()
+
+ req := testcontainers.ContainerRequest{
+ Image: getPulsarTestImage(),
+ ExposedPorts: []string{"6650/tcp", "8080/tcp"},
+ WaitingFor: wait.ForExposedPort(),
+ Cmd: []string{"bin/pulsar", "standalone", "-nfw"},
+ }
+ c, err := testcontainers.GenericContainer(context.Background(),
testcontainers.GenericContainerRequest{
+ ContainerRequest: req,
+ Started: true,
+ })
+ require.NoError(t, err, "Failed to start the pulsar container")
+ defer func() {
+ err := c.Terminate(context.Background())
+ if err != nil {
+ t.Fatal("Failed to terminate the pulsar container", err)
+ }
+ }()
+
+ endpoint, err := c.PortEndpoint(context.Background(), "6650", "pulsar")
+ require.NoError(t, err, "Failed to get the pulsar endpoint")
+
+ client, err := NewClient(ClientOptions{
+ URL: endpoint,
+ ConnectionTimeout: 5 * time.Second,
+ OperationTimeout: 5 * time.Second,
+ })
+ require.NoError(t, err)
+ defer client.Close()
+
+ var testProducer Producer
+ require.Eventually(t, func() bool {
+ testProducer, err = client.CreateProducer(ProducerOptions{
+ Topic: newTopicName(),
+ Schema: NewBytesSchema(nil),
+ SendTimeout: 3 * time.Second,
+ DisableBatching: isDisableBatching,
+ BatchingMaxMessages: 5,
+ MaxPendingMessages: 10,
+ })
+ return err == nil
+ }, 30*time.Second, 1*time.Second)
+
+ numMessages := 10
+ // Send 10 messages synchronously
+ for i := 0; i < numMessages; i++ {
+ send, err := testProducer.Send(context.Background(),
&ProducerMessage{Payload: []byte("test")})
+ require.NoError(t, err)
+ require.NotNil(t, send)
+ }
+
+ // stop pulsar server
+ timeout := 10 * time.Second
+ err = c.Stop(context.Background(), &timeout)
+ require.NoError(t, err)
+
+ // Test the SendAsync could be timeout if the producer is reconnecting
+
+ finalErr := make(chan error, 1)
+ testProducer.SendAsync(context.Background(), &ProducerMessage{
+ Payload: []byte("test"),
+ }, func(_ MessageID, _ *ProducerMessage, err error) {
+ finalErr <- err
+ })
+ select {
+ case <-time.After(10 * time.Second):
+ t.Fatal("test timeout")
+ case err = <-finalErr:
+ // should get a timeout error
+ require.ErrorIs(t, err, ErrSendTimeout)
+ }
+ close(finalErr)
+
+ // Test that the SendAsync could be timeout if the pending queue is full
+
+ go func() {
+ // Send 10 messages asynchronously to make the pending queue
full
+ for i := 0; i < numMessages; i++ {
+ testProducer.SendAsync(context.Background(),
&ProducerMessage{
+ Payload: []byte("test"),
+ }, func(_ MessageID, _ *ProducerMessage, _ error) {
+ })
+ }
+ }()
+
+ time.Sleep(3 * time.Second)
+ finalErr = make(chan error, 1)
+ testProducer.SendAsync(context.Background(), &ProducerMessage{
+ Payload: []byte("test"),
+ }, func(_ MessageID, _ *ProducerMessage, err error) {
+ finalErr <- err
+ })
+ select {
+ case <-time.After(10 * time.Second):
+ t.Fatal("test timeout")
+ case err = <-finalErr:
+ // should get a timeout error
+ require.ErrorIs(t, err, ErrSendTimeout)
+ }
+ close(finalErr)
+}