14
14
15
15
from __future__ import annotations
16
16
17
+ import dataclasses
17
18
import json
18
19
import logging
19
20
from pathlib import Path
25
26
import uuid
26
27
27
28
try :
28
- from a2a .client import A2AClient
29
+ from a2a .client import Client as A2AClient
30
+ from a2a .client import ClientEvent as A2AClientEvent
29
31
from a2a .client .card_resolver import A2ACardResolver
32
+ from a2a .client .client import ClientConfig as A2AClientConfig
33
+ from a2a .client .client_factory import ClientFactory as A2AClientFactory
34
+ from a2a .client .errors import A2AClientError
30
35
from a2a .types import AgentCard
31
36
from a2a .types import Message as A2AMessage
32
- from a2a .types import MessageSendParams as A2AMessageSendParams
33
37
from a2a .types import Part as A2APart
34
38
from a2a .types import Role
35
- from a2a .types import SendMessageRequest
36
- from a2a .types import SendMessageSuccessResponse
37
- from a2a .types import Task as A2ATask
39
+ from a2a .types import TransportProtocol as A2ATransport
38
40
except ImportError as e :
39
41
import sys
40
42
@@ -125,6 +127,7 @@ def __init__(
125
127
timeout : float = DEFAULT_TIMEOUT ,
126
128
genai_part_converter : GenAIPartToA2APartConverter = convert_genai_part_to_a2a_part ,
127
129
a2a_part_converter : A2APartToGenAIPartConverter = convert_a2a_part_to_genai_part ,
130
+ a2a_client_factory : Optional [A2AClientFactory ] = None ,
128
131
** kwargs : Any ,
129
132
) -> None :
130
133
"""Initialize RemoteA2aAgent.
@@ -133,8 +136,11 @@ def __init__(
133
136
name: Agent name (must be unique identifier)
134
137
agent_card: AgentCard object, URL string, or file path string
135
138
description: Agent description (auto-populated from card if empty)
136
- httpx_client: Optional shared HTTP client (will create own if not provided)
139
+ httpx_client: Optional shared HTTP client (will create own if not
140
+ provided) [deprecated] Use a2a_client_factory instead.
137
141
timeout: HTTP timeout in seconds
142
+ a2a_client_factory: Optional A2AClientFactory object (will create own if
143
+ not provided)
138
144
**kwargs: Additional arguments passed to BaseAgent
139
145
140
146
Raises:
@@ -148,14 +154,18 @@ def __init__(
148
154
149
155
self ._agent_card : Optional [AgentCard ] = None
150
156
self ._agent_card_source : Optional [str ] = None
151
- self ._rpc_url : Optional [str ] = None
152
157
self ._a2a_client : Optional [A2AClient ] = None
158
+ # This is stored to support backward compatible usage of class.
159
+ # In future, the client is expected to be present in the factory.
153
160
self ._httpx_client = httpx_client
154
- self ._httpx_client_needs_cleanup = httpx_client is None
161
+ if a2a_client_factory and a2a_client_factory ._config .httpx_client :
162
+ self ._httpx_client = a2a_client_factory ._config .httpx_client
163
+ self ._httpx_client_needs_cleanup = self ._httpx_client is None
155
164
self ._timeout = timeout
156
165
self ._is_resolved = False
157
166
self ._genai_part_converter = genai_part_converter
158
167
self ._a2a_part_converter = a2a_part_converter
168
+ self ._a2a_client_factory : Optional [A2AClientFactory ] = a2a_client_factory
159
169
160
170
# Validate and store agent card reference
161
171
if isinstance (agent_card , AgentCard ):
@@ -177,6 +187,21 @@ async def _ensure_httpx_client(self) -> httpx.AsyncClient:
177
187
timeout = httpx .Timeout (timeout = self ._timeout )
178
188
)
179
189
self ._httpx_client_needs_cleanup = True
190
+ if self ._a2a_client_factory :
191
+ self ._a2a_client_factory = A2AClientFactory (
192
+ config = dataclasses .replace (
193
+ self ._a2a_client_factory ._config ,
194
+ httpx_client = self ._httpx_client ,
195
+ )
196
+ )
197
+ if not self ._a2a_client_factory :
198
+ client_config = A2AClientConfig (
199
+ httpx_client = self ._httpx_client ,
200
+ streaming = False ,
201
+ polling = False ,
202
+ supported_transports = [A2ATransport .jsonrpc ],
203
+ )
204
+ self ._a2a_client_factory = A2AClientFactory (config = client_config )
180
205
return self ._httpx_client
181
206
182
207
async def _resolve_agent_card_from_url (self , url : str ) -> AgentCard :
@@ -251,32 +276,29 @@ async def _validate_agent_card(self, agent_card: AgentCard) -> None:
251
276
252
277
async def _ensure_resolved (self ) -> None :
253
278
"""Ensures agent card is resolved, RPC URL is determined, and A2A client is initialized."""
254
- if self ._is_resolved :
279
+ if self ._is_resolved and self . _a2a_client :
255
280
return
256
281
257
282
try :
258
- # Resolve agent card if needed
259
283
if not self ._agent_card :
260
- self ._agent_card = await self ._resolve_agent_card ()
261
284
262
- # Validate agent card
263
- await self ._validate_agent_card (self ._agent_card )
285
+ # Resolve agent card if needed
286
+ if not self ._agent_card :
287
+ self ._agent_card = await self ._resolve_agent_card ()
264
288
265
- # Set RPC URL
266
- self . _rpc_url = str (self ._agent_card . url )
289
+ # Validate agent card
290
+ await self . _validate_agent_card (self ._agent_card )
267
291
268
- # Update description if empty
269
- if not self .description and self ._agent_card .description :
270
- self .description = self ._agent_card .description
292
+ # Update description if empty
293
+ if not self .description and self ._agent_card .description :
294
+ self .description = self ._agent_card .description
271
295
272
296
# Initialize A2A client
273
297
if not self ._a2a_client :
274
- httpx_client = await self ._ensure_httpx_client ()
275
- self ._a2a_client = A2AClient (
276
- httpx_client = httpx_client ,
277
- agent_card = self ._agent_card ,
278
- url = self ._rpc_url ,
279
- )
298
+ await self ._ensure_httpx_client ()
299
+ # This should be assured via ensure_httpx_client
300
+ if self ._a2a_client_factory :
301
+ self ._a2a_client = self ._a2a_client_factory .create (self ._agent_card )
280
302
281
303
self ._is_resolved = True
282
304
logger .info ("Successfully resolved remote A2A agent: %s" , self .name )
@@ -289,7 +311,7 @@ async def _ensure_resolved(self) -> None:
289
311
290
312
def _create_a2a_request_for_user_function_response (
291
313
self , ctx : InvocationContext
292
- ) -> Optional [SendMessageRequest ]:
314
+ ) -> Optional [A2AMessage ]:
293
315
"""Create A2A request for user function response if applicable.
294
316
295
317
Args:
@@ -323,12 +345,7 @@ def _create_a2a_request_for_user_function_response(
323
345
else None
324
346
)
325
347
326
- return SendMessageRequest (
327
- id = str (uuid .uuid4 ()),
328
- params = A2AMessageSendParams (
329
- message = a2a_message ,
330
- ),
331
- )
348
+ return a2a_message
332
349
333
350
def _construct_message_parts_from_session (
334
351
self , ctx : InvocationContext
@@ -371,7 +388,7 @@ def _construct_message_parts_from_session(
371
388
return message_parts [::- 1 ], context_id
372
389
373
390
async def _handle_a2a_response (
374
- self , a2a_response : Any , ctx : InvocationContext
391
+ self , a2a_response : A2AClientEvent | A2AMessage , ctx : InvocationContext
375
392
) -> Event :
376
393
"""Handle A2A response and convert to Event.
377
394
@@ -383,63 +400,36 @@ async def _handle_a2a_response(
383
400
Event object representing the response
384
401
"""
385
402
try :
386
- if isinstance (a2a_response .root , SendMessageSuccessResponse ):
387
- if a2a_response .root .result :
388
- if isinstance (a2a_response .root .result , A2ATask ):
389
- event = convert_a2a_task_to_event (
390
- a2a_response .root .result ,
391
- self .name ,
392
- ctx ,
393
- self ._a2a_part_converter ,
394
- )
395
- event .custom_metadata = event .custom_metadata or {}
396
- event .custom_metadata [A2A_METADATA_PREFIX + "task_id" ] = (
397
- a2a_response .root .result .id
398
- )
403
+ if isinstance (a2a_response , tuple ):
404
+ # ClientEvent is a tuple of the absolute Task state and the last update.
405
+ # We only need the Task state.
406
+ task = a2a_response [0 ]
407
+ event = convert_a2a_task_to_event (task , self .name , ctx )
408
+ event .custom_metadata = event .custom_metadata or {}
409
+ event .custom_metadata [A2A_METADATA_PREFIX + "task_id" ] = task .id
410
+ if task .context_id :
411
+ event .custom_metadata [A2A_METADATA_PREFIX + "context_id" ] = (
412
+ task .context_id
413
+ )
399
414
400
- else :
401
- event = convert_a2a_message_to_event (
402
- a2a_response .root .result ,
403
- self .name ,
404
- ctx ,
405
- self ._a2a_part_converter ,
406
- )
407
- event .custom_metadata = event .custom_metadata or {}
408
- if a2a_response .root .result .task_id :
409
- event .custom_metadata [A2A_METADATA_PREFIX + "task_id" ] = (
410
- a2a_response .root .result .task_id
411
- )
412
-
413
- if a2a_response .root .result .context_id :
414
- event .custom_metadata [A2A_METADATA_PREFIX + "context_id" ] = (
415
- a2a_response .root .result .context_id
416
- )
415
+ # Otherwise, it's a regular A2AMessage.
416
+ elif isinstance (a2a_response , A2AMessage ):
417
+ event = convert_a2a_message_to_event (a2a_response , self .name , ctx )
418
+ event .custom_metadata = event .custom_metadata or {}
417
419
418
- else :
419
- logger .warning ("A2A response has no result: %s" , a2a_response .root )
420
- event = Event (
421
- author = self .name ,
422
- invocation_id = ctx .invocation_id ,
423
- branch = ctx .branch ,
420
+ if a2a_response .context_id :
421
+ event .custom_metadata [A2A_METADATA_PREFIX + "context_id" ] = (
422
+ a2a_response .context_id
424
423
)
425
424
else :
426
- # Handle error response
427
- error_response = a2a_response .root
428
- logger .error (
429
- "A2A request failed with error: %s, data: %s" ,
430
- error_response .error .message ,
431
- error_response .error .data ,
432
- )
433
425
event = Event (
434
426
author = self .name ,
435
- error_message = error_response .error .message ,
436
- error_code = str (error_response .error .code ),
427
+ error_message = "Unknown A2A response type" ,
437
428
invocation_id = ctx .invocation_id ,
438
429
branch = ctx .branch ,
439
430
)
440
-
441
431
return event
442
- except Exception as e :
432
+ except A2AClientError as e :
443
433
logger .error ("Failed to handle A2A response: %s" , e )
444
434
return Event (
445
435
author = self .name ,
@@ -482,36 +472,33 @@ async def _run_async_impl(
482
472
)
483
473
return
484
474
485
- a2a_request = SendMessageRequest (
486
- id = str (uuid .uuid4 ()),
487
- params = A2AMessageSendParams (
488
- message = A2AMessage (
489
- message_id = str (uuid .uuid4 ()),
490
- parts = message_parts ,
491
- role = "user" ,
492
- context_id = context_id ,
493
- )
494
- ),
475
+ a2a_request = A2AMessage (
476
+ message_id = str (uuid .uuid4 ()),
477
+ parts = message_parts ,
478
+ role = "user" ,
479
+ context_id = context_id ,
495
480
)
496
481
497
482
logger .debug (build_a2a_request_log (a2a_request ))
498
483
499
484
try :
500
- a2a_response = await self ._a2a_client .send_message (request = a2a_request )
501
- logger .debug (build_a2a_response_log (a2a_response ))
485
+ async for a2a_response in self ._a2a_client .send_message (
486
+ request = a2a_request
487
+ ):
488
+ logger .debug (build_a2a_response_log (a2a_response ))
502
489
503
- event = await self ._handle_a2a_response (a2a_response , ctx )
490
+ event = await self ._handle_a2a_response (a2a_response , ctx )
504
491
505
- # Add metadata about the request and response
506
- event .custom_metadata = event .custom_metadata or {}
507
- event .custom_metadata [A2A_METADATA_PREFIX + "request" ] = (
508
- a2a_request .model_dump (exclude_none = True , by_alias = True )
509
- )
510
- event .custom_metadata [A2A_METADATA_PREFIX + "response" ] = (
511
- a2a_response . root .model_dump (exclude_none = True , by_alias = True )
512
- )
492
+ # Add metadata about the request and response
493
+ event .custom_metadata = event .custom_metadata or {}
494
+ event .custom_metadata [A2A_METADATA_PREFIX + "request" ] = (
495
+ a2a_request .model_dump (exclude_none = True , by_alias = True )
496
+ )
497
+ event .custom_metadata [A2A_METADATA_PREFIX + "response" ] = (
498
+ a2a_response .model_dump (exclude_none = True , by_alias = True )
499
+ )
513
500
514
- yield event
501
+ yield event
515
502
516
503
except Exception as e :
517
504
error_message = f"A2A request failed: { e } "
0 commit comments