This is an automated email from the ASF dual-hosted git repository.

lizhanhui pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/rocketmq-clients.git


The following commit(s) were added to refs/heads/master by this push:
     new 2eedee02 [rust] support handling settings command in Producer (#682)
2eedee02 is described below

commit 2eedee0298de62231ffd20e2181734169121734f
Author: Qiping Luo <qiping...@tencent.com>
AuthorDate: Wed Feb 21 22:48:50 2024 +0800

    [rust] support handling settings command in Producer (#682)
    
    * [rust]support handling settings command in Producer
    
    * [rust]propagate error to upper layer when processing transaction commands
    
    * chore: simplify error propagation
    
    Signed-off-by: Li Zhanhui <lizhan...@gmail.com>
    
    ---------
    
    Signed-off-by: Li Zhanhui <lizhan...@gmail.com>
    Co-authored-by: Li Zhanhui <lizhan...@gmail.com>
---
 rust/Cargo.toml             |   2 +-
 rust/src/client.rs          | 153 ++++---------------------
 rust/src/producer.rs        | 273 ++++++++++++++++++++++++++++++++++++++------
 rust/src/simple_consumer.rs |  31 ++++-
 4 files changed, 289 insertions(+), 170 deletions(-)

diff --git a/rust/Cargo.toml b/rust/Cargo.toml
index da47ed29..bc888279 100644
--- a/rust/Cargo.toml
+++ b/rust/Cargo.toml
@@ -58,7 +58,7 @@ minitrace = "0.4"
 byteorder = "1"
 mac_address = "1.1.4"
 hex = "0.4.3"
-time = "0.3"
+time = { version = "0.3", features = ["local-offset"] }
 once_cell = "1.18.0"
 
 mockall = "0.11.4"
diff --git a/rust/src/client.rs b/rust/src/client.rs
index 91bd3692..884f98b0 100644
--- a/rust/src/client.rs
+++ b/rust/src/client.rs
@@ -31,16 +31,14 @@ use tokio::sync::{mpsc, oneshot};
 use crate::conf::ClientOption;
 use crate::error::{ClientError, ErrorKind};
 use crate::model::common::{ClientType, Endpoints, Route, RouteStatus, 
SendReceipt};
-use crate::model::message::{AckMessageEntry, MessageView};
-use crate::model::transaction::{TransactionChecker, TransactionResolution};
+use crate::model::message::AckMessageEntry;
 use crate::pb;
 use crate::pb::receive_message_response::Content;
-use crate::pb::telemetry_command::Command::{RecoverOrphanedTransactionCommand, 
Settings};
 use crate::pb::{
     AckMessageRequest, AckMessageResultEntry, ChangeInvisibleDurationRequest, 
Code,
-    EndTransactionRequest, FilterExpression, HeartbeatRequest, 
HeartbeatResponse, Message,
-    MessageQueue, NotifyClientTerminationRequest, QueryRouteRequest, 
ReceiveMessageRequest,
-    Resource, SendMessageRequest, Status, TelemetryCommand, TransactionSource,
+    FilterExpression, HeartbeatRequest, HeartbeatResponse, Message, 
MessageQueue,
+    NotifyClientTerminationRequest, QueryRouteRequest, ReceiveMessageRequest, 
Resource,
+    SendMessageRequest, Status, TelemetryCommand,
 };
 #[double]
 use crate::session::SessionManager;
@@ -54,7 +52,6 @@ pub(crate) struct Client {
     id: String,
     access_endpoints: Endpoints,
     settings: TelemetryCommand,
-    transaction_checker: Option<Box<TransactionChecker>>,
     telemetry_command_tx: Option<mpsc::Sender<pb::telemetry_command::Command>>,
     shutdown_tx: Option<oneshot::Sender<()>>,
 }
@@ -70,8 +67,6 @@ const OPERATION_HEARTBEAT: &str = "client.heartbeat";
 const OPERATION_SEND_MESSAGE: &str = "client.send_message";
 const OPERATION_RECEIVE_MESSAGE: &str = "client.receive_message";
 const OPERATION_ACK_MESSAGE: &str = "client.ack_message";
-const OPERATION_END_TRANSACTION: &str = "client.end_transaction";
-const OPERATION_HANDLE_TELEMETRY_COMMAND: &str = 
"client.handle_telemetry_command";
 
 impl Debug for Client {
     fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
@@ -102,28 +97,23 @@ impl Client {
             id,
             access_endpoints: endpoints,
             settings,
-            transaction_checker: None,
             telemetry_command_tx: None,
             shutdown_tx: None,
         })
     }
 
-    pub(crate) fn is_started(&self) -> bool {
-        self.shutdown_tx.is_some()
+    pub(crate) fn get_endpoints(&self) -> Endpoints {
+        self.access_endpoints.clone()
     }
 
-    pub(crate) fn has_transaction_checker(&self) -> bool {
-        self.transaction_checker.is_some()
-    }
-
-    pub(crate) fn set_transaction_checker(&mut self, transaction_checker: 
Box<TransactionChecker>) {
-        if self.is_started() {
-            panic!("client {} is started, can not be modified", self.id)
-        }
-        self.transaction_checker = Some(transaction_checker);
+    pub(crate) fn is_started(&self) -> bool {
+        self.shutdown_tx.is_some()
     }
 
-    pub(crate) async fn start(&mut self) -> Result<(), ClientError> {
+    pub(crate) async fn start(
+        &mut self,
+        telemetry_command_tx: mpsc::Sender<pb::telemetry_command::Command>,
+    ) -> Result<(), ClientError> {
         let logger = self.logger.clone();
         let session_manager = self.session_manager.clone();
 
@@ -134,19 +124,12 @@ impl Client {
         let (shutdown_tx, mut shutdown_rx) = oneshot::channel();
         self.shutdown_tx = Some(shutdown_tx);
 
-        // send heartbeat and handle telemetry command
-        let (telemetry_command_tx, mut telemetry_command_rx) = 
mpsc::channel(16);
         self.telemetry_command_tx = Some(telemetry_command_tx);
+
         let rpc_client = self
             .get_session()
             .await
             .map_err(|error| error.with_operation(OPERATION_CLIENT_START))?;
-        let endpoints = self.access_endpoints.clone();
-        let transaction_checker = self.transaction_checker.take();
-        // give a placeholder
-        if transaction_checker.is_some() {
-            self.transaction_checker = Some(Box::new(|_, _| 
TransactionResolution::UNKNOWN));
-        }
 
         tokio::spawn(async move {
             rpc_client.is_started();
@@ -188,24 +171,13 @@ impl Client {
                             debug!(logger,"send heartbeat to server success, 
peer={}",peer);
                         }
                     },
-                    command = telemetry_command_rx.recv() => {
-                        if let Some(command) = command {
-                            let result = 
Self::handle_telemetry_command(rpc_client.shadow_session(), 
&transaction_checker, endpoints.clone(), command).await;
-                            if let Err(error) = result {
-                                error!(logger, "handle telemetry command 
failed: {:?}", error);
-                            }
-                        }
-                    },
                     _ = &mut shutdown_rx => {
-                        info!(logger, "receive shutdown signal, stop heartbeat 
task and telemetry command handler");
+                        info!(logger, "receive shutdown signal, stop heartbeat 
task.");
                         break;
                     }
                 }
             }
-            info!(
-                logger,
-                "heartbeat task and telemetry command handler are stopped"
-            );
+            info!(logger, "heartbeat task is stopped");
         });
         Ok(())
     }
@@ -239,58 +211,6 @@ impl Client {
         Ok(())
     }
 
-    async fn handle_telemetry_command<T: RPCClient + 'static>(
-        mut rpc_client: T,
-        transaction_checker: &Option<Box<TransactionChecker>>,
-        endpoints: Endpoints,
-        command: pb::telemetry_command::Command,
-    ) -> Result<(), ClientError> {
-        return match command {
-            RecoverOrphanedTransactionCommand(command) => {
-                let transaction_id = command.transaction_id;
-                let message = command.message.unwrap();
-                let message_id = message
-                    .system_properties
-                    .as_ref()
-                    .unwrap()
-                    .message_id
-                    .clone();
-                let topic = message.topic.as_ref().unwrap().clone();
-                if let Some(transaction_checker) = transaction_checker {
-                    let resolution = transaction_checker(
-                        transaction_id.clone(),
-                        MessageView::from_pb_message(message, endpoints),
-                    );
-
-                    let response = rpc_client
-                        .end_transaction(EndTransactionRequest {
-                            topic: Some(topic),
-                            message_id: message_id.to_string(),
-                            transaction_id,
-                            resolution: resolution as i32,
-                            source: TransactionSource::SourceServerCheck as 
i32,
-                            trace_context: "".to_string(),
-                        })
-                        .await?;
-                    Self::handle_response_status(response.status, 
OPERATION_END_TRANSACTION)
-                } else {
-                    Err(ClientError::new(
-                        ErrorKind::Config,
-                        "failed to get transaction checker",
-                        OPERATION_END_TRANSACTION,
-                    ))
-                }
-            }
-            Settings(_) => Ok(()),
-            _ => Err(ClientError::new(
-                ErrorKind::Config,
-                "receive telemetry command but there is no handler",
-                OPERATION_HANDLE_TELEMETRY_COMMAND,
-            )
-            .with_context("command", format!("{:?}", command))),
-        };
-    }
-
     pub(crate) fn client_id(&self) -> &str {
         &self.id
     }
@@ -704,13 +624,11 @@ pub(crate) mod tests {
     use crate::error::{ClientError, ErrorKind};
     use crate::log::terminal_logger;
     use crate::model::common::{ClientType, Route};
-    use crate::model::transaction::TransactionResolution;
     use crate::pb::receive_message_response::Content;
     use crate::pb::{
         AckMessageEntry, AckMessageResponse, ChangeInvisibleDurationResponse, 
Code,
-        EndTransactionResponse, FilterExpression, HeartbeatResponse, Message, 
MessageQueue,
-        QueryRouteResponse, ReceiveMessageResponse, Resource, 
SendMessageResponse, Status,
-        SystemProperties, TelemetryCommand,
+        FilterExpression, HeartbeatResponse, Message, MessageQueue, 
QueryRouteResponse,
+        ReceiveMessageResponse, Resource, SendMessageResponse, Status, 
TelemetryCommand,
     };
     use crate::session;
 
@@ -731,7 +649,6 @@ pub(crate) mod tests {
             id: Client::generate_client_id(),
             access_endpoints: 
Endpoints::from_url("http://localhost:8081";).unwrap(),
             settings: TelemetryCommand::default(),
-            transaction_checker: None,
             telemetry_command_tx: None,
             shutdown_tx: None,
         }
@@ -747,7 +664,6 @@ pub(crate) mod tests {
             id: Client::generate_client_id(),
             access_endpoints: 
Endpoints::from_url("http://localhost:8081";).unwrap(),
             settings: TelemetryCommand::default(),
-            transaction_checker: None,
             telemetry_command_tx: Some(tx),
             shutdown_tx: None,
         }
@@ -784,7 +700,8 @@ pub(crate) mod tests {
             .returning(|_, _, _| Ok(Session::mock()));
 
         let mut client = new_client_with_session_manager(session_manager);
-        client.start().await?;
+        let (tx, _) = mpsc::channel(16);
+        client.start(tx).await?;
 
         // TODO use countdown latch instead sleeping
         // wait for run
@@ -800,7 +717,8 @@ pub(crate) mod tests {
             .returning(|_, _, _| Ok(Session::mock()));
 
         let mut client = new_client_with_session_manager(session_manager);
-        let _ = client.start().await;
+        let (tx, _rx) = mpsc::channel(16);
+        let _ = client.start(tx).await;
         let result = client.get_session().await;
         assert!(result.is_ok());
         let result = client
@@ -1134,33 +1052,4 @@ pub(crate) mod tests {
         assert_eq!(error.message, "server return an error");
         assert_eq!(error.operation, "client.ack_message");
     }
-
-    #[tokio::test]
-    async fn client_handle_telemetry_command() {
-        let response = Ok(EndTransactionResponse {
-            status: Some(Status {
-                code: Code::Ok as i32,
-                message: "".to_string(),
-            }),
-        });
-        let mut mock = session::MockRPCClient::new();
-        mock.expect_end_transaction()
-            .return_once(|_| Box::pin(futures::future::ready(response)));
-        let result = Client::handle_telemetry_command(
-            mock,
-            &Some(Box::new(|_, _| TransactionResolution::COMMIT)),
-            Endpoints::from_url("localhost:8081").unwrap(),
-            
RecoverOrphanedTransactionCommand(pb::RecoverOrphanedTransactionCommand {
-                message: Some(Message {
-                    topic: Some(Resource::default()),
-                    user_properties: Default::default(),
-                    system_properties: Some(SystemProperties::default()),
-                    body: vec![],
-                }),
-                transaction_id: "".to_string(),
-            }),
-        )
-        .await;
-        assert!(result.is_ok())
-    }
 }
diff --git a/rust/src/producer.rs b/rust/src/producer.rs
index 2a69f079..e456cbe7 100644
--- a/rust/src/producer.rs
+++ b/rust/src/producer.rs
@@ -15,20 +15,30 @@
  * limitations under the License.
  */
 
+use std::fmt::Debug;
+use std::sync::Arc;
 use std::time::{SystemTime, UNIX_EPOCH};
 
 use mockall_double::double;
 use prost_types::Timestamp;
-use slog::{info, Logger};
+use slog::{error, info, warn, Logger};
+use tokio::select;
+use tokio::sync::RwLock;
+use tokio::sync::{mpsc, oneshot};
 
 #[double]
 use crate::client::Client;
 use crate::conf::{ClientOption, ProducerOption};
 use crate::error::{ClientError, ErrorKind};
-use crate::model::common::{ClientType, SendReceipt};
-use crate::model::message::{self, MessageTypeAware};
-use crate::model::transaction::{Transaction, TransactionChecker, 
TransactionImpl};
-use crate::pb::{Encoding, Resource, SystemProperties};
+use crate::model::common::{ClientType, Endpoints, SendReceipt};
+use crate::model::message::{self, MessageTypeAware, MessageView};
+use crate::model::transaction::{
+    Transaction, TransactionChecker, TransactionImpl, TransactionResolution,
+};
+use crate::pb::settings::PubSub;
+use crate::pb::telemetry_command::Command::{RecoverOrphanedTransactionCommand, 
Settings};
+use crate::pb::{Encoding, EndTransactionRequest, Resource, SystemProperties, 
TransactionSource};
+use crate::session::RPCClient;
 use crate::util::{
     build_endpoints_by_message_queue, build_producer_settings, 
select_message_queue,
     select_message_queue_by_message_group, HOST_NAME,
@@ -41,17 +51,27 @@ use crate::{log, pb};
 /// Most of its methods take shared reference so that application developers 
may use it at will.
 ///
 /// [`Producer`] is `Send` and `Sync` by design, so that developers may get 
started easily.
-#[derive(Debug)]
 pub struct Producer {
-    option: ProducerOption,
+    option: Arc<RwLock<ProducerOption>>,
     logger: Logger,
     client: Client,
+    transaction_checker: Option<Box<TransactionChecker>>,
+    shutdown_tx: Option<oneshot::Sender<()>>,
+}
+
+impl Debug for Producer {
+    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
+        f.debug_struct("Producer")
+            .field("option", &self.option)
+            .field("client", &self.client)
+            .finish()
+    }
 }
 
 impl Producer {
     const OPERATION_SEND_MESSAGE: &'static str = "producer.send_message";
     const OPERATION_SEND_TRANSACTION_MESSAGE: &'static str = 
"producer.send_transaction_message";
-
+    const OPERATION_END_TRANSACTION: &'static str = "producer.end_transaction";
     /// Create a new producer instance
     ///
     /// # Arguments
@@ -67,10 +87,13 @@ impl Producer {
         let logger = log::logger(option.logging_format());
         let settings = build_producer_settings(&option, &client_option);
         let client = Client::new(&logger, client_option, settings)?;
+        let option = Arc::new(RwLock::new(option));
         Ok(Producer {
             option,
             logger,
             client,
+            transaction_checker: None,
+            shutdown_tx: None,
         })
     }
 
@@ -93,23 +116,80 @@ impl Producer {
         };
         let logger = log::logger(option.logging_format());
         let settings = build_producer_settings(&option, &client_option);
-        let mut client = Client::new(&logger, client_option, settings)?;
-        client.set_transaction_checker(transaction_checker);
+        let client = Client::new(&logger, client_option, settings)?;
+        let option = Arc::new(RwLock::new(option));
         Ok(Producer {
             option,
             logger,
             client,
+            transaction_checker: Some(transaction_checker),
+            shutdown_tx: None,
         })
     }
 
+    async fn get_resource_namespace(&self) -> String {
+        let option_guard = self.option.read();
+        let resource_namespace = option_guard.await.namespace().to_string();
+        resource_namespace
+    }
+
     /// Start the producer
     pub async fn start(&mut self) -> Result<(), ClientError> {
-        self.client.start().await?;
-        if let Some(topics) = self.option.topics() {
+        let (telemetry_command_tx, mut telemetry_command_rx) = 
mpsc::channel(16);
+        let telemetry_command_tx: mpsc::Sender<pb::telemetry_command::Command> 
=
+            telemetry_command_tx;
+        self.client.start(telemetry_command_tx).await?;
+        let option_guard = self.option.read().await;
+        let topics = option_guard.topics();
+        if let Some(topics) = topics {
             for topic in topics {
                 self.client.topic_route(topic, true).await?;
             }
         }
+        drop(option_guard);
+        let transaction_checker = self.transaction_checker.take();
+        if transaction_checker.is_some() {
+            self.transaction_checker = Some(Box::new(|_, _| 
TransactionResolution::UNKNOWN));
+        }
+        let (shutdown_tx, mut shutdown_rx) = oneshot::channel();
+        self.shutdown_tx = Some(shutdown_tx);
+        let rpc_client = self.client.get_session().await?;
+        let endpoints = self.client.get_endpoints();
+        let logger = self.logger.clone();
+        let producer_option = Arc::clone(&self.option);
+        tokio::spawn(async move {
+            loop {
+                select! {
+                    command = telemetry_command_rx.recv() => {
+                        if let Some(command) = command {
+                            match command {
+                                RecoverOrphanedTransactionCommand(command) => {
+                                    let result = 
Self::handle_recover_orphaned_transaction_command(
+                                            rpc_client.shadow_session(),
+                                            command,
+                                            &transaction_checker,
+                                            endpoints.clone()).await;
+                                    if let Err(error) = result {
+                                        error!(logger, "handle trannsaction 
command failed: {:?}", error);
+                                    };
+                                }
+                                Settings(command) => {
+                                    let option = &mut 
producer_option.write().await;
+                                    Self::handle_settings_command(command, 
option);
+                                    info!(logger, "handle setting command 
success.");
+                                }
+                                _ => {
+                                    warn!(logger, "unimplemented command 
{:?}", command);
+                                }
+                            }
+                        }
+                    }
+                    _ = &mut shutdown_rx => {
+                       break;
+                    }
+                }
+            }
+        });
         info!(
             self.logger,
             "start producer success, client_id: {}",
@@ -118,7 +198,70 @@ impl Producer {
         Ok(())
     }
 
-    fn transform_messages_to_protobuf(
+    async fn handle_recover_orphaned_transaction_command<T: RPCClient + 
'static>(
+        mut rpc_client: T,
+        command: pb::RecoverOrphanedTransactionCommand,
+        transaction_checker: &Option<Box<TransactionChecker>>,
+        endpoints: Endpoints,
+    ) -> Result<(), ClientError> {
+        let transaction_id = command.transaction_id;
+        let message = command.message.clone().ok_or_else(|| {
+            ClientError::new(
+                ErrorKind::InvalidMessage,
+                "no message in command",
+                Self::OPERATION_END_TRANSACTION,
+            )
+        })?;
+        let message_id = message
+            .system_properties
+            .as_ref()
+            .map(|props| props.message_id.clone())
+            .ok_or_else(|| {
+                ClientError::new(
+                    ErrorKind::InvalidMessage,
+                    "no message id exists",
+                    Self::OPERATION_END_TRANSACTION,
+                )
+            })?;
+        let topic = message.topic.clone().ok_or_else(|| {
+            ClientError::new(
+                ErrorKind::InvalidMessage,
+                "no topic exists in message",
+                Self::OPERATION_END_TRANSACTION,
+            )
+        })?;
+        if let Some(transaction_checker) = transaction_checker {
+            let resolution = transaction_checker(
+                transaction_id.clone(),
+                MessageView::from_pb_message(message, endpoints),
+            );
+            let response = rpc_client
+                .end_transaction(EndTransactionRequest {
+                    topic: Some(topic),
+                    message_id,
+                    transaction_id,
+                    resolution: resolution as i32,
+                    source: TransactionSource::SourceServerCheck as i32,
+                    trace_context: "".to_string(),
+                })
+                .await?;
+            Client::handle_response_status(response.status, 
Self::OPERATION_END_TRANSACTION)
+        } else {
+            Err(ClientError::new(
+                ErrorKind::Config,
+                "failed to get transaction checker",
+                Self::OPERATION_END_TRANSACTION,
+            ))
+        }
+    }
+
+    fn handle_settings_command(settings: pb::Settings, option: &mut 
ProducerOption) {
+        if let Some(PubSub::Publishing(publishing)) = settings.pub_sub {
+            option.set_validate_message_type(publishing.validate_message_type);
+        };
+    }
+
+    async fn transform_messages_to_protobuf(
         &self,
         messages: Vec<impl message::Message>,
     ) -> Result<(String, Option<String>, Vec<pb::Message>), ClientError> {
@@ -185,7 +328,7 @@ impl Producer {
             let pb_message = pb::Message {
                 topic: Some(Resource {
                     name: message.take_topic(),
-                    resource_namespace: self.option.namespace().to_string(),
+                    resource_namespace: self.get_resource_namespace().await,
                 }),
                 user_properties: message.take_properties(),
                 system_properties: Some(SystemProperties {
@@ -243,7 +386,7 @@ impl Producer {
             .collect::<Vec<_>>();
 
         let (topic, message_group, mut pb_messages) =
-            self.transform_messages_to_protobuf(messages)?;
+            self.transform_messages_to_protobuf(messages).await?;
 
         let route = self.client.topic_route(&topic, true).await?;
 
@@ -253,7 +396,10 @@ impl Producer {
             select_message_queue(route)
         };
 
-        if self.option.validate_message_type() {
+        let option_guard = self.option.read().await;
+        let validate_message_type = option_guard.validate_message_type();
+        drop(option_guard);
+        if validate_message_type {
             for message_type in message_types {
                 if !message_queue.accept_type(message_type) {
                     return Err(ClientError::new(
@@ -278,12 +424,16 @@ impl Producer {
         self.client.send_message(&endpoints, pb_messages).await
     }
 
+    pub fn has_transaction_checker(&self) -> bool {
+        self.transaction_checker.is_some()
+    }
+
     /// Send message in a transaction
     pub async fn send_transaction_message(
         &self,
         mut message: impl message::Message,
     ) -> Result<impl Transaction, ClientError> {
-        if !self.client.has_transaction_checker() {
+        if !self.has_transaction_checker() {
             return Err(ClientError::new(
                 ErrorKind::InvalidMessage,
                 "this producer can not send transaction message, please create 
a transaction producer using producer::new_transaction_producer",
@@ -296,14 +446,17 @@ impl Producer {
         Ok(TransactionImpl::new(
             Box::new(rpc_client),
             Resource {
-                resource_namespace: self.option.namespace().to_string(),
+                resource_namespace: self.get_resource_namespace().await,
                 name: topic,
             },
             receipt,
         ))
     }
 
-    pub async fn shutdown(self) -> Result<(), ClientError> {
+    pub async fn shutdown(mut self) -> Result<(), ClientError> {
+        if let Some(tx) = self.shutdown_tx.take() {
+            let _ = tx.send(());
+        }
         self.client.shutdown().await
     }
 }
@@ -312,13 +465,14 @@ impl Producer {
 mod tests {
     use std::sync::Arc;
 
+    use crate::client::MockClient;
     use crate::error::ErrorKind;
     use crate::log::terminal_logger;
     use crate::model::common::Route;
     use crate::model::message::{MessageBuilder, MessageImpl, MessageType};
     use crate::model::transaction::TransactionResolution;
-    use crate::pb::{Broker, MessageQueue};
-    use crate::session::Session;
+    use crate::pb::{Broker, Code, EndTransactionResponse, MessageQueue, 
Status};
+    use crate::session::{self, Session};
 
     use super::*;
 
@@ -327,6 +481,18 @@ mod tests {
             option: Default::default(),
             logger: terminal_logger(),
             client: Client::default(),
+            shutdown_tx: None,
+            transaction_checker: None,
+        }
+    }
+
+    fn new_transaction_producer_for_test() -> Producer {
+        Producer {
+            option: Default::default(),
+            logger: terminal_logger(),
+            client: Client::default(),
+            shutdown_tx: None,
+            transaction_checker: Some(Box::new(|_, _| 
TransactionResolution::COMMIT)),
         }
     }
 
@@ -343,10 +509,16 @@ mod tests {
                     queue: vec![],
                 }))
             });
-            client.expect_start().returning(|| Ok(()));
+            client.expect_start().returning(|_| Ok(()));
             client
                 .expect_client_id()
                 .return_const("fake_id".to_string());
+            client
+                .expect_get_session()
+                .return_once(|| Ok(Session::mock()));
+            client
+                .expect_get_endpoints()
+                .return_once(|| 
Endpoints::from_url("foobar.com:8080").unwrap());
             Ok(client)
         });
         let mut producer_option = ProducerOption::default();
@@ -370,11 +542,16 @@ mod tests {
                     queue: vec![],
                 }))
             });
-            client.expect_start().returning(|| Ok(()));
-            client.expect_set_transaction_checker().returning(|_| ());
+            client.expect_start().returning(|_| Ok(()));
             client
                 .expect_client_id()
                 .return_const("fake_id".to_string());
+            client
+                .expect_get_session()
+                .return_once(|| Ok(Session::mock()));
+            client
+                .expect_get_endpoints()
+                .return_once(|| 
Endpoints::from_url("foobar.com:8080").unwrap());
             Ok(client)
         });
         let mut producer_option = ProducerOption::default();
@@ -401,7 +578,7 @@ mod tests {
             .set_message_group("message_group".to_string())
             .build()
             .unwrap()];
-        let result = producer.transform_messages_to_protobuf(messages);
+        let result = producer.transform_messages_to_protobuf(messages).await;
         assert!(result.is_ok());
 
         let (topic, message_group, pb_messages) = result.unwrap();
@@ -425,7 +602,7 @@ mod tests {
         let producer = new_producer_for_test();
 
         let messages: Vec<MessageImpl> = vec![];
-        let result = producer.transform_messages_to_protobuf(messages);
+        let result = producer.transform_messages_to_protobuf(messages).await;
         assert!(result.is_err());
         let err = result.unwrap_err();
         assert_eq!(err.kind, ErrorKind::InvalidMessage);
@@ -443,7 +620,7 @@ mod tests {
             transaction_enabled: false,
             message_type: MessageType::TRANSACTION,
         }];
-        let result = producer.transform_messages_to_protobuf(messages);
+        let result = producer.transform_messages_to_protobuf(messages).await;
         assert!(result.is_err());
         let err = result.unwrap_err();
         assert_eq!(err.kind, ErrorKind::InvalidMessage);
@@ -461,7 +638,7 @@ mod tests {
                 .build()
                 .unwrap(),
         ];
-        let result = producer.transform_messages_to_protobuf(messages);
+        let result = producer.transform_messages_to_protobuf(messages).await;
         assert!(result.is_err());
         let err = result.unwrap_err();
         assert_eq!(err.kind, ErrorKind::InvalidMessage);
@@ -481,7 +658,7 @@ mod tests {
                 .build()
                 .unwrap(),
         ];
-        let result = producer.transform_messages_to_protobuf(messages);
+        let result = producer.transform_messages_to_protobuf(messages).await;
         assert!(result.is_err());
         let err = result.unwrap_err();
         assert_eq!(err.kind, ErrorKind::InvalidMessage);
@@ -538,7 +715,7 @@ mod tests {
 
     #[tokio::test]
     async fn producer_send_transaction_message() -> Result<(), ClientError> {
-        let mut producer = new_producer_for_test();
+        let mut producer = new_transaction_producer_for_test();
         producer.client.expect_topic_route().returning(|_, _| {
             Ok(Arc::new(Route {
                 index: Default::default(),
@@ -574,10 +751,6 @@ mod tests {
             .client
             .expect_get_session()
             .return_once(|| Ok(Session::mock()));
-        producer
-            .client
-            .expect_has_transaction_checker()
-            .return_once(|| true);
 
         let _ = producer
             .send_transaction_message(
@@ -588,4 +761,36 @@ mod tests {
             .await?;
         Ok(())
     }
+
+    #[tokio::test]
+    async fn client_handle_recover_orphaned_transaction_command() {
+        let response = Ok(EndTransactionResponse {
+            status: Some(Status {
+                code: Code::Ok as i32,
+                message: "".to_string(),
+            }),
+        });
+        let mut mock = session::MockRPCClient::new();
+        mock.expect_end_transaction()
+            .return_once(|_| Box::pin(futures::future::ready(response)));
+
+        let context = MockClient::handle_response_status_context();
+        context.expect().return_once(|_, _| Result::Ok(()));
+        let result = Producer::handle_recover_orphaned_transaction_command(
+            mock,
+            pb::RecoverOrphanedTransactionCommand {
+                message: Some(pb::Message {
+                    topic: Some(Resource::default()),
+                    user_properties: Default::default(),
+                    system_properties: Some(SystemProperties::default()),
+                    body: vec![],
+                }),
+                transaction_id: "".to_string(),
+            },
+            &Some(Box::new(|_, _| TransactionResolution::COMMIT)),
+            Endpoints::from_url("localhost:8081").unwrap(),
+        )
+        .await;
+        assert!(result.is_ok())
+    }
 }
diff --git a/rust/src/simple_consumer.rs b/rust/src/simple_consumer.rs
index f8a6eace..e891d384 100644
--- a/rust/src/simple_consumer.rs
+++ b/rust/src/simple_consumer.rs
@@ -18,7 +18,9 @@
 use std::time::Duration;
 
 use mockall_double::double;
-use slog::{info, Logger};
+use slog::{info, warn, Logger};
+use tokio::select;
+use tokio::sync::{mpsc, oneshot};
 
 #[double]
 use crate::client::Client;
@@ -45,6 +47,7 @@ pub struct SimpleConsumer {
     option: SimpleConsumerOption,
     logger: Logger,
     client: Client,
+    shutdown_tx: Option<oneshot::Sender<()>>,
 }
 
 impl SimpleConsumer {
@@ -78,6 +81,7 @@ impl SimpleConsumer {
             option,
             logger,
             client,
+            shutdown_tx: None,
         })
     }
 
@@ -90,12 +94,29 @@ impl SimpleConsumer {
                 Self::OPERATION_START_SIMPLE_CONSUMER,
             ));
         }
-        self.client.start().await?;
+        let (telemetry_command_tx, mut telemetry_command_rx) = 
mpsc::channel(16);
+        self.client.start(telemetry_command_tx).await?;
         if let Some(topics) = self.option.topics() {
             for topic in topics {
                 self.client.topic_route(topic, true).await?;
             }
         }
+        let (shutdown_tx, mut shutdown_rx) = oneshot::channel();
+        self.shutdown_tx = Some(shutdown_tx);
+        let logger = self.logger.clone();
+        tokio::spawn(async move {
+            loop {
+                select! {
+                    command = telemetry_command_rx.recv() => {
+                        warn!(logger, "command {:?} cannot be handled in 
simple consumer.", command);
+                    }
+
+                    _ = &mut shutdown_rx => {
+                       break;
+                    }
+                }
+            }
+        });
         info!(
             self.logger,
             "start simple consumer success, client_id: {}",
@@ -105,6 +126,9 @@ impl SimpleConsumer {
     }
 
     pub async fn shutdown(self) -> Result<(), ClientError> {
+        if let Some(shutdown_tx) = self.shutdown_tx {
+            let _ = shutdown_tx.send(());
+        };
         self.client.shutdown().await
     }
 
@@ -215,7 +239,7 @@ mod tests {
                     queue: vec![],
                 }))
             });
-            client.expect_start().returning(|| Ok(()));
+            client.expect_start().returning(|_| Ok(()));
             client
                 .expect_client_id()
                 .return_const("fake_id".to_string());
@@ -272,6 +296,7 @@ mod tests {
             option: SimpleConsumerOption::default(),
             logger: terminal_logger(),
             client,
+            shutdown_tx: None,
         };
 
         let messages = simple_consumer


Reply via email to