File: //usr/local/lib/python3.9/site-packages/kombu/transport/gcpubsub.py
"""GCP Pub/Sub transport module for kombu.
More information about GCP Pub/Sub:
https://cloud.google.com/pubsub
Features
========
* Type: Virtual
* Supports Direct: Yes
* Supports Topic: No
* Supports Fanout: Yes
* Supports Priority: No
* Supports TTL: No
Connection String
=================
Connection string has the following formats:
.. code-block::
gcpubsub://projects/project-name
Transport Options
=================
* ``queue_name_prefix``: (str) Prefix for queue names.
* ``ack_deadline_seconds``: (int) The maximum time after receiving a message
and acknowledging it before pub/sub redelivers the message.
* ``expiration_seconds``: (int) Subscriptions without any subscriber
activity or changes made to their properties are removed after this period.
Examples of subscriber activities include open connections,
active pulls, or successful pushes.
* ``wait_time_seconds``: (int) The maximum time to wait for new messages.
Defaults to 10.
* ``retry_timeout_seconds``: (int) The maximum time to wait before retrying.
* ``bulk_max_messages``: (int) The maximum number of messages to pull in bulk.
Defaults to 32.
"""
from __future__ import annotations
import dataclasses
import datetime
import string
import threading
from concurrent.futures import (FIRST_COMPLETED, Future, ThreadPoolExecutor,
wait)
from contextlib import suppress
from os import getpid
from queue import Empty
from threading import Lock
from time import monotonic, sleep
from uuid import NAMESPACE_OID, uuid3
from _socket import gethostname
from _socket import timeout as socket_timeout
from google.api_core.exceptions import (AlreadyExists, DeadlineExceeded,
PermissionDenied)
from google.api_core.retry import Retry
from google.cloud import monitoring_v3
from google.cloud.monitoring_v3 import query
from google.cloud.pubsub_v1 import PublisherClient, SubscriberClient
from google.cloud.pubsub_v1 import exceptions as pubsub_exceptions
from google.cloud.pubsub_v1.publisher import exceptions as publisher_exceptions
from google.cloud.pubsub_v1.subscriber import \
exceptions as subscriber_exceptions
from google.pubsub_v1 import gapic_version as package_version
from kombu.entity import TRANSIENT_DELIVERY_MODE
from kombu.log import get_logger
from kombu.utils.encoding import bytes_to_str, safe_str
from kombu.utils.json import dumps, loads
from kombu.utils.objects import cached_property
from . import virtual
logger = get_logger('kombu.transport.gcpubsub')
# dots are replaced by dash, all other punctuation replaced by underscore.
PUNCTUATIONS_TO_REPLACE = set(string.punctuation) - {'_', '.', '-'}
CHARS_REPLACE_TABLE = {
ord('.'): ord('-'),
**{ord(c): ord('_') for c in PUNCTUATIONS_TO_REPLACE},
}
class UnackedIds:
"""Threadsafe list of ack_ids."""
def __init__(self):
self._list = []
self._lock = Lock()
def append(self, val):
# append is atomic
self._list.append(val)
def extend(self, vals: list):
# extend is atomic
self._list.extend(vals)
def pop(self, index=-1):
with self._lock:
return self._list.pop(index)
def remove(self, val):
with self._lock, suppress(ValueError):
self._list.remove(val)
def __len__(self):
with self._lock:
return len(self._list)
def __getitem__(self, item):
# getitem is atomic
return self._list[item]
class AtomicCounter:
"""Threadsafe counter.
Returns the value after inc/dec operations.
"""
def __init__(self, initial=0):
self._value = initial
self._lock = Lock()
def inc(self, n=1):
with self._lock:
self._value += n
return self._value
def dec(self, n=1):
with self._lock:
self._value -= n
return self._value
def get(self):
with self._lock:
return self._value
@dataclasses.dataclass
class QueueDescriptor:
"""Pub/Sub queue descriptor."""
name: str
topic_path: str # projects/{project_id}/topics/{topic_id}
subscription_id: str
subscription_path: str # projects/{project_id}/subscriptions/{subscription_id}
unacked_ids: UnackedIds = dataclasses.field(default_factory=UnackedIds)
class Channel(virtual.Channel):
"""GCP Pub/Sub channel."""
supports_fanout = True
do_restore = False # pub/sub does that for us
default_wait_time_seconds = 10
default_ack_deadline_seconds = 240
default_expiration_seconds = 86400
default_retry_timeout_seconds = 300
default_bulk_max_messages = 32
_min_ack_deadline = 10
_fanout_exchanges = set()
_unacked_extender: threading.Thread = None
_stop_extender = threading.Event()
_n_channels = AtomicCounter()
_queue_cache: dict[str, QueueDescriptor] = {}
_tmp_subscriptions: set[str] = set()
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.pool = ThreadPoolExecutor()
logger.info('new GCP pub/sub channel: %s', self.conninfo.hostname)
self.project_id = Transport.parse_uri(self.conninfo.hostname)
if self._n_channels.inc() == 1:
Channel._unacked_extender = threading.Thread(
target=self._extend_unacked_deadline,
daemon=True,
)
self._stop_extender.clear()
Channel._unacked_extender.start()
def entity_name(self, name: str, table=CHARS_REPLACE_TABLE) -> str:
"""Format AMQP queue name into a valid Pub/Sub queue name."""
if not name.startswith(self.queue_name_prefix):
name = self.queue_name_prefix + name
return str(safe_str(name)).translate(table)
def _queue_bind(self, exchange, routing_key, pattern, queue):
exchange_type = self.typeof(exchange).type
queue = self.entity_name(queue)
logger.debug(
'binding queue: %s to %s exchange: %s with routing_key: %s',
queue,
exchange_type,
exchange,
routing_key,
)
filter_args = {}
if exchange_type == 'direct':
# Direct exchange is implemented as a single subscription
# E.g. for exchange 'test_direct':
# -topic:'test_direct'
# -bound queue:'direct1':
# -subscription: direct1' on topic 'test_direct'
# -filter:routing_key'
filter_args = {
'filter': f'attributes.routing_key="{routing_key}"'
}
subscription_path = self.subscriber.subscription_path(
self.project_id, queue
)
message_retention_duration = self.expiration_seconds
elif exchange_type == 'fanout':
# Fanout exchange is implemented as a separate subscription.
# E.g. for exchange 'test_fanout':
# -topic:'test_fanout'
# -bound queue 'fanout1':
# -subscription:'fanout1-uuid' on topic 'test_fanout'
# -bound queue 'fanout2':
# -subscription:'fanout2-uuid' on topic 'test_fanout'
uid = f'{uuid3(NAMESPACE_OID, f"{gethostname()}.{getpid()}")}'
uniq_sub_name = f'{queue}-{uid}'
subscription_path = self.subscriber.subscription_path(
self.project_id, uniq_sub_name
)
self._tmp_subscriptions.add(subscription_path)
self._fanout_exchanges.add(exchange)
message_retention_duration = 600
else:
raise NotImplementedError(
f'exchange type {exchange_type} not implemented'
)
exchange_topic = self._create_topic(
self.project_id, exchange, message_retention_duration
)
self._create_subscription(
topic_path=exchange_topic,
subscription_path=subscription_path,
filter_args=filter_args,
msg_retention=message_retention_duration,
)
qdesc = QueueDescriptor(
name=queue,
topic_path=exchange_topic,
subscription_id=queue,
subscription_path=subscription_path,
)
self._queue_cache[queue] = qdesc
def _create_topic(
self,
project_id: str,
topic_id: str,
message_retention_duration: int = None,
) -> str:
topic_path = self.publisher.topic_path(project_id, topic_id)
if self._is_topic_exists(topic_path):
# topic creation takes a while, so skip if possible
logger.debug('topic: %s exists', topic_path)
return topic_path
try:
logger.debug('creating topic: %s', topic_path)
request = {'name': topic_path}
if message_retention_duration:
request[
'message_retention_duration'
] = f'{message_retention_duration}s'
self.publisher.create_topic(request=request)
except AlreadyExists:
pass
return topic_path
def _is_topic_exists(self, topic_path: str) -> bool:
topics = self.publisher.list_topics(
request={"project": f'projects/{self.project_id}'}
)
for t in topics:
if t.name == topic_path:
return True
return False
def _create_subscription(
self,
project_id: str = None,
topic_id: str = None,
topic_path: str = None,
subscription_path: str = None,
filter_args=None,
msg_retention: int = None,
) -> str:
subscription_path = (
subscription_path
or self.subscriber.subscription_path(self.project_id, topic_id)
)
topic_path = topic_path or self.publisher.topic_path(
project_id, topic_id
)
try:
logger.debug(
'creating subscription: %s, topic: %s, filter: %s',
subscription_path,
topic_path,
filter_args,
)
msg_retention = msg_retention or self.expiration_seconds
self.subscriber.create_subscription(
request={
"name": subscription_path,
"topic": topic_path,
'ack_deadline_seconds': self.ack_deadline_seconds,
'expiration_policy': {
'ttl': f'{self.expiration_seconds}s'
},
'message_retention_duration': f'{msg_retention}s',
**(filter_args or {}),
}
)
except AlreadyExists:
pass
return subscription_path
def _delete(self, queue, *args, **kwargs):
"""Delete a queue by name."""
queue = self.entity_name(queue)
logger.info('deleting queue: %s', queue)
qdesc = self._queue_cache.get(queue)
if not qdesc:
return
self.subscriber.delete_subscription(
request={"subscription": qdesc.subscription_path}
)
self._queue_cache.pop(queue, None)
def _put(self, queue, message, **kwargs):
"""Put a message onto the queue."""
queue = self.entity_name(queue)
qdesc = self._queue_cache[queue]
routing_key = self._get_routing_key(message)
logger.debug(
'putting message to queue: %s, topic: %s, routing_key: %s',
queue,
qdesc.topic_path,
routing_key,
)
encoded_message = dumps(message)
self.publisher.publish(
qdesc.topic_path,
encoded_message.encode("utf-8"),
routing_key=routing_key,
)
def _put_fanout(self, exchange, message, routing_key, **kwargs):
"""Put a message onto fanout exchange."""
self._lookup(exchange, routing_key)
topic_path = self.publisher.topic_path(self.project_id, exchange)
logger.debug(
'putting msg to fanout exchange: %s, topic: %s',
exchange,
topic_path,
)
encoded_message = dumps(message)
self.publisher.publish(
topic_path,
encoded_message.encode("utf-8"),
retry=Retry(deadline=self.retry_timeout_seconds),
)
def _get(self, queue: str, timeout: float = None):
"""Retrieves a single message from a queue."""
queue = self.entity_name(queue)
qdesc = self._queue_cache[queue]
try:
response = self.subscriber.pull(
request={
'subscription': qdesc.subscription_path,
'max_messages': 1,
},
retry=Retry(deadline=self.retry_timeout_seconds),
timeout=timeout or self.wait_time_seconds,
)
except DeadlineExceeded:
raise Empty()
if len(response.received_messages) == 0:
raise Empty()
message = response.received_messages[0]
ack_id = message.ack_id
payload = loads(message.message.data)
delivery_info = payload['properties']['delivery_info']
logger.debug(
'queue:%s got message, ack_id: %s, payload: %s',
queue,
ack_id,
payload['properties'],
)
if self._is_auto_ack(payload['properties']):
logger.debug('auto acking message ack_id: %s', ack_id)
self._do_ack([ack_id], qdesc.subscription_path)
else:
delivery_info['gcpubsub_message'] = {
'queue': queue,
'ack_id': ack_id,
'message_id': message.message.message_id,
'subscription_path': qdesc.subscription_path,
}
qdesc.unacked_ids.append(ack_id)
return payload
def _is_auto_ack(self, payload_properties: dict):
exchange = payload_properties['delivery_info']['exchange']
delivery_mode = payload_properties['delivery_mode']
return (
delivery_mode == TRANSIENT_DELIVERY_MODE
or exchange in self._fanout_exchanges
)
def _get_bulk(self, queue: str, timeout: float):
"""Retrieves bulk of messages from a queue."""
prefixed_queue = self.entity_name(queue)
qdesc = self._queue_cache[prefixed_queue]
max_messages = self._get_max_messages_estimate()
if not max_messages:
raise Empty()
try:
response = self.subscriber.pull(
request={
'subscription': qdesc.subscription_path,
'max_messages': max_messages,
},
retry=Retry(deadline=self.retry_timeout_seconds),
timeout=timeout or self.wait_time_seconds,
)
except DeadlineExceeded:
raise Empty()
received_messages = response.received_messages
if len(received_messages) == 0:
raise Empty()
auto_ack_ids = []
ret_payloads = []
logger.debug(
'batching %d messages from queue: %s',
len(received_messages),
prefixed_queue,
)
for message in received_messages:
ack_id = message.ack_id
payload = loads(bytes_to_str(message.message.data))
delivery_info = payload['properties']['delivery_info']
delivery_info['gcpubsub_message'] = {
'queue': prefixed_queue,
'ack_id': ack_id,
'message_id': message.message.message_id,
'subscription_path': qdesc.subscription_path,
}
if self._is_auto_ack(payload['properties']):
auto_ack_ids.append(ack_id)
else:
qdesc.unacked_ids.append(ack_id)
ret_payloads.append(payload)
if auto_ack_ids:
logger.debug('auto acking ack_ids: %s', auto_ack_ids)
self._do_ack(auto_ack_ids, qdesc.subscription_path)
return queue, ret_payloads
def _get_max_messages_estimate(self) -> int:
max_allowed = self.qos.can_consume_max_estimate()
max_if_unlimited = self.bulk_max_messages
return max_if_unlimited if max_allowed is None else max_allowed
def _lookup(self, exchange, routing_key, default=None):
exchange_info = self.state.exchanges.get(exchange, {})
if not exchange_info:
return super()._lookup(exchange, routing_key, default)
ret = self.typeof(exchange).lookup(
self.get_table(exchange),
exchange,
routing_key,
default,
)
if ret:
return ret
logger.debug(
'no queues bound to exchange: %s, binding on the fly',
exchange,
)
self.queue_bind(exchange, exchange, routing_key)
return [exchange]
def _size(self, queue: str) -> int:
"""Return the number of messages in a queue.
This is a *rough* estimation, as Pub/Sub doesn't provide
an exact API.
"""
queue = self.entity_name(queue)
if queue not in self._queue_cache:
return 0
qdesc = self._queue_cache[queue]
result = query.Query(
self.monitor,
self.project_id,
'pubsub.googleapis.com/subscription/num_undelivered_messages',
end_time=datetime.datetime.now(),
minutes=1,
).select_resources(subscription_id=qdesc.subscription_id)
# monitoring API requires the caller to have the monitoring.viewer
# role. Since we can live without the exact number of messages
# in the queue, we can ignore the exception and allow users to
# use the transport without this role.
with suppress(PermissionDenied):
return sum(
content.points[0].value.int64_value for content in result
)
return -1
def basic_ack(self, delivery_tag, multiple=False):
"""Acknowledge one message."""
if multiple:
raise NotImplementedError('multiple acks not implemented')
delivery_info = self.qos.get(delivery_tag).delivery_info
pubsub_message = delivery_info['gcpubsub_message']
ack_id = pubsub_message['ack_id']
queue = pubsub_message['queue']
logger.debug('ack message. queue: %s ack_id: %s', queue, ack_id)
subscription_path = pubsub_message['subscription_path']
self._do_ack([ack_id], subscription_path)
qdesc = self._queue_cache[queue]
qdesc.unacked_ids.remove(ack_id)
super().basic_ack(delivery_tag)
def _do_ack(self, ack_ids: list[str], subscription_path: str):
self.subscriber.acknowledge(
request={"subscription": subscription_path, "ack_ids": ack_ids},
retry=Retry(deadline=self.retry_timeout_seconds),
)
def _purge(self, queue: str):
"""Delete all current messages in a queue."""
queue = self.entity_name(queue)
qdesc = self._queue_cache.get(queue)
if not qdesc:
return
n = self._size(queue)
self.subscriber.seek(
request={
"subscription": qdesc.subscription_path,
"time": datetime.datetime.now(),
}
)
return n
def _extend_unacked_deadline(self):
thread_id = threading.get_native_id()
logger.info(
'unacked deadline extension thread: [%s] started',
thread_id,
)
min_deadline_sleep = self._min_ack_deadline / 2
sleep_time = max(min_deadline_sleep, self.ack_deadline_seconds / 4)
while not self._stop_extender.wait(sleep_time):
for qdesc in self._queue_cache.values():
if len(qdesc.unacked_ids) == 0:
logger.debug(
'thread [%s]: no unacked messages for %s',
thread_id,
qdesc.subscription_path,
)
continue
logger.debug(
'thread [%s]: extend ack deadline for %s: %d msgs [%s]',
thread_id,
qdesc.subscription_path,
len(qdesc.unacked_ids),
list(qdesc.unacked_ids),
)
self.subscriber.modify_ack_deadline(
request={
"subscription": qdesc.subscription_path,
"ack_ids": list(qdesc.unacked_ids),
"ack_deadline_seconds": self.ack_deadline_seconds,
}
)
logger.info(
'unacked deadline extension thread [%s] stopped', thread_id
)
def after_reply_message_received(self, queue: str):
queue = self.entity_name(queue)
sub = self.subscriber.subscription_path(self.project_id, queue)
logger.debug(
'after_reply_message_received: queue: %s, sub: %s', queue, sub
)
self._tmp_subscriptions.add(sub)
@cached_property
def subscriber(self):
return SubscriberClient()
@cached_property
def publisher(self):
return PublisherClient()
@cached_property
def monitor(self):
return monitoring_v3.MetricServiceClient()
@property
def conninfo(self):
return self.connection.client
@property
def transport_options(self):
return self.connection.client.transport_options
@cached_property
def wait_time_seconds(self):
return self.transport_options.get(
'wait_time_seconds', self.default_wait_time_seconds
)
@cached_property
def retry_timeout_seconds(self):
return self.transport_options.get(
'retry_timeout_seconds', self.default_retry_timeout_seconds
)
@cached_property
def ack_deadline_seconds(self):
return self.transport_options.get(
'ack_deadline_seconds', self.default_ack_deadline_seconds
)
@cached_property
def queue_name_prefix(self):
return self.transport_options.get('queue_name_prefix', 'kombu-')
@cached_property
def expiration_seconds(self):
return self.transport_options.get(
'expiration_seconds', self.default_expiration_seconds
)
@cached_property
def bulk_max_messages(self):
return self.transport_options.get(
'bulk_max_messages', self.default_bulk_max_messages
)
def close(self):
"""Close the channel."""
logger.debug('closing channel')
while self._tmp_subscriptions:
sub = self._tmp_subscriptions.pop()
with suppress(Exception):
logger.debug('deleting subscription: %s', sub)
self.subscriber.delete_subscription(
request={"subscription": sub}
)
if not self._n_channels.dec():
self._stop_extender.set()
Channel._unacked_extender.join()
super().close()
@staticmethod
def _get_routing_key(message):
routing_key = (
message['properties']
.get('delivery_info', {})
.get('routing_key', '')
)
return routing_key
class Transport(virtual.Transport):
"""GCP Pub/Sub transport."""
Channel = Channel
can_parse_url = True
polling_interval = 0.1
connection_errors = virtual.Transport.connection_errors + (
pubsub_exceptions.TimeoutError,
)
channel_errors = (
virtual.Transport.channel_errors
+ (
publisher_exceptions.FlowControlLimitError,
publisher_exceptions.MessageTooLargeError,
publisher_exceptions.PublishError,
publisher_exceptions.TimeoutError,
publisher_exceptions.PublishToPausedOrderingKeyException,
)
+ (subscriber_exceptions.AcknowledgeError,)
)
driver_type = 'gcpubsub'
driver_name = 'pubsub_v1'
implements = virtual.Transport.implements.extend(
exchange_type=frozenset(['direct', 'fanout']),
)
def __init__(self, client, **kwargs):
super().__init__(client, **kwargs)
self._pool = ThreadPoolExecutor()
self._get_bulk_future_to_queue: dict[Future, str] = dict()
def driver_version(self):
return package_version.__version__
@staticmethod
def parse_uri(uri: str) -> str:
# URL like:
# gcpubsub://projects/project-name
project = uri.split('gcpubsub://projects/')[1]
return project.strip('/')
@classmethod
def as_uri(self, uri: str, include_password=False, mask='**') -> str:
return uri or 'gcpubsub://'
def drain_events(self, connection, timeout=None):
time_start = monotonic()
polling_interval = self.polling_interval
if timeout and polling_interval and polling_interval > timeout:
polling_interval = timeout
while 1:
try:
self._drain_from_active_queues(timeout=timeout)
except Empty:
if timeout and monotonic() - time_start >= timeout:
raise socket_timeout()
if polling_interval:
sleep(polling_interval)
else:
break
def _drain_from_active_queues(self, timeout):
# cleanup empty requests from prev run
self._rm_empty_bulk_requests()
# submit new requests for all active queues
# longer timeout means less frequent polling
# and more messages in a single bulk
self._submit_get_bulk_requests(timeout=10)
done, _ = wait(
self._get_bulk_future_to_queue,
timeout=timeout,
return_when=FIRST_COMPLETED,
)
empty = {f for f in done if f.exception()}
done -= empty
for f in empty:
self._get_bulk_future_to_queue.pop(f, None)
if not done:
raise Empty()
logger.debug('got %d done get_bulk tasks', len(done))
for f in done:
queue, payloads = f.result()
for payload in payloads:
logger.debug('consuming message from queue: %s', queue)
if queue not in self._callbacks:
logger.warning(
'Message for queue %s without consumers', queue
)
continue
self._deliver(payload, queue)
self._get_bulk_future_to_queue.pop(f, None)
def _rm_empty_bulk_requests(self):
empty = {
f
for f in self._get_bulk_future_to_queue
if f.done() and f.exception()
}
for f in empty:
self._get_bulk_future_to_queue.pop(f, None)
def _submit_get_bulk_requests(self, timeout):
queues_with_submitted_get_bulk = set(
self._get_bulk_future_to_queue.values()
)
for channel in self.channels:
for queue in channel._active_queues:
if queue in queues_with_submitted_get_bulk:
continue
future = self._pool.submit(channel._get_bulk, queue, timeout)
self._get_bulk_future_to_queue[future] = queue