This is an automated email from the ASF dual-hosted git repository.
aaronai pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/rocketmq-clients.git
The following commit(s) were added to refs/heads/master by this push:
new f1d5e959 Initial implementation of rocketmq producer (#546)
f1d5e959 is described below
commit f1d5e959067d44038da5a29ea69ee9b44b00b15e
Author: Yan Chao Mei <[email protected]>
AuthorDate: Mon Jun 26 10:49:12 2023 +0800
Initial implementation of rocketmq producer (#546)
* add client & producer
* Finish Producer basic logic
Now it can produce message
* fix style problem
* Using black to format code
* Remove username and password
* Fix style issue
---------
Co-authored-by: Aaron Ai <[email protected]>
---
python/rocketmq/client.py | 255 +++++++++++++++++++++
python/rocketmq/client_config.py | 11 +-
python/rocketmq/client_manager.py | 15 +-
python/rocketmq/definition.py | 178 ++++++++++++++
python/rocketmq/log.py | 3 +-
python/rocketmq/producer.py | 133 +++++++++++
python/rocketmq/publish_settings.py | 85 +++++++
python/rocketmq/rpc_client.py | 222 +++++++++++++++---
python/rocketmq/session.py | 44 ++++
.../rocketmq/{utils.py => session_credentials.py} | 35 ++-
python/rocketmq/settings.py | 80 +++++++
python/rocketmq/signature.py | 6 +-
python/rocketmq/utils.py | 2 +
13 files changed, 1005 insertions(+), 64 deletions(-)
diff --git a/python/rocketmq/client.py b/python/rocketmq/client.py
new file mode 100644
index 00000000..4ccda2b9
--- /dev/null
+++ b/python/rocketmq/client.py
@@ -0,0 +1,255 @@
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import threading
+from typing import Set
+
+from protocol import service_pb2
+from protocol.service_pb2 import QueryRouteRequest
+from rocketmq.client_config import ClientConfig
+from rocketmq.client_id_encoder import ClientIdEncoder
+from rocketmq.definition import TopicRouteData
+from rocketmq.rpc_client import Endpoints, RpcClient
+from rocketmq.session import Session
+from rocketmq.signature import Signature
+
+
+class Client:
+ def __init__(self, client_config: ClientConfig, topics: Set[str]):
+ self.client_config = client_config
+ self.client_id = ClientIdEncoder.generate()
+ self.endpoints = client_config.endpoints
+ self.topics = topics
+
+ self.topic_route_cache = {}
+
+ self.sessions_table = {}
+ self.sessionsLock = threading.Lock()
+ self.client_manager = ClientManager(self)
+
+ async def start_up(self):
+ # get topic route
+ for topic in self.topics:
+ self.topic_route_cache[topic] = await self.fetch_topic_route(topic)
+
+ def GetTotalRouteEndpoints(self):
+ endpoints = set()
+ for item in self.topic_route_cache.items():
+ for endpoint in [mq.broker.endpoints for mq in
item[1].message_queues]:
+ endpoints.add(endpoint)
+ return endpoints
+
+ def get_client_config(self):
+ return self.client_config
+
+ async def OnTopicRouteDataFetched(self, topic, topicRouteData):
+ route_endpoints = set()
+ for mq in topicRouteData.message_queues:
+ route_endpoints.add(mq.broker.endpoints)
+
+ existed_route_endpoints = self.GetTotalRouteEndpoints()
+ new_endpoints = route_endpoints.difference(existed_route_endpoints)
+
+ for endpoints in new_endpoints:
+ created, session = await self.GetSession(endpoints)
+ if not created:
+ continue
+
+ await session.sync_settings(True)
+
+ self.topic_route_cache[topic] = topicRouteData
+ # self.OnTopicRouteDataUpdated0(topic, topicRouteData)
+
+ async def fetch_topic_route0(self, topic):
+ req = QueryRouteRequest()
+ req.topic.name = topic
+ address = req.endpoints.addresses.add()
+ address.host = self.endpoints.Addresses[0].host
+ address.port = self.endpoints.Addresses[0].port
+ req.endpoints.scheme =
self.endpoints.scheme.to_protobuf(self.endpoints.scheme)
+ response = await self.client_manager.query_route(self.endpoints, req,
10)
+
+ message_queues = response.message_queues
+ return TopicRouteData(message_queues)
+
+ # return topic data
+ async def fetch_topic_route(self, topic):
+ topic_route_data = await self.fetch_topic_route0(topic)
+ await self.OnTopicRouteDataFetched(topic, topic_route_data)
+ return topic_route_data
+
+ async def GetSession(self, endpoints):
+ self.sessionsLock.acquire()
+ try:
+ # Session exists, return in advance.
+ if endpoints in self.sessions_table:
+ return (False, self.sessions_table[endpoints])
+ finally:
+ self.sessionsLock.release()
+
+ self.sessionsLock.acquire()
+ try:
+ # Session exists, return in advance.
+ if endpoints in self.sessions_table:
+ return (False, self.sessions_table[endpoints])
+
+ stream = self.client_manager.telemetry(endpoints, 10)
+ created = Session(endpoints, stream, self)
+ self.sessions_table[endpoints] = created
+ return (True, created)
+ finally:
+ self.sessionsLock.release()
+
+
+class ClientManager:
+ def __init__(self, client: Client):
+ self.__client = client
+ self.__rpc_clients = {}
+ self.__rpc_clients_lock = threading.Lock()
+
+ def __get_rpc_client(self, endpoints: Endpoints, ssl_enabled: bool):
+ with self.__rpc_clients_lock:
+ rpc_client = self.__rpc_clients.get(endpoints)
+ if rpc_client:
+ return rpc_client
+ rpc_client = RpcClient(endpoints.grpc_target(True), ssl_enabled)
+ self.__rpc_clients[endpoints] = rpc_client
+ return rpc_client
+
+ async def query_route(
+ self,
+ endpoints: Endpoints,
+ request: service_pb2.QueryRouteRequest,
+ timeout_seconds: int,
+ ):
+ rpc_client = self.__get_rpc_client(
+ endpoints, self.__client.client_config.ssl_enabled
+ )
+
+ metadata = Signature.sign(self.__client.client_config,
self.__client.client_id)
+ return await rpc_client.query_route(request, metadata, timeout_seconds)
+
+ async def heartbeat(
+ self,
+ endpoints: Endpoints,
+ request: service_pb2.HeartbeatRequest,
+ timeout_seconds: int,
+ ):
+ rpc_client = self.__get_rpc_client(
+ endpoints, self.__client.client_config.ssl_enabled
+ )
+ metadata = Signature.sign(self.__client.client_config,
self.__client.client_id)
+ return await rpc_client.heartbeat(request, metadata, timeout_seconds)
+
+ async def send_message(
+ self,
+ endpoints: Endpoints,
+ request: service_pb2.SendMessageRequest,
+ timeout_seconds: int,
+ ):
+ rpc_client = self.__get_rpc_client(
+ endpoints, self.__client.client_config.ssl_enabled
+ )
+ metadata = Signature.sign(self.__client.client_config,
self.__client.client_id)
+ return await rpc_client.send_message(request, metadata,
timeout_seconds)
+
+ async def query_assignment(
+ self,
+ endpoints: Endpoints,
+ request: service_pb2.QueryAssignmentRequest,
+ timeout_seconds: int,
+ ):
+ rpc_client = self.__get_rpc_client(
+ endpoints, self.__client.client_config.ssl_enabled
+ )
+ metadata = Signature.sign(self.__client.client_config,
self.__client.client_id)
+ return await rpc_client.query_assignment(request, metadata,
timeout_seconds)
+
+ async def ack_message(
+ self,
+ endpoints: Endpoints,
+ request: service_pb2.AckMessageRequest,
+ timeout_seconds: int,
+ ):
+ rpc_client = self.__get_rpc_client(
+ endpoints, self.__client.client_config.ssl_enabled
+ )
+ metadata = Signature.sign(self.__client.client_config,
self.__client.client_id)
+ return await rpc_client.ack_message(request, metadata, timeout_seconds)
+
+ async def forward_message_to_dead_letter_queue(
+ self,
+ endpoints: Endpoints,
+ request: service_pb2.ForwardMessageToDeadLetterQueueRequest,
+ timeout_seconds: int,
+ ):
+ rpc_client = self.__get_rpc_client(
+ endpoints, self.__client.client_config.ssl_enabled
+ )
+ metadata = Signature.sign(self.__client.client_config,
self.__client.client_id)
+ return await rpc_client.forward_message_to_dead_letter_queue(
+ request, metadata, timeout_seconds
+ )
+
+ async def end_transaction(
+ self,
+ endpoints: Endpoints,
+ request: service_pb2.EndTransactionRequest,
+ timeout_seconds: int,
+ ):
+ rpc_client = self.__get_rpc_client(
+ endpoints, self.__client.client_config.ssl_enabled
+ )
+ metadata = Signature.sign(self.__client.client_config,
self.__client.client_id)
+ return await rpc_client.end_transaction(request, metadata,
timeout_seconds)
+
+ async def notify_client_termination(
+ self,
+ endpoints: Endpoints,
+ request: service_pb2.NotifyClientTerminationRequest,
+ timeout_seconds: int,
+ ):
+ rpc_client = self.__get_rpc_client(
+ endpoints, self.__client.client_config.ssl_enabled
+ )
+ metadata = Signature.sign(self.__client.client_config,
self.__client.client_id)
+ return await rpc_client.notify_client_termination(
+ request, metadata, timeout_seconds
+ )
+
+ async def change_invisible_duration(
+ self,
+ endpoints: Endpoints,
+ request: service_pb2.ChangeInvisibleDurationRequest,
+ timeout_seconds: int,
+ ):
+ rpc_client = self.__get_rpc_client(
+ endpoints, self.__client.client_config.ssl_enabled
+ )
+ metadata = Signature.sign(self.__client.client_config,
self.__client.client_id)
+ return await rpc_client.change_invisible_duration(
+ request, metadata, timeout_seconds
+ )
+
+ def telemetry(
+ self,
+ endpoints: Endpoints,
+ timeout_seconds: int,
+ ):
+ rpc_client = self.__get_rpc_client(
+ endpoints, self.__client.client_config.ssl_enabled
+ )
+ metadata = Signature.sign(self.__client.client_config,
self.__client.client_id)
+ return rpc_client.telemetry(metadata, timeout_seconds)
diff --git a/python/rocketmq/client_config.py b/python/rocketmq/client_config.py
index 74f9a8ee..41e691c4 100644
--- a/python/rocketmq/client_config.py
+++ b/python/rocketmq/client_config.py
@@ -13,12 +13,21 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+from rocketmq.rpc_client import Endpoints
+from rocketmq.session_credentials import SessionCredentialsProvider
+
class ClientConfig:
- def __init__(self, endpoints: str, session_credentials_provider,
ssl_enabled: bool):
+ def __init__(
+ self,
+ endpoints: Endpoints,
+ session_credentials_provider: SessionCredentialsProvider,
+ ssl_enabled: bool,
+ ):
self.__endpoints = endpoints
self.__session_credentials_provider = session_credentials_provider
self.__ssl_enabled = ssl_enabled
+ self.request_timeout = 10
@property
def session_credentials_provider(self):
diff --git a/python/rocketmq/client_manager.py
b/python/rocketmq/client_manager.py
index dd13b2f5..a1a87b91 100644
--- a/python/rocketmq/client_manager.py
+++ b/python/rocketmq/client_manager.py
@@ -26,12 +26,12 @@ class ClientManager:
self.__rpc_clients = {}
self.__rpc_clients_lock = threading.Lock()
- def __get_rpc_client(self, endpoints: Endpoints, ssl_enabled: bool) ->
RpcClient:
+ def __get_rpc_client(self, endpoints: Endpoints, ssl_enabled: bool):
with self.__rpc_clients_lock:
rpc_client = self.__rpc_clients.get(endpoints)
if rpc_client:
return rpc_client
- rpc_client = RpcClient(endpoints.get_target(), ssl_enabled)
+ rpc_client = RpcClient(endpoints, ssl_enabled)
self.__rpc_clients[endpoints] = rpc_client
return rpc_client
@@ -135,3 +135,14 @@ class ClientManager:
endpoints, self.__client.client_config.ssl_enabled
)
return await rpc_client.change_invisible_duration(request,
timeout_seconds)
+
+ async def telemetry(
+ self,
+ endpoints: Endpoints,
+ request: service_pb2.TelemetryCommand,
+ timeout_seconds: int,
+ ):
+ rpc_client = self.__get_rpc_client(
+ endpoints, self.__client.client_config.ssl_enabled
+ )
+ return await rpc_client.telemetry()
diff --git a/python/rocketmq/definition.py b/python/rocketmq/definition.py
new file mode 100644
index 00000000..b2115b60
--- /dev/null
+++ b/python/rocketmq/definition.py
@@ -0,0 +1,178 @@
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from enum import Enum
+from typing import List
+
+from protocol.definition_pb2 import Broker as ProtoBroker
+from protocol.definition_pb2 import MessageQueue as ProtoMessageQueue
+from protocol.definition_pb2 import MessageType as ProtoMessageType
+from protocol.definition_pb2 import Permission as ProtoPermission
+from protocol.definition_pb2 import Resource as ProtoResource
+from rocketmq.protocol import definition_pb2
+from rocketmq.rpc_client import Endpoints
+
+
+class Broker:
+ def __init__(self, broker):
+ self.name = broker.name
+ self.id = broker.id
+ self.endpoints = Endpoints(broker.endpoints)
+
+ def to_protobuf(self):
+ return ProtoBroker(
+ Name=self.name, Id=self.id, Endpoints=self.endpoints.to_protobuf()
+ )
+
+
+class Resource:
+ def __init__(self, name=None, resource=None):
+ if resource is not None:
+ self.namespace = resource.ResourceNamespace
+ self.name = resource.Name
+ else:
+ self.namespace = ""
+ self.name = name
+
+ def to_protobuf(self):
+ return ProtoResource(ResourceNamespace=self.namespace, Name=self.name)
+
+ def __str__(self):
+ return f"{self.namespace}.{self.name}" if self.namespace else self.name
+
+
+class Permission(Enum):
+ NONE = 0
+ READ = 1
+ WRITE = 2
+ READ_WRITE = 3
+
+
+class PermissionHelper:
+ @staticmethod
+ def from_protobuf(permission):
+ if permission == ProtoPermission.READ:
+ return Permission.READ
+ elif permission == ProtoPermission.WRITE:
+ return Permission.WRITE
+ elif permission == ProtoPermission.READ_WRITE:
+ return Permission.READ_WRITE
+ elif permission == ProtoPermission.NONE:
+ return Permission.NONE
+ else:
+ pass
+
+ @staticmethod
+ def to_protobuf(permission):
+ if permission == Permission.READ:
+ return ProtoPermission.READ
+ elif permission == Permission.WRITE:
+ return ProtoPermission.WRITE
+ elif permission == Permission.READ_WRITE:
+ return ProtoPermission.READ_WRITE
+ else:
+ pass
+
+ @staticmethod
+ def is_writable(permission):
+ if permission in [Permission.WRITE, Permission.READ_WRITE]:
+ return True
+ else:
+ return False
+
+ @staticmethod
+ def is_readable(permission):
+ if permission in [Permission.READ, Permission.READ_WRITE]:
+ return True
+ else:
+ return False
+
+
+class MessageType(Enum):
+ NORMAL = 0
+ FIFO = 1
+ DELAY = 2
+ TRANSACTION = 3
+
+
+class MessageTypeHelper:
+ @staticmethod
+ def from_protobuf(message_type):
+ if message_type == ProtoMessageType.NORMAL:
+ return MessageType.NORMAL
+ elif message_type == ProtoMessageType.FIFO:
+ return MessageType.FIFO
+ elif message_type == ProtoMessageType.DELAY:
+ return MessageType.DELAY
+ elif message_type == ProtoMessageType.TRANSACTION:
+ return MessageType.TRANSACTION
+ else:
+ pass
+
+ @staticmethod
+ def to_protobuf(message_type):
+ if message_type == MessageType.NORMAL:
+ return ProtoMessageType.NORMAL
+ elif message_type == MessageType.FIFO:
+ return ProtoMessageType.FIFO
+ elif message_type == MessageType.DELAY:
+ return ProtoMessageType.DELAY
+ elif message_type == MessageType.TRANSACTION:
+ return ProtoMessageType.TRANSACTION
+ else:
+ return ProtoMessageType.UNSPECIFIED
+
+
+class MessageQueue:
+ def __init__(self, message_queue):
+ self._topic_resource = Resource(message_queue.topic)
+ self.queue_id = message_queue.id
+ self.permission =
PermissionHelper.from_protobuf(message_queue.permission)
+ self.accept_message_types = [
+ MessageTypeHelper.from_protobuf(mt)
+ for mt in message_queue.accept_message_types
+ ]
+ self.broker = Broker(message_queue.broker)
+
+ @property
+ def topic(self):
+ return self._topic_resource.name
+
+ def __str__(self):
+ return f"{self.broker.name}.{self._topic_resource}.{self.queue_id}"
+
+ def to_protobuf(self):
+ message_types = [
+ MessageTypeHelper.to_protobuf(mt) for mt in
self.accept_message_types
+ ]
+ return ProtoMessageQueue(
+ topic=self._topic_resource.to_protobuf(),
+ id=self.queue_id,
+ permission=PermissionHelper.to_protobuf(self.permission),
+ broker=self.broker.to_protobuf(),
+ accept_message_types=message_types,
+ )
+
+
+class TopicRouteData:
+ def __init__(self, message_queues: List[definition_pb2.MessageQueue]):
+ message_queue_list = []
+ for mq in message_queues:
+ message_queue_list.append(MessageQueue(mq))
+ self.__message_queue_list = message_queue_list
+
+ @property
+ def message_queues(self) -> List[MessageQueue]:
+ return self.__message_queue_list
diff --git a/python/rocketmq/log.py b/python/rocketmq/log.py
index f3e4eae3..a9c82e4b 100644
--- a/python/rocketmq/log.py
+++ b/python/rocketmq/log.py
@@ -29,7 +29,8 @@ console_handler = logging.StreamHandler()
console_handler.setLevel(logging.DEBUG)
formatter = logging.Formatter(
- "%(asctime)s [%(levelname)s] [%(process)d] [%(threadName)s]
[%(filename)s#%(funcName)s:%(lineno)d] %(message)s"
+ "%(asctime)s [%(levelname)s] [%(process)d] [%(threadName)s] [%(filename)s#\
+ %(funcName)s:%(lineno)d] %(message)s"
)
file_handler.setFormatter(formatter)
console_handler.setFormatter(formatter)
diff --git a/python/rocketmq/producer.py b/python/rocketmq/producer.py
new file mode 100644
index 00000000..c9c0d351
--- /dev/null
+++ b/python/rocketmq/producer.py
@@ -0,0 +1,133 @@
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import asyncio
+import threading
+from typing import Set
+
+import rocketmq
+from rocketmq.client import Client
+from rocketmq.client_config import ClientConfig
+from rocketmq.definition import TopicRouteData
+from rocketmq.message_id_codec import MessageIdCodec
+from rocketmq.protocol.definition_pb2 import Message as ProtoMessage
+from rocketmq.protocol.definition_pb2 import Resource, SystemProperties
+from rocketmq.protocol.service_pb2 import SendMessageRequest
+from rocketmq.publish_settings import PublishingSettings
+from rocketmq.rpc_client import Endpoints
+from rocketmq.session_credentials import (SessionCredentials,
+ SessionCredentialsProvider)
+
+
+class PublishingLoadBalancer:
+ def __init__(self, topic_route_data: TopicRouteData, index: int = 0):
+ self.__index = index
+ self.__index_lock = threading.Lock()
+ message_queues = []
+ for mq in topic_route_data.message_queues:
+ if (
+ not mq.permission.is_writable()
+ or mq.broker.id is not rocketmq.utils.master_broker_id
+ ):
+ continue
+ message_queues.append(mq)
+ self.__message_queues = message_queues
+
+ @property
+ def index(self):
+ return self.__index
+
+ def get_and_increment_index(self):
+ with self.__index_lock:
+ temp = self.__index
+ self.__index += 1
+ return temp
+
+ def take_message_queues(self, excluded: Set[Endpoints], count: int):
+ next_index = self.get_and_increment_index()
+ candidates = []
+ candidate_broker_name = set()
+
+ queue_num = len(self.__message_queues)
+ for i in range(queue_num):
+ mq = self.__message_queues[next_index % queue_num]
+ next_index = next_index + 1
+ if (
+ mq.broker.endpoints not in excluded
+ and mq.broker.name not in candidate_broker_name
+ ):
+ candidate_broker_name.add(mq.broker.name)
+ candidates.append(mq)
+ if len(candidates) >= count:
+ return candidates
+ # if all endpoints are isolated
+ if candidates:
+ return candidates
+ for i in range(queue_num):
+ mq = self.__message_queues[next_index % queue_num]
+ if mq.broker.name not in candidate_broker_name:
+ candidate_broker_name.add(mq.broker.name)
+ candidates.append(mq)
+ if len(candidates) >= count:
+ return candidates
+ return candidates
+
+
+class Producer(Client):
+ def __init__(self, client_config: ClientConfig, topics: Set[str]):
+ super().__init__(client_config, topics)
+ self.publish_settings = PublishingSettings(
+ self.client_id, self.endpoints, None, 10, topics
+ )
+
+ async def start_up(self):
+ await super().start_up()
+
+ async def send_message(self, message):
+ req = SendMessageRequest()
+ req.messages.extend([message])
+ topic_data = self.topic_route_cache["normal_topic"]
+ endpoints = topic_data.message_queues[2].broker.endpoints
+ return await self.client_manager.send_message(endpoints, req, 10)
+
+ def get_settings(self):
+ return self.publish_settings
+
+
+async def test():
+ creds = SessionCredentials("username", "password")
+ creds_provider = SessionCredentialsProvider(creds)
+ client_config = ClientConfig(
+
endpoints=Endpoints("rmq-cn-jaj390gga04.cn-hangzhou.rmq.aliyuncs.com:8080"),
+ session_credentials_provider=creds_provider,
+ ssl_enabled=True,
+ )
+ producer = Producer(client_config, topics={"normal_topic"})
+ topic = Resource()
+ topic.name = "normal_topic"
+ msg = ProtoMessage()
+ msg.topic.CopyFrom(topic)
+ msg.body = b"My Message Body"
+ sysperf = SystemProperties()
+ sysperf.message_id = MessageIdCodec.next_message_id()
+ msg.system_properties.CopyFrom(sysperf)
+ print(msg)
+ await producer.start_up()
+ result = await producer.send_message(msg)
+ print(result)
+
+
+if __name__ == "__main__":
+ asyncio.run(test())
diff --git a/python/rocketmq/publish_settings.py
b/python/rocketmq/publish_settings.py
new file mode 100644
index 00000000..c629d514
--- /dev/null
+++ b/python/rocketmq/publish_settings.py
@@ -0,0 +1,85 @@
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import platform
+import socket
+from typing import Dict
+
+from rocketmq.protocol.definition_pb2 import UA
+from rocketmq.protocol.definition_pb2 import Publishing as ProtoPublishing
+from rocketmq.protocol.definition_pb2 import Resource as ProtoResource
+from rocketmq.protocol.definition_pb2 import Settings as ProtoSettings
+from rocketmq.rpc_client import Endpoints
+from rocketmq.settings import (ClientType, ClientTypeHelper, IRetryPolicy,
+ Settings)
+from rocketmq.signature import Signature
+
+
+class UserAgent:
+ def __init__(self):
+ self._version = Signature()._Signature__CLIENT_VERSION_KEY
+ self._platform = platform.platform()
+ self._hostname = socket.gethostname()
+
+ def to_protobuf(self) -> UA:
+ return UA(
+ version=self._version, hostname=self._hostname,
platform=self._platform
+ )
+
+
+class PublishingSettings(Settings):
+ def __init__(
+ self,
+ client_id: str,
+ endpoints: Endpoints,
+ retry_policy: IRetryPolicy,
+ request_timeout: int,
+ topics: Dict[str, bool],
+ ):
+ super().__init__(
+ client_id, ClientType.Producer, endpoints, retry_policy,
request_timeout
+ )
+ self._max_body_size_bytes = 4 * 1024 * 1024
+ self._validate_message_type = True
+ self._topics = topics
+
+ def get_max_body_size_bytes(self) -> int:
+ return self._max_body_size_bytes
+
+ def is_validate_message_type(self) -> bool:
+ return self._validate_message_type
+
+ def sync(self, settings: ProtoSettings) -> None:
+ if settings.pub_sub_case != ProtoSettings.PubSubOneofCase.PUBLISHING:
+ return
+
+ self.retry_policy =
self.retry_policy.inherit_backoff(settings.backoff_policy)
+ self._validate_message_type = settings.publishing.validate_message_type
+ self._max_body_size_bytes = settings.publishing.max_body_size
+
+ def to_protobuf(self):
+ topics = [ProtoResource(name=topic_name) for topic_name in
self._topics]
+
+ publishing = ProtoPublishing(
+ topics=topics,
+ validate_message_type=self._validate_message_type,
+ max_body_size=self._max_body_size_bytes,
+ )
+ return ProtoSettings(
+ publishing=publishing,
+ access_point=self.Endpoints.to_protobuf(),
+ client_type=ClientTypeHelper.to_protobuf(self.ClientType),
+ user_agent=UserAgent().to_protobuf(),
+ )
diff --git a/python/rocketmq/rpc_client.py b/python/rocketmq/rpc_client.py
index 545419c7..a737d3d0 100644
--- a/python/rocketmq/rpc_client.py
+++ b/python/rocketmq/rpc_client.py
@@ -14,14 +14,147 @@
# limitations under the License.
import asyncio
+import operator
+import socket
import time
from datetime import timedelta
+from enum import Enum
+from functools import reduce
import certifi
from grpc import aio, ssl_channel_credentials
from protocol import service_pb2
from rocketmq import logger
from rocketmq.protocol import service_pb2_grpc
+from rocketmq.protocol.definition_pb2 import Address as ProtoAddress
+from rocketmq.protocol.definition_pb2 import \
+ AddressScheme as ProtoAddressScheme
+from rocketmq.protocol.definition_pb2 import Endpoints as ProtoEndpoints
+
+
+class AddressScheme(Enum):
+ Unspecified = 0
+ Ipv4 = 1
+ Ipv6 = 2
+ DomainName = 3
+
+ @staticmethod
+ def to_protobuf(scheme):
+ if scheme == AddressScheme.Ipv4:
+ return ProtoAddressScheme.IPV4
+ elif scheme == AddressScheme.Ipv6:
+ return ProtoAddressScheme.IPV6
+ elif scheme == AddressScheme.DomainName:
+ return ProtoAddressScheme.DOMAIN_NAME
+ else: # Unspecified or other cases
+ return ProtoAddressScheme.ADDRESS_SCHEME_UNSPECIFIED
+
+
+class Address:
+ def __init__(self, host, port):
+ self.host = host
+ self.port = port
+
+ def to_protobuf(self):
+ proto_address = ProtoAddress()
+ proto_address.host = self.host
+ proto_address.port = self.port
+ return proto_address
+
+
+class Endpoints:
+ HttpPrefix = "http://"
+ HttpsPrefix = "https://"
+ DefaultPort = 80
+ EndpointSeparator = ":"
+
+ def __init__(self, endpoints):
+ self.Addresses = []
+
+ self.scheme = AddressScheme.Unspecified
+ self._hash = None
+
+ if type(endpoints) == str:
+ if endpoints.startswith(self.HttpPrefix):
+ endpoints = endpoints[len(self.HttpPrefix):]
+ if endpoints.startswith(self.HttpsPrefix):
+ endpoints = endpoints[len(self.HttpsPrefix):]
+
+ index = endpoints.find(self.EndpointSeparator)
+ port = int(endpoints[index + 1:]) if index > 0 else 80
+ host = endpoints[:index] if index > 0 else endpoints
+ address = Address(host, port)
+ self.Addresses.append(address)
+ try:
+ socket.inet_pton(socket.AF_INET, host)
+ self.scheme = AddressScheme.IPv4
+ except socket.error:
+ try:
+ socket.inet_pton(socket.AF_INET6, host)
+ self.scheme = AddressScheme.IPv6
+ except socket.error:
+ self.scheme = AddressScheme.DomainName
+ self.Addresses.append(address)
+
+ # Assuming AddressListEqualityComparer exists
+ self._hash = 17
+ self._hash = (self._hash * 31) + reduce(
+ operator.xor, (hash(address) for address in self.Addresses)
+ )
+ self._hash = (self._hash * 31) + hash(self.scheme)
+ else:
+ self.Addresses = [
+ Address(addr.host, addr.port) for addr in endpoints.addresses
+ ]
+ if not self.Addresses:
+ raise Exception("No available address")
+
+ if endpoints.scheme == "Ipv4":
+ self.scheme = AddressScheme.Ipv4
+ elif endpoints.scheme == "Ipv6":
+ self.scheme = AddressScheme.Ipv6
+ else:
+ self.scheme = AddressScheme.DomainName
+ if len(self.Addresses) > 1:
+ raise Exception(
+ "Multiple addresses are\
+ not allowed in domain scheme"
+ )
+
+ self._hash = self._calculate_hash()
+
+ def _calculate_hash(self):
+ hash_value = 17
+ for address in self.Addresses:
+ hash_value = (hash_value * 31) + hash(address)
+ hash_value = (hash_value * 31) + hash(self.scheme)
+ return hash_value
+
+ def __str__(self):
+ for address in self.Addresses:
+ return None
+
+ def grpc_target(self, sslEnabled):
+ for address in self.Addresses:
+ return address.host + ":" + str(address.port)
+ raise ValueError("No available address")
+
+ def __eq__(self, other):
+ if other is None:
+ return False
+ if self is other:
+ return True
+ res = self.Addresses == other.Addresses and self.Scheme == other.Scheme
+ return res
+
+ def __hash__(self):
+ return self._hash
+
+ def to_protobuf(self):
+ proto_endpoints = ProtoEndpoints()
+ proto_endpoints.scheme = self.scheme.to_protobuf(self.scheme)
+ proto_endpoints.addresses.extend([i.to_protobuf() for i in
self.Addresses])
+ return proto_endpoints
class RpcClient:
@@ -53,28 +186,37 @@ class RpcClient:
)
async def query_route(
- self, request: service_pb2.QueryRouteRequest, timeout_seconds: int
+ self, request: service_pb2.QueryRouteRequest, metadata,
timeout_seconds: int
):
- return await self.__stub.QueryRoute(request, timeout=timeout_seconds)
+ # metadata = [('x-mq-client-id', 'value1')]
+ return await self.__stub.QueryRoute(
+ request, timeout=timeout_seconds, metadata=metadata
+ )
async def heartbeat(
- self, request: service_pb2.HeartbeatRequest, timeout_seconds: int
+ self, request: service_pb2.HeartbeatRequest, metadata,
timeout_seconds: int
):
- return await self.__stub.Heartbeat(request, timeout=timeout_seconds)
+ return await self.__stub.Heartbeat(
+ request, metadata=metadata, timeout=timeout_seconds
+ )
async def send_message(
- self, request: service_pb2.SendMessageRequest, timeout_seconds: int
+ self, request: service_pb2.SendMessageRequest, metadata,
timeout_seconds: int
):
- return await self.__stub.SendMessage(request, timeout=timeout_seconds)
+ return await self.__stub.SendMessage(
+ request, metadata=metadata, timeout=timeout_seconds
+ )
async def receive_message(
- self, request: service_pb2.ReceiveMessageRequest, timeout_seconds: int
+ self, request: service_pb2.ReceiveMessageRequest, metadata,
timeout_seconds: int
):
- results = self.__stub.ReceiveMessage(request, timeout=timeout_seconds)
+ results = self.__stub.ReceiveMessage(
+ request, metadata=metadata, timeout=timeout_seconds
+ )
response = []
try:
async for result in results:
- if result.HasField('message'):
+ if result.HasField("message"):
response.append(result.message)
except Exception as e:
logger.info("An error occurred: %s", e)
@@ -82,68 +224,72 @@ class RpcClient:
return response
async def query_assignment(
- self, request: service_pb2.QueryAssignmentRequest, timeout_seconds: int
+ self,
+ request: service_pb2.QueryAssignmentRequest,
+ metadata,
+ timeout_seconds: int,
):
- return await self.__stub.QueryAssignment(request,
timeout=timeout_seconds)
+ return await self.__stub.QueryAssignment(
+ request, metadata=metadata, timeout=timeout_seconds
+ )
async def ack_message(
- self, request: service_pb2.AckMessageRequest, timeout_seconds: int
+ self, request: service_pb2.AckMessageRequest, metadata,
timeout_seconds: int
):
- return await self.__stub.AckMessage(request, timeout=timeout_seconds)
+ return await self.__stub.AckMessage(
+ request, metadata=metadata, timeout=timeout_seconds
+ )
async def forward_message_to_dead_letter_queue(
self,
request: service_pb2.ForwardMessageToDeadLetterQueueRequest,
+ metadata,
timeout_seconds: int,
):
return await self.__stub.ForwardMessageToDeadLetterQueue(
- request, timeout=timeout_seconds
+ request, metadata=metadata, timeout=timeout_seconds
)
async def end_transaction(
- self, request: service_pb2.EndTransactionRequest, timeout_seconds: int
+ self, request: service_pb2.EndTransactionRequest, metadata,
timeout_seconds: int
):
- return await self.__stub.EndTransaction(request,
timeout=timeout_seconds)
+ return await self.__stub.EndTransaction(
+ request, metadata=metadata, timeout=timeout_seconds
+ )
async def notify_client_termination(
- self, request: service_pb2.NotifyClientTerminationRequest,
timeout_seconds: int
+ self,
+ request: service_pb2.NotifyClientTerminationRequest,
+ metadata,
+ timeout_seconds: int,
):
return await self.__stub.NotifyClientTermination(
- request, timeout=timeout_seconds
+ request, metadata=metadata, timeout=timeout_seconds
)
async def change_invisible_duration(
- self, request: service_pb2.ChangeInvisibleDurationRequest,
timeout_seconds: int
+ self,
+ request: service_pb2.ChangeInvisibleDurationRequest,
+ metadata,
+ timeout_seconds: int,
):
return await self.__stub.ChangeInvisibleDuration(
- request, timeout=timeout_seconds
+ request, metadata=metadata, timeout=timeout_seconds
)
async def send_requests(self, requests, stream):
for request in requests:
await stream.send_message(request)
- async def telemetry(
- self, timeout_seconds: int, requests
- ):
- responses = []
- async with self.__stub.Telemetry() as stream:
- # Create a task for sending requests
- send_task = asyncio.create_task(self.send_requests(requests,
stream))
- # Receiving responses
- async for response in stream:
- responses.append(response)
-
- # Await the send task to ensure all requests have been sent
- await send_task
-
- return responses
+ def telemetry(self, metadata, timeout_seconds: int):
+ stream = self.__stub.Telemetry(metadata=metadata,
timeout=timeout_seconds)
+ return stream
async def test():
- client = RpcClient("rmq-cn-72u353icd01.cn-hangzhou.rmq.aliyuncs.com:8080")
- request = service_pb2.QueryRouteRequest()
- response = await client.query_route(request, 3)
+ client = RpcClient("rmq-cn-jaj390gga04.cn-hangzhou.rmq.aliyuncs.com:8081")
+ request = service_pb2.SendMessageRequest()
+ response = await client.send_message(request, 3)
logger.info(response)
diff --git a/python/rocketmq/session.py b/python/rocketmq/session.py
new file mode 100644
index 00000000..50c11d98
--- /dev/null
+++ b/python/rocketmq/session.py
@@ -0,0 +1,44 @@
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import asyncio
+from threading import Event
+
+from rocketmq.protocol.service_pb2 import \
+ TelemetryCommand as ProtoTelemetryCommand
+
+
+class Session:
+ def __init__(self, endpoints, streaming_call, client):
+ self._endpoints = endpoints
+ self._semaphore = asyncio.Semaphore(1)
+ self._streaming_call = streaming_call
+ self._client = client
+ self._event = Event()
+
+ async def write_async(self, telemetry_command: ProtoTelemetryCommand):
+ await self._streaming_call.write(telemetry_command)
+ response = await self._streaming_call.read()
+ print(response)
+
+ async def sync_settings(self, await_resp):
+ await self._semaphore.acquire()
+ try:
+ settings = self._client.get_settings()
+ telemetry_command = ProtoTelemetryCommand()
+ telemetry_command.settings.CopyFrom(settings.to_protobuf())
+ await self.write_async(telemetry_command)
+ finally:
+ self._semaphore.release()
diff --git a/python/rocketmq/utils.py b/python/rocketmq/session_credentials.py
similarity index 51%
copy from python/rocketmq/utils.py
copy to python/rocketmq/session_credentials.py
index adec636b..828e293c 100644
--- a/python/rocketmq/utils.py
+++ b/python/rocketmq/session_credentials.py
@@ -13,27 +13,24 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-import hashlib
-import hmac
+class SessionCredentials:
+ def __init__(self, access_key=None, access_secret=None,
security_token=None):
+ if access_key is None:
+ raise ValueError("accessKey should not be None")
+ if access_secret is None:
+ raise ValueError("accessSecret should not be None")
-def number_to_base(number, base):
- alphabet = "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZ"
- if number == 0:
- return alphabet[0]
+ self.access_key = access_key
+ self.access_secret = access_secret
+ self.security_token = security_token
- result = []
- while number:
- number, remainder = divmod(number, base)
- result.append(alphabet[remainder])
- return "".join(reversed(result))
+class SessionCredentialsProvider:
+ def __init__(self, credentials):
+ if not isinstance(credentials, SessionCredentials):
+ raise ValueError("credentials should be an instance of
SessionCredentials")
+ self.credentials = credentials
-
-def sign(access_secret: str, datetime: str) -> str:
- digester = hmac.new(
- bytes(access_secret, encoding="UTF-8"),
- bytes(datetime, encoding="UTF-8"),
- hashlib.sha1,
- )
- return digester.hexdigest().upper()
+ def get_credentials(self):
+ return self.credentials
diff --git a/python/rocketmq/settings.py b/python/rocketmq/settings.py
new file mode 100644
index 00000000..a270dd15
--- /dev/null
+++ b/python/rocketmq/settings.py
@@ -0,0 +1,80 @@
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import datetime
+from abc import ABC, abstractmethod
+from dataclasses import dataclass
+from datetime import timedelta
+from enum import Enum
+from typing import Optional
+
+from rocketmq.protocol.definition_pb2 import ClientType as ProtoClientType
+from rocketmq.protocol.definition_pb2 import RetryPolicy as ProtoRetryPolicy
+from rocketmq.protocol.definition_pb2 import Settings as ProtoSettings
+from rocketmq.rpc_client import Endpoints
+
+
+class IRetryPolicy(ABC):
+ @abstractmethod
+ def GetMaxAttempts(self) -> int:
+ pass
+
+ @abstractmethod
+ def GetNextAttemptDelay(self, attempt: int) -> datetime.timedelta:
+ pass
+
+ @abstractmethod
+ def ToProtobuf(self) -> ProtoRetryPolicy:
+ pass
+
+ @abstractmethod
+ def InheritBackoff(self, retryPolicy: ProtoRetryPolicy):
+ pass
+
+
+class ClientType(Enum):
+ Producer = 1
+ SimpleConsumer = 2
+ PushConsumer = 3
+
+
+class ClientTypeHelper:
+ @staticmethod
+ def to_protobuf(clientType):
+ return {
+ ClientType.Producer: ProtoClientType.PRODUCER,
+ ClientType.SimpleConsumer: ProtoClientType.SIMPLE_CONSUMER,
+ ClientType.PushConsumer: ProtoClientType.PUSH_CONSUMER,
+ }.get(clientType, ProtoClientType.CLIENT_TYPE_UNSPECIFIED)
+
+
+@dataclass
+class Settings:
+ ClientId: str
+ ClientType: ClientType
+ Endpoints: Endpoints
+ RetryPolicy: Optional[IRetryPolicy]
+ RequestTimeout: timedelta
+
+ def to_protobuf(self):
+ settings = ProtoSettings()
+ return settings
+
+ def Sync(self, settings):
+ # Sync the settings properties from the Protobuf message
+ pass
+
+ def GetRetryPolicy(self):
+ return self.RetryPolicy
diff --git a/python/rocketmq/signature.py b/python/rocketmq/signature.py
index c77a355a..3b507ded 100644
--- a/python/rocketmq/signature.py
+++ b/python/rocketmq/signature.py
@@ -50,13 +50,13 @@ class Signature:
Signature.__DATE_TIME_KEY,
date_time,
),
- (Signature.__REQUEST_ID_KEY, uuid.uuid4()),
+ (Signature.__REQUEST_ID_KEY, str(uuid.uuid4())),
(Signature.__CLIENT_ID_KEY, client_id),
]
if not client_config.session_credentials_provider:
return metadata
session_credentials = (
- client_config.session_credentials_provider.session_credentials()
+ client_config.session_credentials_provider.get_credentials()
)
if not session_credentials:
return metadata
@@ -68,7 +68,7 @@ class Signature:
not session_credentials.access_secret
):
return metadata
- signature = sign(session_credentials.access_key, date_time)
+ signature = sign(session_credentials.access_secret, date_time)
authorization = (
Signature.__ALGORITHM
+ " "
diff --git a/python/rocketmq/utils.py b/python/rocketmq/utils.py
index adec636b..dd2a4a98 100644
--- a/python/rocketmq/utils.py
+++ b/python/rocketmq/utils.py
@@ -16,6 +16,8 @@
import hashlib
import hmac
+master_broker_id = 0
+
def number_to_base(number, base):
alphabet = "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZ"