Skip to content

Commit 43f84cb

Browse files
author
Shaul Kremer
committed
Add internal queues for pubsub message transmission to clients.
1 parent 8c08e11 commit 43f84cb

File tree

1 file changed

+72
-2
lines changed

1 file changed

+72
-2
lines changed

fastapi_websocket_pubsub/event_notifier.py

Lines changed: 72 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
import asyncio
22
import copy
3+
from collections import OrderedDict
4+
from time import monotonic_ns
35
from typing import Any, Callable, Coroutine, Dict, List, Optional, Union
46

57
from fastapi_websocket_rpc import RpcChannel
@@ -37,7 +39,74 @@ class Subscription(BaseModel):
3739

3840
# Publish event callback signature
3941
def EventCallback(subscription: Subscription, data: Any):
40-
pass
42+
...
43+
44+
45+
SUBSCRIPTION_TASK_CLEANUP_TIME = 60
46+
47+
48+
class SubscriptionPusher:
49+
def __init__(self):
50+
self._queues: Dict[str, asyncio.Queue] = {}
51+
self._tasks: Dict[str, asyncio.Task] = {}
52+
self._queue_flush_times: OrderedDict[str, int] = OrderedDict()
53+
self._queues_lock = asyncio.Lock()
54+
self._cleanup_task = asyncio.create_task(self._cleanup_queues)
55+
56+
async def trigger(self, subscription: Subscription, data):
57+
need_create_task = False
58+
await self._queues_lock.acquire()
59+
try:
60+
if not subscription.id in self._queues:
61+
self._queues[subscription.id] = asyncio.Queue()
62+
need_create_task = True
63+
finally:
64+
self._queues_lock.release()
65+
if need_create_task:
66+
self._queue_flush_times[subscription.id] = monotonic_ns()
67+
self._tasks[subscription.id] = asyncio.create_task(
68+
self._handle_queue(subscription.id)
69+
)
70+
await self._queues[subscription].put(data)
71+
72+
async def _handle_queue(self, subscription: Subscription):
73+
while True:
74+
data = await self._queues[subscription.id].get()
75+
try:
76+
await subscription.callback(data)
77+
except Exception:
78+
logger.opt(exception=True).warning(
79+
"Unable to handle subscription {} data {}:", subscription, data
80+
)
81+
if self._queues[subscription.id].empty():
82+
self._queue_flush_times[subscription.id] = monotonic_ns()
83+
self._queue_flush_times.move_to_end(subscription.id)
84+
85+
async def _cleanup_queues(self):
86+
while True:
87+
most_recent_flush_to_delete = monotonic_ns() - SUBSCRIPTION_TASK_CLEANUP_TIME
88+
await self._queues_lock.acquire()
89+
try:
90+
# This code is safe because there are no any awaits except
91+
# for the lock, if this changes then we need to also lock
92+
# the unlocked segment of trigger and _handle_queue with a
93+
# per-subscription lock and hold it while we delete the
94+
# subscription so we don't miss any messages
95+
subscriptions_to_delete = []
96+
for subscription_id, last_flush in self._queue_flush_times.items():
97+
if last_flush > most_recent_flush_to_delete:
98+
# We're done
99+
break
100+
if self._queues[subscription_id].empty():
101+
subscriptions_to_delete.append(subscription_id)
102+
for subscription_id in subscriptions_to_delete:
103+
self._tasks[subscription_id].cancel()
104+
del self._tasks[subscription_id]
105+
del self._queues[subscription_id]
106+
del self._queue_flush_times[subscription_id]
107+
finally:
108+
self._queues_lock.release()
109+
await asyncio.sleep(1)
41110

42111

43112
class EventNotifier:
@@ -70,6 +139,7 @@ def __init__(self):
70139
self._on_unsubscribe_events = []
71140
# List of restriction checks to perform on every action on the channel
72141
self._channel_restrictions = []
142+
self._subscription_pusher = SubscriptionPusher()
73143

74144
def gen_subscriber_id(self):
75145
return gen_uid()
@@ -175,7 +245,7 @@ async def trigger_callback(
175245
subscriber_id: SubscriberId,
176246
subscription: Subscription,
177247
):
178-
await subscription.callback(subscription, data)
248+
await self._subscription_pusher.trigger(subscription, data)
179249

180250
async def callback_subscribers(
181251
self,

0 commit comments

Comments
 (0)