Skip to content

Commit a3b0302

Browse files
committed
Buffer messages to wait for reconnect
Signed-off-by: Lucas ONeil <lucasoneil@gmail.com>
1 parent f6c4c6c commit a3b0302

File tree

5 files changed

+85
-33
lines changed

5 files changed

+85
-33
lines changed

oidc-controller/api/routers/acapy_handler.py

Lines changed: 5 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
from ..db.session import get_db
1212

1313
from ..core.config import settings
14-
from ..routers.socketio import sio, connections_reload
14+
from ..routers.socketio import buffered_emit, connections_reload
1515

1616
logger: structlog.typing.FilteringBoundLogger = structlog.getLogger(__name__)
1717

@@ -39,9 +39,6 @@ async def post_topic(request: Request, topic: str, db: Database = Depends(get_db
3939

4040
# Get the saved websocket session
4141
pid = str(auth_session.id)
42-
connections = connections_reload()
43-
sid = connections.get(pid)
44-
logger.debug(f"sid: {sid} found for pid: {pid}")
4542

4643
if webhook_body["state"] == "presentation-received":
4744
logger.info("presentation-received")
@@ -51,12 +48,10 @@ async def post_topic(request: Request, topic: str, db: Database = Depends(get_db
5148
if webhook_body["verified"] == "true":
5249
auth_session.proof_status = AuthSessionState.VERIFIED
5350
auth_session.presentation_exchange = webhook_body["by_format"]
54-
if sid:
55-
await sio.emit("status", {"status": "verified"}, to=sid)
51+
await buffered_emit("status", {"status": "verified"}, to_pid=pid)
5652
else:
5753
auth_session.proof_status = AuthSessionState.FAILED
58-
if sid:
59-
await sio.emit("status", {"status": "failed"}, to=sid)
54+
await buffered_emit("status", {"status": "failed"}, to_pid=pid)
6055

6156
await AuthSessionCRUD(db).patch(
6257
str(auth_session.id), AuthSessionPatch(**auth_session.model_dump())
@@ -67,8 +62,7 @@ async def post_topic(request: Request, topic: str, db: Database = Depends(get_db
6762
logger.info("ABANDONED")
6863
logger.info(webhook_body["error_msg"])
6964
auth_session.proof_status = AuthSessionState.ABANDONED
70-
if sid:
71-
await sio.emit("status", {"status": "abandoned"}, to=sid)
65+
await buffered_emit("status", {"status": "abandoned"}, to_pid=pid)
7266

7367
await AuthSessionCRUD(db).patch(
7468
str(auth_session.id), AuthSessionPatch(**auth_session.model_dump())
@@ -93,8 +87,7 @@ async def post_topic(request: Request, topic: str, db: Database = Depends(get_db
9387
):
9488
logger.info("EXPIRED")
9589
auth_session.proof_status = AuthSessionState.EXPIRED
96-
if sid:
97-
await sio.emit("status", {"status": "expired"}, to=sid)
90+
await buffered_emit("status", {"status": "expired"}, to_pid=pid)
9891

9992
await AuthSessionCRUD(db).patch(
10093
str(auth_session.id), AuthSessionPatch(**auth_session.model_dump())

oidc-controller/api/routers/oidc.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@
3131
from ..db.session import get_db
3232

3333
# Access to the websocket
34-
from ..routers.socketio import connections_reload, sio
34+
from ..routers.socketio import buffered_emit, connections_reload
3535

3636
from ..verificationConfigs.crud import VerificationConfigCRUD
3737
from ..verificationConfigs.helpers import VariableSubstitutionError
@@ -58,8 +58,6 @@ async def poll_pres_exch_complete(pid: str, db: Database = Depends(get_db)):
5858
auth_session = await AuthSessionCRUD(db).get(pid)
5959

6060
pid = str(auth_session.id)
61-
connections = connections_reload()
62-
sid = connections.get(pid)
6361

6462
"""
6563
Check if proof is expired. But only if the proof has not been started.
@@ -75,8 +73,7 @@ async def poll_pres_exch_complete(pid: str, db: Database = Depends(get_db)):
7573
str(auth_session.id), AuthSessionPatch(**auth_session.model_dump())
7674
)
7775
# Send message through the websocket.
78-
if sid:
79-
await sio.emit("status", {"status": "expired"}, to=sid)
76+
await buffered_emit("status", {"status": "expired"}, to_pid=pid)
8077

8178
return {"proof_status": auth_session.proof_status}
8279

oidc-controller/api/routers/presentation_request.py

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
from ..authSessions.models import AuthSession, AuthSessionState
1010

1111
from ..core.config import settings
12-
from ..routers.socketio import sio, connections_reload
12+
from ..routers.socketio import buffered_emit, connections_reload
1313
from ..routers.oidc import gen_deep_link
1414
from ..db.session import get_db
1515

@@ -49,16 +49,11 @@ async def send_connectionless_proof_req(
4949
pres_exch_id
5050
)
5151

52-
# Get the websocket session
53-
connections = connections_reload()
54-
sid = connections.get(str(auth_session.id))
55-
5652
# If the qrcode has been scanned, toggle the verified flag
5753
if auth_session.proof_status is AuthSessionState.NOT_STARTED:
5854
auth_session.proof_status = AuthSessionState.PENDING
5955
await AuthSessionCRUD(db).patch(auth_session.id, auth_session)
60-
if sid:
61-
await sio.emit("status", {"status": "pending"}, to=sid)
56+
await buffered_emit("status", {"status": "pending"}, to_pid=auth_session.id)
6257

6358
msg = auth_session.presentation_request_msg
6459

oidc-controller/api/routers/socketio.py

Lines changed: 66 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,14 @@
11
import socketio # For using websockets
22
import logging
3+
import time
34

45
logger = logging.getLogger(__name__)
56

6-
77
connections = {}
8+
message_buffers = {}
9+
buffer_timeout = 60 # Timeout in seconds
810

911
sio = socketio.AsyncServer(async_mode="asgi", cors_allowed_origins="*")
10-
1112
sio_app = socketio.ASGIApp(socketio_server=sio, socketio_path="/ws/socket.io")
1213

1314

@@ -18,18 +19,74 @@ async def connect(sid, socket):
1819

1920
@sio.event
2021
async def initialize(sid, data):
21-
global connections
22-
# Store websocket session matched to the presentation exchange id
23-
connections[data.get("pid")] = sid
22+
global connections, message_buffers
23+
pid = data.get("pid")
24+
connections[pid] = sid
25+
# Initialize buffer if it doesn't exist
26+
if pid not in message_buffers:
27+
message_buffers[pid] = []
2428

2529

2630
@sio.event
2731
async def disconnect(sid):
28-
global connections
32+
global connections, message_buffers
2933
logger.info(f">>> disconnect : sid={sid}")
30-
# Remove websocket session from the store
31-
if len(connections) > 0:
32-
connections = {k: v for k, v in connections.items() if v != sid}
34+
# Find the pid associated with the sid
35+
pid = next((k for k, v in connections.items() if v == sid), None)
36+
if pid:
37+
# Remove pid from connections
38+
del connections[pid]
39+
40+
41+
async def buffered_emit(event, data, to_pid=None):
42+
global connections, message_buffers
43+
44+
connections = connections_reload()
45+
sid = connections.get(to_pid)
46+
logger.debug(f"sid: {sid} found for pid: {to_pid}")
47+
48+
if sid:
49+
try:
50+
await sio.emit(event, data, room=sid)
51+
except:
52+
# If send fails, buffer the message
53+
buffer_message(to_pid, event, data)
54+
else:
55+
# Buffer the message if the target is not connected
56+
buffer_message(to_pid, event, data)
57+
58+
59+
def buffer_message(pid, event, data):
60+
global message_buffers
61+
current_time = time.time()
62+
if pid not in message_buffers:
63+
message_buffers[pid] = []
64+
# Add message with timestamp and event name
65+
message_buffers[pid].append((event, data, current_time))
66+
# Clean up old messages
67+
message_buffers[pid] = [
68+
(msg_event, msg_data, timestamp)
69+
for msg_event, msg_data, timestamp in message_buffers[pid]
70+
if current_time - timestamp <= buffer_timeout
71+
]
72+
73+
74+
@sio.event
75+
async def fetch_buffered_messages(sid, pid):
76+
global message_buffers
77+
current_time = time.time()
78+
if pid in message_buffers:
79+
# Filter messages that are still valid (i.e., within the buffer_timeout)
80+
valid_messages = [
81+
(msg_event, msg_data, timestamp)
82+
for msg_event, msg_data, timestamp in message_buffers[pid]
83+
if current_time - timestamp <= buffer_timeout
84+
]
85+
# Emit each valid message
86+
for event, data, _ in valid_messages:
87+
await sio.emit(event, data, room=sid)
88+
# Reassign the valid_messages back to message_buffers[pid] to clean up old messages
89+
message_buffers[pid] = valid_messages
3390

3491

3592
def connections_reload():

oidc-controller/api/templates/verified_credentials.html

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -112,6 +112,14 @@ <h1 class="mb-3 fw-bolder fs-1">Continue with:</h1>
112112
>
113113
DEBUG Disconnect Web Socket
114114
</button>
115+
116+
<button
117+
class="btn btn-primary mt-4"
118+
v-on:click="socket.connect()"
119+
title="Reconnect Websocket"
120+
>
121+
DEBUG Reconnect Web Socket
122+
</button>
115123
</div>
116124

117125
<hr v-if="mobileDevice" />
@@ -383,6 +391,8 @@ <h5 v-if="state.showScanned" class="fw-bolder mb-3">
383391
`Socket connecting. SID: ${this.socket.id}. PID: {{pid}}. Recovered? ${this.socket.recovered} `
384392
);
385393
this.socket.emit("initialize", { pid: "{{pid}}" });
394+
// Emit the `fetch_buffered_messages` event with `pid` as a string using Jinja templating
395+
this.socket.emit('fetch_buffered_messages', '{{ pid }}');
386396
});
387397

388398
this.socket.on("connect_error", (error) => {

0 commit comments

Comments
 (0)