|
1 | 1 | import asyncio
|
2 | 2 | import copy
|
| 3 | +from collections import OrderedDict |
| 4 | +from time import monotonic_ns |
3 | 5 | from typing import Any, Callable, Coroutine, Dict, List, Optional, Union
|
4 | 6 |
|
5 | 7 | from fastapi_websocket_rpc import RpcChannel
|
@@ -37,7 +39,73 @@ class Subscription(BaseModel):
|
37 | 39 |
|
38 | 40 | # Publish event callback signature
|
39 | 41 | def EventCallback(subscription: Subscription, data: Any):
|
40 |
| - pass |
| 42 | + ... |
| 43 | + |
| 44 | + |
| 45 | +SUBSCRIPTION_TASK_CLEANUP_TIME = 60 |
| 46 | + |
| 47 | + |
| 48 | +class SubscriptionPusher: |
| 49 | + def __init__(self): |
| 50 | + self._queues: Dict[str, asyncio.Queue] = {} |
| 51 | + self._tasks: Dict[str, asyncio.Task] = {} |
| 52 | + self._queue_flush_times: OrderedDict[str, int] = OrderedDict() |
| 53 | + self._queues_lock = asyncio.Lock() |
| 54 | + self._cleanup_task = asyncio.create_task(self._cleanup_queues()) |
| 55 | + |
| 56 | + async def trigger(self, subscription: Subscription, data): |
| 57 | + need_create_task = False |
| 58 | + await self._queues_lock.acquire() |
| 59 | + try: |
| 60 | + if not subscription.id in self._queues: |
| 61 | + self._queues[subscription.id] = asyncio.Queue() |
| 62 | + need_create_task = True |
| 63 | + finally: |
| 64 | + self._queues_lock.release() |
| 65 | + if need_create_task: |
| 66 | + self._queue_flush_times[subscription.id] = monotonic_ns() |
| 67 | + self._tasks[subscription.id] = asyncio.create_task( |
| 68 | + self._handle_queue(subscription) |
| 69 | + ) |
| 70 | + await self._queues[subscription.id].put(data) |
| 71 | + |
| 72 | + async def _handle_queue(self, subscription: Subscription): |
| 73 | + while True: |
| 74 | + try: |
| 75 | + data = await self._queues[subscription.id].get() |
| 76 | + try: |
| 77 | + await subscription.callback(subscription, data) |
| 78 | + except Exception: |
| 79 | + logger.exception( |
| 80 | + "Unable to handle subscription %r data %r:", subscription, data |
| 81 | + ) |
| 82 | + if self._queues[subscription.id].empty(): |
| 83 | + self._queue_flush_times[subscription.id] = monotonic_ns() |
| 84 | + self._queue_flush_times.move_to_end(subscription.id) |
| 85 | + except Exception: |
| 86 | + logger.exception("Handle failed:") |
| 87 | + |
| 88 | + async def _cleanup_queues(self): |
| 89 | + while True: |
| 90 | + most_recent_flush_to_delete = monotonic_ns() - SUBSCRIPTION_TASK_CLEANUP_TIME |
| 91 | + await self._queues_lock.acquire() |
| 92 | + try: |
| 93 | + # This code is safe because there are no any awaits except for the lock, if this changes then we need to more locks |
| 94 | + subscriptions_to_delete = [] |
| 95 | + for subscription_id, last_flush in self._queue_flush_times.items(): |
| 96 | + if last_flush > most_recent_flush_to_delete: |
| 97 | + # We're done |
| 98 | + break |
| 99 | + if self._queues[subscription_id].empty(): |
| 100 | + subscriptions_to_delete.append(subscription_id) |
| 101 | + for subscription_id in subscriptions_to_delete: |
| 102 | + self._tasks[subscription_id].cancel() |
| 103 | + del self._tasks[subscription_id] |
| 104 | + del self._queues[subscription_id] |
| 105 | + del self._queue_flush_times[subscription_id] |
| 106 | + finally: |
| 107 | + self._queues_lock.release() |
| 108 | + await asyncio.sleep(1) |
41 | 109 |
|
42 | 110 |
|
43 | 111 | class EventNotifier:
|
@@ -70,6 +138,7 @@ def __init__(self):
|
70 | 138 | self._on_unsubscribe_events = []
|
71 | 139 | # List of restriction checks to perform on every action on the channel
|
72 | 140 | self._channel_restrictions = []
|
| 141 | + self._subscription_pusher = SubscriptionPusher() |
73 | 142 |
|
74 | 143 | def gen_subscriber_id(self):
|
75 | 144 | return gen_uid()
|
@@ -175,7 +244,8 @@ async def trigger_callback(
|
175 | 244 | subscriber_id: SubscriberId,
|
176 | 245 | subscription: Subscription,
|
177 | 246 | ):
|
178 |
| - await subscription.callback(subscription, data) |
| 247 | + logger.info("trigger on %s %s", subscription.id, topic) |
| 248 | + await self._subscription_pusher.trigger(subscription, data) |
179 | 249 |
|
180 | 250 | async def callback_subscribers(
|
181 | 251 | self,
|
|
0 commit comments