Skip to content

Commit df8ae7e

Browse files
author
Shaul Kremer
committed
Add internal queues for pubsub message transmission to clients.
1 parent 8c08e11 commit df8ae7e

File tree

1 file changed

+72
-2
lines changed

1 file changed

+72
-2
lines changed

fastapi_websocket_pubsub/event_notifier.py

Lines changed: 72 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
import asyncio
22
import copy
3+
from collections import OrderedDict
4+
from time import monotonic_ns
35
from typing import Any, Callable, Coroutine, Dict, List, Optional, Union
46

57
from fastapi_websocket_rpc import RpcChannel
@@ -37,7 +39,73 @@ class Subscription(BaseModel):
3739

3840
# Publish event callback signature
3941
def EventCallback(subscription: Subscription, data: Any):
40-
pass
42+
...
43+
44+
45+
SUBSCRIPTION_TASK_CLEANUP_TIME = 60
46+
47+
48+
class SubscriptionPusher:
49+
def __init__(self):
50+
self._queues: Dict[str, asyncio.Queue] = {}
51+
self._tasks: Dict[str, asyncio.Task] = {}
52+
self._queue_flush_times: OrderedDict[str, int] = OrderedDict()
53+
self._queues_lock = asyncio.Lock()
54+
self._cleanup_task = asyncio.create_task(self._cleanup_queues())
55+
56+
async def trigger(self, subscription: Subscription, data):
57+
need_create_task = False
58+
await self._queues_lock.acquire()
59+
try:
60+
if not subscription.id in self._queues:
61+
self._queues[subscription.id] = asyncio.Queue()
62+
need_create_task = True
63+
finally:
64+
self._queues_lock.release()
65+
if need_create_task:
66+
self._queue_flush_times[subscription.id] = monotonic_ns()
67+
self._tasks[subscription.id] = asyncio.create_task(
68+
self._handle_queue(subscription)
69+
)
70+
await self._queues[subscription.id].put(data)
71+
72+
async def _handle_queue(self, subscription: Subscription):
73+
while True:
74+
try:
75+
data = await self._queues[subscription.id].get()
76+
try:
77+
await subscription.callback(subscription, data)
78+
except Exception:
79+
logger.exception(
80+
"Unable to handle subscription %r data %r:", subscription, data
81+
)
82+
if self._queues[subscription.id].empty():
83+
self._queue_flush_times[subscription.id] = monotonic_ns()
84+
self._queue_flush_times.move_to_end(subscription.id)
85+
except Exception:
86+
logger.exception("Handle failed:")
87+
88+
async def _cleanup_queues(self):
89+
while True:
90+
most_recent_flush_to_delete = monotonic_ns() - SUBSCRIPTION_TASK_CLEANUP_TIME
91+
await self._queues_lock.acquire()
92+
try:
93+
# This code is safe because there are no any awaits except for the lock, if this changes then we need to more locks
94+
subscriptions_to_delete = []
95+
for subscription_id, last_flush in self._queue_flush_times.items():
96+
if last_flush > most_recent_flush_to_delete:
97+
# We're done
98+
break
99+
if self._queues[subscription_id].empty():
100+
subscriptions_to_delete.append(subscription_id)
101+
for subscription_id in subscriptions_to_delete:
102+
self._tasks[subscription_id].cancel()
103+
del self._tasks[subscription_id]
104+
del self._queues[subscription_id]
105+
del self._queue_flush_times[subscription_id]
106+
finally:
107+
self._queues_lock.release()
108+
await asyncio.sleep(1)
41109

42110

43111
class EventNotifier:
@@ -70,6 +138,7 @@ def __init__(self):
70138
self._on_unsubscribe_events = []
71139
# List of restriction checks to perform on every action on the channel
72140
self._channel_restrictions = []
141+
self._subscription_pusher = SubscriptionPusher()
73142

74143
def gen_subscriber_id(self):
75144
return gen_uid()
@@ -175,7 +244,8 @@ async def trigger_callback(
175244
subscriber_id: SubscriberId,
176245
subscription: Subscription,
177246
):
178-
await subscription.callback(subscription, data)
247+
logger.info("trigger on %s %s", subscription.id, topic)
248+
await self._subscription_pusher.trigger(subscription, data)
179249

180250
async def callback_subscribers(
181251
self,

0 commit comments

Comments
 (0)