This is an automated email from the ASF dual-hosted git repository. aaronai pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/rocketmq-clients.git
The following commit(s) were added to refs/heads/master by this push: new d4506486 Implementation of Simple Consumer for Python Client (#588) d4506486 is described below commit d4506486a38b9d064bf8cf55548ebb17f4387b01 Author: Yan Chao Mei <1653720...@qq.com> AuthorDate: Sat Aug 26 11:50:53 2023 +0800 Implementation of Simple Consumer for Python Client (#588) * finish simple_consumer * fix style issues * delete private info * convert comments to English * add state enum & change_invisible_duration * extract example * add more tests * fix style issue --- python/examples/simple_consumer_example.py | 58 ++++ python/rocketmq/client.py | 25 +- python/rocketmq/consumer.py | 73 ++++ python/rocketmq/definition.py | 13 +- python/rocketmq/filter_expression.py | 35 ++ python/rocketmq/message.py | 89 ++++- python/rocketmq/producer.py | 19 +- python/rocketmq/rpc_client.py | 3 +- python/rocketmq/simple_consumer.py | 423 ++++++++++++++++++++++++ python/rocketmq/simple_subscription_settings.py | 89 +++++ python/rocketmq/state.py | 25 ++ 11 files changed, 830 insertions(+), 22 deletions(-) diff --git a/python/examples/simple_consumer_example.py b/python/examples/simple_consumer_example.py new file mode 100644 index 00000000..07ff20b9 --- /dev/null +++ b/python/examples/simple_consumer_example.py @@ -0,0 +1,58 @@ +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import asyncio + +from rocketmq.client_config import ClientConfig +from rocketmq.filter_expression import FilterExpression +from rocketmq.log import logger +from rocketmq.protocol.definition_pb2 import Resource +from rocketmq.rpc_client import Endpoints +from rocketmq.session_credentials import (SessionCredentials, + SessionCredentialsProvider) +from rocketmq.simple_consumer import SimpleConsumer + + +async def test(): + credentials = SessionCredentials("username", "password") + credentials_provider = SessionCredentialsProvider(credentials) + client_config = ClientConfig( + endpoints=Endpoints("endpoint"), + session_credentials_provider=credentials_provider, + ssl_enabled=True, + ) + topic = Resource() + topic.name = "normal_topic" + + consumer_group = "yourConsumerGroup" + subscription = {topic.name: FilterExpression("*")} + simple_consumer = (await SimpleConsumer.Builder() + .set_client_config(client_config) + .set_consumer_group(consumer_group) + .set_await_duration(15) + .set_subscription_expression(subscription) + .build()) + logger.info(simple_consumer) + # while True: + message_views = await simple_consumer.receive(16, 15) + logger.info(message_views) + for message in message_views: + logger.info(message.body) + logger.info(f"Received a message, topic={message.topic}, message-id={message.message_id}, body-size={len(message.body)}") + await simple_consumer.ack(message) + logger.info(f"Message is acknowledged successfully, message-id={message.message_id}") + +if __name__ == "__main__": + asyncio.run(test()) diff --git a/python/rocketmq/client.py b/python/rocketmq/client.py index 509d991d..fd6a7f79 100644 --- a/python/rocketmq/client.py +++ b/python/rocketmq/client.py @@ -15,7 +15,6 @@ import asyncio import threading -from typing import Set from protocol import definition_pb2, service_pb2 from protocol.definition_pb2 import Code as ProtoCode @@ -60,7 +59,7 @@ class Client: """ Main client class which handles interaction with the server. """ - def __init__(self, client_config: ClientConfig, topics: Set[str]): + def __init__(self, client_config: ClientConfig): """ Initialization method for the Client class. @@ -70,7 +69,6 @@ class Client: self.client_config = client_config self.client_id = ClientIdEncoder.generate() self.endpoints = client_config.endpoints - self.topics = topics #: A cache to store topic routes. self.topic_route_cache = {} @@ -83,13 +81,16 @@ class Client: #: A dictionary to store isolated items. self.isolated = dict() + def get_topics(self): + raise NotImplementedError("This method should be implemented by the subclass.") + async def start(self): """ Start method which initiates fetching of topic routes and schedules heartbeats. """ # get topic route logger.debug(f"Begin to start the rocketmq client, client_id={self.client_id}") - for topic in self.topics: + for topic in self.get_topics(): self.topic_route_cache[topic] = await self.fetch_topic_route(topic) scheduler = ScheduleWithFixedDelay(self.heartbeat, 3, 12) scheduler_sync_settings = ScheduleWithFixedDelay(self.sync_settings, 3, 12) @@ -489,6 +490,22 @@ class ClientManager: request, metadata, timeout_seconds ) + async def receive_message( + self, + endpoints: Endpoints, + request: service_pb2.ReceiveMessageRequest, + timeout_seconds: int, + ): + rpc_client = self.__get_rpc_client( + endpoints, self.__client.client_config.ssl_enabled + ) + metadata = Signature.sign(self.__client.client_config, self.__client.client_id) + + response = await rpc_client.receive_message( + request, metadata, timeout_seconds + ) + return response + def telemetry( self, endpoints: Endpoints, diff --git a/python/rocketmq/consumer.py b/python/rocketmq/consumer.py new file mode 100644 index 00000000..d81d8972 --- /dev/null +++ b/python/rocketmq/consumer.py @@ -0,0 +1,73 @@ +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import re +from typing import List + +from filter_expression import ExpressionType +from google.protobuf.duration_pb2 import Duration +from message import MessageView +from rocketmq.client import Client +from rocketmq.protocol.definition_pb2 import \ + FilterExpression as ProtoFilterExpression +from rocketmq.protocol.definition_pb2 import FilterType +from rocketmq.protocol.definition_pb2 import Resource as ProtoResource +from rocketmq.protocol.service_pb2 import \ + ReceiveMessageRequest as ProtoReceiveMessageRequest + + +class ReceiveMessageResult: + def __init__(self, endpoints, messages: List['MessageView']): + self.endpoints = endpoints + self.messages = messages + + +class Consumer(Client): + CONSUMER_GROUP_REGEX = re.compile(r"^[%a-zA-Z0-9_-]+$") + + def __init__(self, client_config, consumer_group): + super().__init__(client_config) + self.consumer_group = consumer_group + + async def receive_message(self, request, mq, await_duration): + tolerance = self.client_config.request_timeout + timeout = tolerance + await_duration + results = await self.client_manager.receive_message(mq.broker.endpoints, request, timeout) + + messages = [MessageView.from_protobuf(message, mq) for message in results] + return ReceiveMessageResult(mq.broker.endpoints, messages) + + @staticmethod + def _wrap_filter_expression(filter_expression): + filter_type = FilterType.TAG + if filter_expression.type == ExpressionType.Sql92: + filter_type = FilterType.SQL + return ProtoFilterExpression( + type=filter_type, + expression=filter_expression.expression + ) + + def wrap_receive_message_request(self, batch_size, mq, filter_expression, await_duration, invisible_duration): + group = ProtoResource() + group.name = self.consumer_group + return ProtoReceiveMessageRequest( + group=group, + message_queue=mq.to_protobuf(), + filter_expression=self._wrap_filter_expression(filter_expression), + long_polling_timeout=Duration(seconds=await_duration), + batch_size=batch_size, + auto_renew=False, + invisible_duration=Duration(seconds=invisible_duration) + ) diff --git a/python/rocketmq/definition.py b/python/rocketmq/definition.py index 3d63748c..498fc6d9 100644 --- a/python/rocketmq/definition.py +++ b/python/rocketmq/definition.py @@ -62,7 +62,7 @@ class Broker: :return: The protobuf representation of the broker. """ return ProtoBroker( - Name=self.name, Id=self.id, Endpoints=self.endpoints.to_protobuf() + name=self.name, id=self.id, endpoints=self.endpoints.to_protobuf() ) @@ -76,8 +76,8 @@ class Resource: :param resource: The resource object. """ if resource is not None: - self.namespace = resource.ResourceNamespace - self.name = resource.Name + self.namespace = resource.resource_namespace + self.name = resource.name else: self.namespace = "" self.name = name @@ -87,7 +87,10 @@ class Resource: :return: The protobuf representation of the resource. """ - return ProtoResource(ResourceNamespace=self.namespace, Name=self.name) + resource = ProtoResource() + resource.name = self.name + resource.resource_namespace = self.namespace + return resource def __str__(self): return f"{self.namespace}.{self.name}" if self.namespace else self.name @@ -219,7 +222,7 @@ class MessageQueue: :param message_queue: The initial message queue to be encapsulated. """ - self._topic_resource = Resource(message_queue.topic) + self._topic_resource = Resource(message_queue.topic.name, message_queue.topic) self.queue_id = message_queue.id self.permission = PermissionHelper.from_protobuf(message_queue.permission) self.accept_message_types = [ diff --git a/python/rocketmq/filter_expression.py b/python/rocketmq/filter_expression.py new file mode 100644 index 00000000..9e3e5117 --- /dev/null +++ b/python/rocketmq/filter_expression.py @@ -0,0 +1,35 @@ +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from enum import Enum + + +class ExpressionType(Enum): + Tag = 1 + Sql92 = 2 + + +class FilterExpression: + def __init__(self, expression, expression_type=ExpressionType.Tag): + self._expression = expression + self._type = expression_type + + @property + def type(self): + return self._type + + @property + def expression(self): + return self._expression diff --git a/python/rocketmq/message.py b/python/rocketmq/message.py index f20e8651..5076da91 100644 --- a/python/rocketmq/message.py +++ b/python/rocketmq/message.py @@ -14,7 +14,14 @@ # limitations under the License. -from rocketmq.message_id import MessageId +import binascii +import gzip +import hashlib +from typing import Dict, List + +from rocketmq.definition import MessageQueue +from rocketmq.protocol.definition_pb2 import DigestType as ProtoDigestType +from rocketmq.protocol.definition_pb2 import Encoding as ProtoEncoding class Message: @@ -70,16 +77,21 @@ class Message: class MessageView: def __init__( self, - message_id: MessageId, + message_id: str, topic: str, body: bytes, - properties: map, tag: str, - keys: str, message_group: str, delivery_timestamp: int, + keys: List[str], + properties: Dict[str, str], born_host: str, + born_time: int, delivery_attempt: int, + message_queue: MessageQueue, + receipt_handle: str, + offset: int, + corrupted: bool ): self.__message_id = message_id self.__topic = topic @@ -91,11 +103,28 @@ class MessageView: self.__delivery_timestamp = delivery_timestamp self.__born_host = born_host self.__delivery_attempt = delivery_attempt + self.__receipt_handle = receipt_handle + self.__born_time = born_time + self.__message_queue = message_queue + self.__offset = offset + self.__corrupted = corrupted + + @property + def message_queue(self): + return self.__message_queue + + @property + def receipt_handle(self): + return self.__receipt_handle @property def topic(self): return self.__topic + @property + def body(self): + return self.__body + @property def message_id(self): return self.__message_id @@ -123,3 +152,55 @@ class MessageView: @property def delivery_timestamp(self): return self.__delivery_timestamp + + @classmethod + def from_protobuf(cls, message, message_queue=None): + topic = message.topic.name + system_properties = message.system_properties + message_id = system_properties.message_id + body_digest = system_properties.body_digest + check_sum = body_digest.checksum + raw = message.body + corrupted = False + digest_type = body_digest.type + + # Digest Type check + if digest_type == ProtoDigestType.CRC32: + expected_check_sum = format(binascii.crc32(raw) & 0xFFFFFFFF, '08X') + if not expected_check_sum == check_sum: + corrupted = True + elif digest_type == ProtoDigestType.MD5: + expected_check_sum = hashlib.md5(raw).hexdigest() + if not expected_check_sum == check_sum: + corrupted = True + elif digest_type == ProtoDigestType.SHA1: + expected_check_sum = hashlib.sha1(raw).hexdigest() + if not expected_check_sum == check_sum: + corrupted = True + elif digest_type in [ProtoDigestType.unspecified, None]: + print(f"Unsupported message body digest algorithm, digestType={digest_type}, topic={topic}, messageId={message_id}") + + # Body Encoding check + body_encoding = system_properties.body_encoding + body = raw + if body_encoding == ProtoEncoding.GZIP: + body = gzip.decompress(message.body) + elif body_encoding in [ProtoEncoding.IDENTITY, None]: + pass + else: + print(f"Unsupported message encoding algorithm, topic={topic}, messageId={message_id}, bodyEncoding={body_encoding}") + + tag = system_properties.tag + message_group = system_properties.message_group + delivery_time = system_properties.delivery_timestamp + keys = list(system_properties.keys) + + born_host = system_properties.born_host + born_time = system_properties.born_timestamp + delivery_attempt = system_properties.delivery_attempt + queue_offset = system_properties.queue_offset + properties = {key: value for key, value in message.user_properties.items()} + receipt_handle = system_properties.receipt_handle + + return cls(message_id, topic, body, tag, message_group, delivery_time, keys, properties, born_host, + born_time, delivery_attempt, message_queue, receipt_handle, queue_offset, corrupted) diff --git a/python/rocketmq/producer.py b/python/rocketmq/producer.py index 9e10a3db..378bf8eb 100644 --- a/python/rocketmq/producer.py +++ b/python/rocketmq/producer.py @@ -179,7 +179,8 @@ class Producer(Client): :param client_config: The configuration for the client. :param topics: The set of topics to which the producer can send messages. """ - super().__init__(client_config, topics) + super().__init__(client_config) + self.publish_topics = topics retry_policy = ExponentialBackoffRetryPolicy.immediately_retry_policy(10) #: Set up the publishing settings with the given parameters. self.publish_settings = PublishingSettings( @@ -196,6 +197,9 @@ class Producer(Client): """Provide an asynchronous context manager for the producer.""" await self.shutdown() + def get_topics(self): + return self.publish_topics + async def start(self): """Start the RocketMQ producer and log the operation.""" logger.info(f"Begin to start the rocketmq producer, client_id={self.client_id}") @@ -364,7 +368,7 @@ async def test(): credentials = SessionCredentials("username", "password") credentials_provider = SessionCredentialsProvider(credentials) client_config = ClientConfig( - endpoints=Endpoints("rmq-cn-jaj390gga04.cn-hangzhou.rmq.aliyuncs.com:8080"), + endpoints=Endpoints("endpoint"), session_credentials_provider=credentials_provider, ssl_enabled=True, ) @@ -375,6 +379,7 @@ async def test(): msg.body = b"My Normal Message Body" sysperf = SystemProperties() sysperf.message_id = MessageIdCodec.next_message_id() + sysperf.message_group = "yourConsumerGroup" msg.system_properties.CopyFrom(sysperf) producer = Producer(client_config, topics={"normal_topic"}) message = Message(topic.name, msg.body) @@ -388,7 +393,7 @@ async def test_delay_message(): credentials = SessionCredentials("username", "password") credentials_provider = SessionCredentialsProvider(credentials) client_config = ClientConfig( - endpoints=Endpoints("rmq-cn-jaj390gga04.cn-hangzhou.rmq.aliyuncs.com:8080"), + endpoints=Endpoints("endpoint"), session_credentials_provider=credentials_provider, ssl_enabled=True, ) @@ -417,7 +422,7 @@ async def test_fifo_message(): credentials = SessionCredentials("username", "password") credentials_provider = SessionCredentialsProvider(credentials) client_config = ClientConfig( - endpoints=Endpoints("rmq-cn-jaj390gga04.cn-hangzhou.rmq.aliyuncs.com:8080"), + endpoints=Endpoints("endpoint"), session_credentials_provider=credentials_provider, ssl_enabled=True, ) @@ -431,7 +436,7 @@ async def test_fifo_message(): msg.system_properties.CopyFrom(sysperf) logger.debug(f"{msg}") producer = Producer(client_config, topics={"fifo_topic"}) - message = Message(topic.name, msg.body, message_group="yourMessageGroup") + message = Message(topic.name, msg.body, message_group="yourConsumerGroup") await producer.start() await asyncio.sleep(10) send_receipt = await producer.send(message) @@ -442,7 +447,7 @@ async def test_transaction_message(): credentials = SessionCredentials("username", "password") credentials_provider = SessionCredentialsProvider(credentials) client_config = ClientConfig( - endpoints=Endpoints("rmq-cn-jaj390gga04.cn-hangzhou.rmq.aliyuncs.com:8080"), + endpoints=Endpoints("endpoint"), session_credentials_provider=credentials_provider, ssl_enabled=True, ) @@ -469,7 +474,7 @@ async def test_retry_and_isolation(): credentials = SessionCredentials("username", "password") credentials_provider = SessionCredentialsProvider(credentials) client_config = ClientConfig( - endpoints=Endpoints("rmq-cn-jaj390gga04.cn-hangzhou.rmq.aliyuncs.com:8080"), + endpoints=Endpoints("endpoint"), session_credentials_provider=credentials_provider, ssl_enabled=True, ) diff --git a/python/rocketmq/rpc_client.py b/python/rocketmq/rpc_client.py index 6c1107ab..d907632a 100644 --- a/python/rocketmq/rpc_client.py +++ b/python/rocketmq/rpc_client.py @@ -23,9 +23,8 @@ from functools import reduce import certifi from grpc import aio, ssl_channel_credentials -from protocol import service_pb2 from rocketmq.log import logger -from rocketmq.protocol import service_pb2_grpc +from rocketmq.protocol import service_pb2, service_pb2_grpc from rocketmq.protocol.definition_pb2 import Address as ProtoAddress from rocketmq.protocol.definition_pb2 import \ AddressScheme as ProtoAddressScheme diff --git a/python/rocketmq/simple_consumer.py b/python/rocketmq/simple_consumer.py new file mode 100644 index 00000000..a85eb809 --- /dev/null +++ b/python/rocketmq/simple_consumer.py @@ -0,0 +1,423 @@ +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import asyncio +import random +import re +import threading +from datetime import timedelta +from threading import Lock +from typing import Dict + +import rocketmq +from google.protobuf.duration_pb2 import Duration +from rocketmq.client_config import ClientConfig +from rocketmq.consumer import Consumer +from rocketmq.definition import PermissionHelper +from rocketmq.filter_expression import FilterExpression +from rocketmq.log import logger +from rocketmq.message import MessageView +from rocketmq.protocol.definition_pb2 import Resource +from rocketmq.protocol.definition_pb2 import Resource as ProtoResource +from rocketmq.protocol.service_pb2 import \ + AckMessageEntry as ProtoAckMessageEntry +from rocketmq.protocol.service_pb2 import \ + AckMessageRequest as ProtoAckMessageRequest +from rocketmq.protocol.service_pb2 import \ + ChangeInvisibleDurationRequest as ProtoChangeInvisibleDurationRequest +from rocketmq.rpc_client import Endpoints +from rocketmq.session_credentials import (SessionCredentials, + SessionCredentialsProvider) +from rocketmq.simple_subscription_settings import SimpleSubscriptionSettings +from rocketmq.state import State +from utils import get_positive_mod + + +class SubscriptionLoadBalancer: + """This class serves as a load balancer for message subscription. + It keeps track of a rotating index to help distribute the load evenly. + """ + + def __init__(self, topic_route_data): + #: current index for message queue selection + self._index = random.randint(0, 10000) # assuming a range of 0-10000 + #: thread lock to ensure atomic update to the index + self._index_lock = threading.Lock() + + #: filter the message queues which are readable and from the master broker + self._message_queues = [ + mq for mq in topic_route_data.message_queues + if PermissionHelper().is_readable(mq.permission) + and mq.broker.id == rocketmq.utils.master_broker_id + ] + + def update(self, topic_route_data): + """Updates the message queues based on the new topic route data.""" + self._index += 1 + self._message_queues = [ + mq for mq in topic_route_data.message_queues + if PermissionHelper().is_readable(mq.permission) + and mq.broker.id == rocketmq.utils.master_broker_id + ] + return self + + def take_message_queue(self): + """Fetches the next message queue based on the current index.""" + with self._index_lock: + index = get_positive_mod(self._index, len(self._message_queues)) + self._index += 1 + return self._message_queues[index] + + +class SimpleConsumer(Consumer): + """The SimpleConsumer class extends the Client class and is used to consume + messages from specific topics in RocketMQ. + """ + + def __init__(self, client_config: ClientConfig, consumer_group: str, await_duration: int, subscription_expressions: Dict[str, FilterExpression]): + """Create a new SimpleConsumer. + + :param client_config: The configuration for the client. + :param consumer_group: The consumer group. + :param await_duration: The await duration. + :param subscription_expressions: The subscription expressions. + """ + super().__init__(client_config, consumer_group) + + self._consumer_group = consumer_group + self._await_duration = await_duration + self._subscription_expressions = subscription_expressions + + self._simple_subscription_settings = SimpleSubscriptionSettings(self.client_id, self.endpoints, self._consumer_group, timedelta(seconds=10), 10, self._subscription_expressions) + self._subscription_route_data_cache = {} + self._topic_round_robin_index = 0 + self._state_lock = Lock() + self._state = State.New + self._subscription_load_balancer = {} # A dictionary to keep subscription load balancers + + def get_topics(self): + return set(self._subscription_expressions.keys()) + + def get_settings(self): + return self._simple_subscription_settings + + async def subscribe(self, topic: str, filter_expression: FilterExpression): + if self._state != State.Running: + raise Exception("Simple consumer is not running") + + await self.get_subscription_load_balancer(topic) + self._subscription_expressions[topic] = filter_expression + + def unsubscribe(self, topic: str): + if self._state != State.Running: + raise Exception("Simple consumer is not running") + try: + self._subscription_expressions.pop(topic) + except KeyError: + pass + + async def start(self): + """Start the RocketMQ consumer and log the operation.""" + logger.info(f"Begin to start the rocketmq consumer, client_id={self.client_id}") + with self._state_lock: + if self._state != State.New: + raise Exception("Consumer already started") + await super().start() + # Start all necessary operations + self._state = State.Running + logger.info(f"The rocketmq consumer starts successfully, client_id={self.client_id}") + + async def shutdown(self): + """Shutdown the RocketMQ consumer and log the operation.""" + logger.info(f"Begin to shutdown the rocketmq consumer, client_id={self.client_id}") + with self._state_lock: + if self._state != State.Running: + raise Exception("Consumer is not running") + # Shutdown all necessary operations + self._state = State.Terminated + await super().shutdown() + logger.info(f"Shutdown the rocketmq consumer successfully, client_id={self.client_id}") + + def update_subscription_load_balancer(self, topic, topic_route_data): + # if a load balancer for this topic already exists in the subscription routing data cache, update it + subscription_load_balancer = self._subscription_route_data_cache.get(topic) + if subscription_load_balancer: + subscription_load_balancer.update(topic_route_data) + # otherwise, create a new subscription load balancer + else: + subscription_load_balancer = SubscriptionLoadBalancer(topic_route_data) + + # store new or updated subscription load balancers in the subscription routing data cache + self._subscription_route_data_cache[topic] = subscription_load_balancer + return subscription_load_balancer + + async def get_subscription_load_balancer(self, topic): + # if a load balancer for this topic already exists in the subscription routing data cache, return it + subscription_load_balancer = self._subscription_route_data_cache.get(topic) + if subscription_load_balancer: + return subscription_load_balancer + + # otherwise, obtain the routing data for the topic + topic_route_data = await self.get_route_data(topic) + # update subscription load balancer + return self.update_subscription_load_balancer(topic, topic_route_data) + + async def receive(self, max_message_num, invisible_duration): + if self._state != State.Running: + raise Exception("Simple consumer is not running") + if max_message_num <= 0: + raise Exception("maxMessageNum must be greater than 0") + copy = dict(self._subscription_expressions) + topics = list(copy.keys()) + if len(topics) == 0: + raise ValueError("There is no topic to receive message") + + index = (self._topic_round_robin_index + 1) % len(topics) + self._topic_round_robin_index = index + topic = topics[index] + filter_expression = self._subscription_expressions[topic] + subscription_load_balancer = await self.get_subscription_load_balancer(topic) + mq = subscription_load_balancer.take_message_queue() + request = self.wrap_receive_message_request(max_message_num, mq, filter_expression, self._await_duration, invisible_duration) + result = await self.receive_message(request, mq, self._await_duration) + return result.messages + + def wrap_change_invisible_duration(self, message_view: MessageView, invisible_duration): + topic_resource = ProtoResource() + topic_resource.name = message_view.topic + + request = ProtoChangeInvisibleDurationRequest() + request.topic.CopyFrom(topic_resource) + group = ProtoResource() + group.name = message_view.message_group + logger.debug(message_view.message_group) + request.group.CopyFrom(group) + request.receipt_handle = message_view.receipt_handle + request.invisible_duration.CopyFrom(Duration(seconds=invisible_duration)) + request.message_id = message_view.message_id + + return request + + async def change_invisible_duration(self, message_view: MessageView, invisible_duration): + if self._state != State.Running: + raise Exception("Simple consumer is not running") + + request = self.wrap_change_invisible_duration(message_view, invisible_duration) + result = await self.client_manager.change_invisible_duration( + message_view.message_queue.broker.endpoints, + request, + self.client_config.request_timeout + ) + logger.debug(result) + + async def ack(self, message_view: MessageView): + if self._state != State.Running: + raise Exception("Simple consumer is not running") + request = self.wrap_ack_message_request(message_view) + result = await self.client_manager.ack_message(message_view.message_queue.broker.endpoints, request=request, timeout_seconds=self.client_config.request_timeout) + logger.info(result) + + def get_protobuf_group(self): + return ProtoResource(name=self.consumer_group) + + def wrap_ack_message_request(self, message_view: MessageView): + topic_resource = ProtoResource() + topic_resource.name = message_view.topic + entry = ProtoAckMessageEntry() + entry.message_id = message_view.message_id + entry.receipt_handle = message_view.receipt_handle + + request = ProtoAckMessageRequest(group=self.get_protobuf_group(), topic=topic_resource, entries=[entry]) + return request + + class Builder: + def __init__(self): + self._consumer_group_regex = re.compile(r"^[%a-zA-Z0-9_-]+$") + self._clientConfig = None + self._consumerGroup = None + self._awaitDuration = None + self._subscriptionExpressions = {} + + def set_client_config(self, client_config: ClientConfig): + if client_config is None: + raise ValueError("clientConfig should not be null") + self._clientConfig = client_config + return self + + def set_consumer_group(self, consumer_group: str): + if consumer_group is None: + raise ValueError("consumerGroup should not be null") + # Assuming CONSUMER_GROUP_REGEX is defined in the outer scope + if not re.match(self._consumer_group_regex, consumer_group): + raise ValueError(f"topic does not match the regex {self._consumer_group_regex}") + self._consumerGroup = consumer_group + return self + + def set_await_duration(self, await_duration: int): + self._awaitDuration = await_duration + return self + + def set_subscription_expression(self, subscription_expressions: Dict[str, FilterExpression]): + if subscription_expressions is None: + raise ValueError("subscriptionExpressions should not be null") + if len(subscription_expressions) == 0: + raise ValueError("subscriptionExpressions should not be empty") + self._subscriptionExpressions = subscription_expressions + return self + + async def build(self): + if self._clientConfig is None: + raise ValueError("clientConfig has not been set yet") + if self._consumerGroup is None: + raise ValueError("consumerGroup has not been set yet") + if len(self._subscriptionExpressions) == 0: + raise ValueError("subscriptionExpressions has not been set yet") + + simple_consumer = SimpleConsumer(self._clientConfig, self._consumerGroup, self._awaitDuration, self._subscriptionExpressions) + await simple_consumer.start() + return simple_consumer + + +async def test(): + credentials = SessionCredentials("username", "password") + credentials_provider = SessionCredentialsProvider(credentials) + client_config = ClientConfig( + endpoints=Endpoints("endpoint"), + session_credentials_provider=credentials_provider, + ssl_enabled=True, + ) + topic = Resource() + topic.name = "normal_topic" + + consumer_group = "yourConsumerGroup" + subscription = {topic.name: FilterExpression("*")} + simple_consumer = (await SimpleConsumer.Builder() + .set_client_config(client_config) + .set_consumer_group(consumer_group) + .set_await_duration(15) + .set_subscription_expression(subscription) + .build()) + logger.info(simple_consumer) + # while True: + message_views = await simple_consumer.receive(16, 15) + logger.info(message_views) + for message in message_views: + logger.info(message.body) + logger.info(f"Received a message, topic={message.topic}, message-id={message.message_id}, body-size={len(message.body)}") + await simple_consumer.ack(message) + logger.info(f"Message is acknowledged successfully, message-id={message.message_id}") + + +async def test_fifo_message(): + credentials = SessionCredentials("username", "password") + credentials_provider = SessionCredentialsProvider(credentials) + client_config = ClientConfig( + endpoints=Endpoints("endpoint"), + session_credentials_provider=credentials_provider, + ssl_enabled=True, + ) + topic = Resource() + topic.name = "fifo_topic" + + consumer_group = "yourConsumerGroup" + subscription = {topic.name: FilterExpression("*")} + simple_consumer = (await SimpleConsumer.Builder() + .set_client_config(client_config) + .set_consumer_group(consumer_group) + .set_await_duration(15) + .set_subscription_expression(subscription) + .build()) + logger.info(simple_consumer) + # while True: + message_views = await simple_consumer.receive(16, 15) + # logger.info(message_views) + for message in message_views: + logger.info(message.body) + logger.info(f"Received a message, topic={message.topic}, message-id={message.message_id}, body-size={len(message.body)}") + await simple_consumer.ack(message) + logger.info(f"Message is acknowledged successfully, message-id={message.message_id}") + + +async def test_change_invisible_duration(): + credentials = SessionCredentials("username", "password") + credentials_provider = SessionCredentialsProvider(credentials) + client_config = ClientConfig( + endpoints=Endpoints("endpoint"), + session_credentials_provider=credentials_provider, + ssl_enabled=True, + ) + topic = Resource() + topic.name = "fifo_topic" + + consumer_group = "yourConsumerGroup" + subscription = {topic.name: FilterExpression("*")} + simple_consumer = (await SimpleConsumer.Builder() + .set_client_config(client_config) + .set_consumer_group(consumer_group) + .set_await_duration(15) + .set_subscription_expression(subscription) + .build()) + logger.info(simple_consumer) + # while True: + message_views = await simple_consumer.receive(16, 15) + # logger.info(message_views) + for message in message_views: + await simple_consumer.change_invisible_duration(message_view=message, invisible_duration=3) + logger.info(message.body) + logger.info(f"Received a message, topic={message.topic}, message-id={message.message_id}, body-size={len(message.body)}") + await simple_consumer.ack(message) + logger.info(f"Message is acknowledged successfully, message-id={message.message_id}") + + +async def test_subscribe_unsubscribe(): + credentials = SessionCredentials("username", "password") + credentials_provider = SessionCredentialsProvider(credentials) + client_config = ClientConfig( + endpoints=Endpoints("endpoint"), + session_credentials_provider=credentials_provider, + ssl_enabled=True, + ) + topic = Resource() + topic.name = "normal_topic" + + consumer_group = "yourConsumerGroup" + subscription = {topic.name: FilterExpression("*")} + simple_consumer = (await SimpleConsumer.Builder() + .set_client_config(client_config) + .set_consumer_group(consumer_group) + .set_await_duration(15) + .set_subscription_expression(subscription) + .build()) + logger.info(simple_consumer) + # while True: + message_views = await simple_consumer.receive(16, 15) + logger.info(message_views) + for message in message_views: + logger.info(message.body) + logger.info(f"Received a message, topic={message.topic}, message-id={message.message_id}, body-size={len(message.body)}") + await simple_consumer.ack(message) + logger.info(f"Message is acknowledged successfully, message-id={message.message_id}") + simple_consumer.unsubscribe('normal_topic') + await simple_consumer.subscribe('fifo_topic', FilterExpression("*")) + message_views = await simple_consumer.receive(16, 15) + logger.info(message_views) + for message in message_views: + logger.info(message.body) + logger.info(f"Received a message, topic={message.topic}, message-id={message.message_id}, body-size={len(message.body)}") + await simple_consumer.ack(message) + logger.info(f"Message is acknowledged successfully, message-id={message.message_id}") + +if __name__ == "__main__": + asyncio.run(test_subscribe_unsubscribe()) diff --git a/python/rocketmq/simple_subscription_settings.py b/python/rocketmq/simple_subscription_settings.py new file mode 100644 index 00000000..6d193008 --- /dev/null +++ b/python/rocketmq/simple_subscription_settings.py @@ -0,0 +1,89 @@ +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Dict + +from google.protobuf.duration_pb2 import Duration +from rocketmq.filter_expression import ExpressionType +from rocketmq.log import logger +from rocketmq.protocol.definition_pb2 import \ + FilterExpression as ProtoFilterExpression +from rocketmq.protocol.definition_pb2 import FilterType as ProtoFilterType +from rocketmq.protocol.definition_pb2 import Resource as ProtoResource +from rocketmq.protocol.definition_pb2 import Settings as ProtoSettings +from rocketmq.protocol.definition_pb2 import Subscription as ProtoSubscription +from rocketmq.protocol.definition_pb2 import \ + SubscriptionEntry as ProtoSubscriptionEntry + +from .settings import ClientType, ClientTypeHelper, Settings + + +# Assuming a simple representation of FilterExpression for the purpose of this example +class FilterExpression: + def __init__(self, type, expression): + self.Type = type + self.Expression = expression + + +class SimpleSubscriptionSettings(Settings): + + def __init__(self, clientId, endpoints, consumerGroup, requestTimeout, longPollingTimeout, + subscriptionExpressions: Dict[str, FilterExpression]): + super().__init__(clientId, ClientType.SimpleConsumer, endpoints, None, requestTimeout) + self._group = consumerGroup # Simplified as string for now + self._longPollingTimeout = longPollingTimeout + self._subscriptionExpressions = subscriptionExpressions + + def Sync(self, settings: ProtoSettings): + if not isinstance(settings, ProtoSettings): + logger.error(f"[Bug] Issued settings doesn't match with the client type, clientId={self.ClientId}, clientType={self.ClientType}") + + def to_protobuf(self): + subscriptionEntries = [] + + for key, value in self._subscriptionExpressions.items(): + topic = ProtoResource() + topic.name = key + + subscriptionEntry = ProtoSubscriptionEntry() + filterExpression = ProtoFilterExpression() + + if value.type == ExpressionType.Tag: + filterExpression.type = ProtoFilterType.TAG + elif value.type == ExpressionType.Sql92: + filterExpression.type = ProtoFilterType.SQL + else: + logger.warn(f"[Bug] Unrecognized filter type={value.Type} for simple consumer") + + filterExpression.expression = value.expression + subscriptionEntry.topic.CopyFrom(topic) + subscriptionEntries.append(subscriptionEntry) + + subscription = ProtoSubscription() + group = ProtoResource() + group.name = self._group + subscription.group.CopyFrom(group) + subscription.subscriptions.extend(subscriptionEntries) + duration_longPollingTimeout = Duration(seconds=self._longPollingTimeout) + subscription.long_polling_timeout.CopyFrom(duration_longPollingTimeout) + + settings = super().to_protobuf() + settings.access_point.CopyFrom(self.Endpoints.to_protobuf()) # Assuming Endpoints has a to_protobuf method + settings.client_type = ClientTypeHelper.to_protobuf(self.ClientType) + + settings.request_timeout.CopyFrom(Duration(seconds=int(self.RequestTimeout.total_seconds()))) + settings.subscription.CopyFrom(subscription) + + return settings diff --git a/python/rocketmq/state.py b/python/rocketmq/state.py new file mode 100644 index 00000000..e8f2d010 --- /dev/null +++ b/python/rocketmq/state.py @@ -0,0 +1,25 @@ +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from enum import Enum + + +class State(Enum): + New = 1 + Starting = 2 + Running = 3 + Stopping = 4 + Terminated = 5 + Failed = 6