Source code for runnel.stream

from dataclasses import dataclass, replace
from functools import lru_cache
from typing import TYPE_CHECKING, Any, Callable, Iterable, Type, Union

import structlog

from runnel.exceptions import Misconfigured
from runnel.interfaces import Serializer
from runnel.record import Record

if TYPE_CHECKING:
    from runnel.app import App

logger = structlog.get_logger(__name__)


[docs]@dataclass(frozen=True) class Stream: """ A set of partitioned Redis streams, together representing a single logical event stream. Not intended to be used directly. Use :attr:`runnel.App.stream` instead. """ app: "App" name: str record: Type[Record] partition_by: Union[str, Callable] serializer: Serializer # Not applicable for records with primitive=True. hasher: Callable[[Any], int] partition_count: int partition_size: int def __post_init__(self): if self.record._primitive and self.serializer: raise Misconfigured("Cannot serialize primitive records") by = self.partition_by if isinstance(by, str) and not by in self.record.__fields__: raise Misconfigured("Stream.partition_by must specify a record field") def __hash__(self): return hash(self.id) def clone(self, **kwargs): return replace(self, **kwargs) @property def id(self): return f"{self.app.name}.{self.name}" @lru_cache(maxsize=None) def partition_key(self, i): return f"__strm:{self.id}.{i}" @lru_cache() def all_partition_keys(self): return [self.partition_key(i) for i in range(0, self.partition_count)] def route(self, key): return self.partition_key(self.hash(key)) def hash(self, key): return self.hasher(key) % self.partition_count
[docs] async def send(self, *records: Iterable[Record], stream_ids=None): """ Send records to partitions of the stream, according to their partition keys. Parameters ---------- records : Iterable[Record] The records to send. stream_ids : Optional[Iterable[str]] A list of stream_ids corresponding to the records. Must be the same length as records. If ``None``, then ``"*"`` will be used for all records. See `<https://redis.io/commands/xadd>`_ for more details. """ if not stream_ids: stream_ids = ["*" for _ in range(len(records))] assert len(stream_ids) == len(records) async with await self.app.redis.pipeline() as pipe: for record, stream_id in zip(records, stream_ids): await pipe.xadd( name=self.route(self._compute_key(record)), entry=self.serialize(record), max_len=self.partition_size, approximate=True, stream_id=stream_id ) await pipe.execute()
async def read(self, group, consumer, prefetch, timeout, **keys): return await self.app.redis.xreadgroup( # yapf: disable group=group, consumer_id=consumer, count=prefetch, block=timeout, **keys, ) async def ack(self, *events): if len(events) == 1: e = events[0] await self.app.redis.xack(e.partition.key, e.group, e.xid) else: keys = {} # XACK supports multiple ids in one command, but only per # key and consumer group, so we must preprocess the events. for e in events: if e.partition.key not in keys: keys[e.partition.key] = {} if e.group not in keys[e.partition.key]: keys[e.partition.key] = {e.group: []} keys[e.partition.key][e.group].append(e.xid) async with await self.app.redis.pipeline() as pipe: for key, groups in keys.items(): for group, xids in groups.items(): await pipe.execute_command(*["XACK", key, group, *xids]) await pipe.execute() logger.debug("acked", events=[e.data for e in events]) def _compute_key(self, record): if isinstance(self.partition_by, str): return getattr(record, self.partition_by) elif isinstance(self.partition_by, Callable): return self.partition_by(record) def serialize(self, record): if self.record._primitive: return {k.encode("utf-8"): v for k, v in record.dict().items()} value = self.serializer.dumps(record.dict()) if self.serializer.compressor: value = self.serializer.compressor.compress(value) return {b"data": value} def deserialize(self, value): if self.record._primitive: value = {k.decode("utf-8"): v for k, v in value.items()} else: value = value[b"data"] if self.serializer and self.serializer.compressor: value = self.serializer.compressor.decompress(value) value = self.serializer.loads(value) return self.record(**value)