Skip to content

Commit 8015b9d

Browse files
author
Shaul Kremer
committed
Add internal queues for pubsub message transmission to clients.
1 parent 8c08e11 commit 8015b9d

File tree

1 file changed

+63
-2
lines changed

1 file changed

+63
-2
lines changed

fastapi_websocket_pubsub/event_notifier.py

Lines changed: 63 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
import asyncio
22
import copy
3+
from collections import OrderedDict
4+
from time import monotonic_ns
35
from typing import Any, Callable, Coroutine, Dict, List, Optional, Union
46

57
from fastapi_websocket_rpc import RpcChannel
@@ -37,7 +39,65 @@ class Subscription(BaseModel):
3739

3840
# Publish event callback signature
3941
def EventCallback(subscription: Subscription, data: Any):
40-
pass
42+
...
43+
44+
45+
SUBSCRIPTION_TASK_CLEANUP_TIME = 60
46+
47+
48+
class SubscriptionPusher:
49+
def __init__(self):
50+
self._queues: Dict[Subscription, asyncio.Queue] = {}
51+
self._tasks: Dict[Subscription, asyncio.Task] = {}
52+
self._queue_flush_times: OrderedDict[Subscription, int] = OrderedDict()
53+
self._queues_lock = asyncio.Lock()
54+
self._cleanup_task = asyncio.create_task(self._cleanup_queues)
55+
56+
async def trigger(self, subscription: Subscription, data):
57+
need_create_task = False
58+
await self._queues_lock.acquire()
59+
try:
60+
if not subscription in self._queues:
61+
self._queues[subscription] = asyncio.Queue()
62+
need_create_task = True
63+
finally:
64+
self._queues_lock.release()
65+
if need_create_task:
66+
self._queue_flush_times[subscription] = monotonic_ns()
67+
self._tasks[subscription] = asyncio.create_task(
68+
self._handle_queue(subscription)
69+
)
70+
await self._queues[subscription].put(data)
71+
72+
async def _handle_queue(self, subscription: Subscription):
73+
while True:
74+
data = await self._queues[subscription].get()
75+
try:
76+
await subscription.callback(data)
77+
except Exception:
78+
logger.opt(exception=True).warning(
79+
"Unable to handle subscription {} data {}:", subscription, data
80+
)
81+
if self._queues[subscription].empty():
82+
self._queue_flush_times[subscription] = monotonic_ns()
83+
self._queue_flush_times.move_to_end(subscription)
84+
85+
async def _cleanup_queues(self):
86+
most_recent_flush_to_delete = monotonic_ns() - SUBSCRIPTION_TASK_CLEANUP_TIME
87+
await self._queues_lock.acquire()
88+
try:
89+
subscriptions_to_delete = []
90+
for subscription, last_flush in self._queue_flush_times.items():
91+
if last_flush > most_recent_flush_to_delete:
92+
# We're done
93+
break
94+
for subscription in subscriptions_to_delete:
95+
self._tasks[subscription].cancel()
96+
del self._tasks[subscription]
97+
del self._queues[subscription]
98+
del self._queue_flush_times[subscription]
99+
finally:
100+
self._queues_lock.release()
41101

42102

43103
class EventNotifier:
@@ -70,6 +130,7 @@ def __init__(self):
70130
self._on_unsubscribe_events = []
71131
# List of restriction checks to perform on every action on the channel
72132
self._channel_restrictions = []
133+
self._subscription_pusher = SubscriptionPusher()
73134

74135
def gen_subscriber_id(self):
75136
return gen_uid()
@@ -175,7 +236,7 @@ async def trigger_callback(
175236
subscriber_id: SubscriberId,
176237
subscription: Subscription,
177238
):
178-
await subscription.callback(subscription, data)
239+
await self._subscription_pusher.trigger(subscription, data)
179240

180241
async def callback_subscribers(
181242
self,

0 commit comments

Comments
 (0)