|
1 | 1 | import asyncio
|
2 | 2 | import copy
|
| 3 | +from collections import OrderedDict |
| 4 | +from time import monotonic_ns |
3 | 5 | from typing import Any, Callable, Coroutine, Dict, List, Optional, Union
|
4 | 6 |
|
5 | 7 | from fastapi_websocket_rpc import RpcChannel
|
@@ -37,7 +39,65 @@ class Subscription(BaseModel):
|
37 | 39 |
|
38 | 40 | # Publish event callback signature
|
39 | 41 | def EventCallback(subscription: Subscription, data: Any):
|
40 |
| - pass |
| 42 | + ... |
| 43 | + |
| 44 | + |
| 45 | +SUBSCRIPTION_TASK_CLEANUP_TIME = 60 |
| 46 | + |
| 47 | + |
| 48 | +class SubscriptionPusher: |
| 49 | + def __init__(self): |
| 50 | + self._queues: Dict[Subscription, asyncio.Queue] = {} |
| 51 | + self._tasks: Dict[Subscription, asyncio.Task] = {} |
| 52 | + self._queue_flush_times: OrderedDict[Subscription, int] = OrderedDict() |
| 53 | + self._queues_lock = asyncio.Lock() |
| 54 | + self._cleanup_task = asyncio.create_task(self._cleanup_queues) |
| 55 | + |
| 56 | + async def trigger(self, subscription: Subscription, data): |
| 57 | + need_create_task = False |
| 58 | + await self._queues_lock.acquire() |
| 59 | + try: |
| 60 | + if not subscription in self._queues: |
| 61 | + self._queues[subscription] = asyncio.Queue() |
| 62 | + need_create_task = True |
| 63 | + finally: |
| 64 | + self._queues_lock.release() |
| 65 | + if need_create_task: |
| 66 | + self._queue_flush_times[subscription] = monotonic_ns() |
| 67 | + self._tasks[subscription] = asyncio.create_task( |
| 68 | + self._handle_queue(subscription) |
| 69 | + ) |
| 70 | + await self._queues[subscription].put(data) |
| 71 | + |
| 72 | + async def _handle_queue(self, subscription: Subscription): |
| 73 | + while True: |
| 74 | + data = await self._queues[subscription].get() |
| 75 | + try: |
| 76 | + await subscription.callback(data) |
| 77 | + except Exception: |
| 78 | + logger.opt(exception=True).warning( |
| 79 | + "Unable to handle subscription {} data {}:", subscription, data |
| 80 | + ) |
| 81 | + if self._queues[subscription].empty(): |
| 82 | + self._queue_flush_times[subscription] = monotonic_ns() |
| 83 | + self._queue_flush_times.move_to_end(subscription) |
| 84 | + |
| 85 | + async def _cleanup_queues(self): |
| 86 | + most_recent_flush_to_delete = monotonic_ns() - SUBSCRIPTION_TASK_CLEANUP_TIME |
| 87 | + await self._queues_lock.acquire() |
| 88 | + try: |
| 89 | + subscriptions_to_delete = [] |
| 90 | + for subscription, last_flush in self._queue_flush_times.items(): |
| 91 | + if last_flush > most_recent_flush_to_delete: |
| 92 | + # We're done |
| 93 | + break |
| 94 | + for subscription in subscriptions_to_delete: |
| 95 | + self._tasks[subscription].cancel() |
| 96 | + del self._tasks[subscription] |
| 97 | + del self._queues[subscription] |
| 98 | + del self._queue_flush_times[subscription] |
| 99 | + finally: |
| 100 | + self._queues_lock.release() |
41 | 101 |
|
42 | 102 |
|
43 | 103 | class EventNotifier:
|
@@ -70,6 +130,7 @@ def __init__(self):
|
70 | 130 | self._on_unsubscribe_events = []
|
71 | 131 | # List of restriction checks to perform on every action on the channel
|
72 | 132 | self._channel_restrictions = []
|
| 133 | + self._subscription_pusher = SubscriptionPusher() |
73 | 134 |
|
74 | 135 | def gen_subscriber_id(self):
|
75 | 136 | return gen_uid()
|
@@ -175,7 +236,7 @@ async def trigger_callback(
|
175 | 236 | subscriber_id: SubscriberId,
|
176 | 237 | subscription: Subscription,
|
177 | 238 | ):
|
178 |
| - await subscription.callback(subscription, data) |
| 239 | + await self._subscription_pusher.trigger(subscription, data) |
179 | 240 |
|
180 | 241 | async def callback_subscribers(
|
181 | 242 | self,
|
|
0 commit comments