Skip to content

Commit fa55354

Browse files
google-genai-botcopybara-github
authored andcommitted
feat: Update to ADK + A2A Remote Client to use A2A SDK ClientFactory
Use the A2A Python SDK for client support for A2A Remote clients. This enables A2A based agents that use gRPC or RESTful interfaces, as well as the jsonrpc support. This also simplifies creation of clients and provides simpler mechanisms to inject credentials and observability into the remote agent interactions. PiperOrigin-RevId: 804711466
1 parent 64f11a6 commit fa55354

File tree

3 files changed

+723
-264
lines changed

3 files changed

+723
-264
lines changed

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,7 @@ dev = [
8484

8585
a2a = [
8686
# go/keep-sorted start
87-
"a2a-sdk>=0.3.0,<0.4.0;python_version>='3.10'",
87+
"a2a-sdk>=0.3.4,<0.4.0;python_version>='3.10'",
8888
# go/keep-sorted end
8989
]
9090

src/google/adk/agents/remote_a2a_agent.py

Lines changed: 88 additions & 101 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414

1515
from __future__ import annotations
1616

17+
import dataclasses
1718
import json
1819
import logging
1920
from pathlib import Path
@@ -25,16 +26,17 @@
2526
import uuid
2627

2728
try:
28-
from a2a.client import A2AClient
29+
from a2a.client import Client as A2AClient
30+
from a2a.client import ClientEvent as A2AClientEvent
2931
from a2a.client.card_resolver import A2ACardResolver
32+
from a2a.client.client import ClientConfig as A2AClientConfig
33+
from a2a.client.client_factory import ClientFactory as A2AClientFactory
34+
from a2a.client.errors import A2AClientError
3035
from a2a.types import AgentCard
3136
from a2a.types import Message as A2AMessage
32-
from a2a.types import MessageSendParams as A2AMessageSendParams
3337
from a2a.types import Part as A2APart
3438
from a2a.types import Role
35-
from a2a.types import SendMessageRequest
36-
from a2a.types import SendMessageSuccessResponse
37-
from a2a.types import Task as A2ATask
39+
from a2a.types import TransportProtocol as A2ATransport
3840
except ImportError as e:
3941
import sys
4042

@@ -125,6 +127,7 @@ def __init__(
125127
timeout: float = DEFAULT_TIMEOUT,
126128
genai_part_converter: GenAIPartToA2APartConverter = convert_genai_part_to_a2a_part,
127129
a2a_part_converter: A2APartToGenAIPartConverter = convert_a2a_part_to_genai_part,
130+
a2a_client_factory: Optional[A2AClientFactory] = None,
128131
**kwargs: Any,
129132
) -> None:
130133
"""Initialize RemoteA2aAgent.
@@ -133,8 +136,11 @@ def __init__(
133136
name: Agent name (must be unique identifier)
134137
agent_card: AgentCard object, URL string, or file path string
135138
description: Agent description (auto-populated from card if empty)
136-
httpx_client: Optional shared HTTP client (will create own if not provided)
139+
httpx_client: Optional shared HTTP client (will create own if not
140+
provided) [deprecated] Use a2a_client_factory instead.
137141
timeout: HTTP timeout in seconds
142+
a2a_client_factory: Optional A2AClientFactory object (will create own if
143+
not provided)
138144
**kwargs: Additional arguments passed to BaseAgent
139145
140146
Raises:
@@ -148,14 +154,18 @@ def __init__(
148154

149155
self._agent_card: Optional[AgentCard] = None
150156
self._agent_card_source: Optional[str] = None
151-
self._rpc_url: Optional[str] = None
152157
self._a2a_client: Optional[A2AClient] = None
158+
# This is stored to support backward compatible usage of class.
159+
# In future, the client is expected to be present in the factory.
153160
self._httpx_client = httpx_client
154-
self._httpx_client_needs_cleanup = httpx_client is None
161+
if a2a_client_factory and a2a_client_factory._config.httpx_client:
162+
self._httpx_client = a2a_client_factory._config.httpx_client
163+
self._httpx_client_needs_cleanup = self._httpx_client is None
155164
self._timeout = timeout
156165
self._is_resolved = False
157166
self._genai_part_converter = genai_part_converter
158167
self._a2a_part_converter = a2a_part_converter
168+
self._a2a_client_factory: Optional[A2AClientFactory] = a2a_client_factory
159169

160170
# Validate and store agent card reference
161171
if isinstance(agent_card, AgentCard):
@@ -177,6 +187,21 @@ async def _ensure_httpx_client(self) -> httpx.AsyncClient:
177187
timeout=httpx.Timeout(timeout=self._timeout)
178188
)
179189
self._httpx_client_needs_cleanup = True
190+
if self._a2a_client_factory:
191+
self._a2a_client_factory = A2AClientFactory(
192+
config=dataclasses.replace(
193+
self._a2a_client_factory._config,
194+
httpx_client=self._httpx_client,
195+
)
196+
)
197+
if not self._a2a_client_factory:
198+
client_config = A2AClientConfig(
199+
httpx_client=self._httpx_client,
200+
streaming=False,
201+
polling=False,
202+
supported_transports=[A2ATransport.jsonrpc],
203+
)
204+
self._a2a_client_factory = A2AClientFactory(config=client_config)
180205
return self._httpx_client
181206

182207
async def _resolve_agent_card_from_url(self, url: str) -> AgentCard:
@@ -251,32 +276,29 @@ async def _validate_agent_card(self, agent_card: AgentCard) -> None:
251276

252277
async def _ensure_resolved(self) -> None:
253278
"""Ensures agent card is resolved, RPC URL is determined, and A2A client is initialized."""
254-
if self._is_resolved:
279+
if self._is_resolved and self._a2a_client:
255280
return
256281

257282
try:
258-
# Resolve agent card if needed
259283
if not self._agent_card:
260-
self._agent_card = await self._resolve_agent_card()
261284

262-
# Validate agent card
263-
await self._validate_agent_card(self._agent_card)
285+
# Resolve agent card if needed
286+
if not self._agent_card:
287+
self._agent_card = await self._resolve_agent_card()
264288

265-
# Set RPC URL
266-
self._rpc_url = str(self._agent_card.url)
289+
# Validate agent card
290+
await self._validate_agent_card(self._agent_card)
267291

268-
# Update description if empty
269-
if not self.description and self._agent_card.description:
270-
self.description = self._agent_card.description
292+
# Update description if empty
293+
if not self.description and self._agent_card.description:
294+
self.description = self._agent_card.description
271295

272296
# Initialize A2A client
273297
if not self._a2a_client:
274-
httpx_client = await self._ensure_httpx_client()
275-
self._a2a_client = A2AClient(
276-
httpx_client=httpx_client,
277-
agent_card=self._agent_card,
278-
url=self._rpc_url,
279-
)
298+
await self._ensure_httpx_client()
299+
# This should be assured via ensure_httpx_client
300+
if self._a2a_client_factory:
301+
self._a2a_client = self._a2a_client_factory.create(self._agent_card)
280302

281303
self._is_resolved = True
282304
logger.info("Successfully resolved remote A2A agent: %s", self.name)
@@ -289,7 +311,7 @@ async def _ensure_resolved(self) -> None:
289311

290312
def _create_a2a_request_for_user_function_response(
291313
self, ctx: InvocationContext
292-
) -> Optional[SendMessageRequest]:
314+
) -> Optional[A2AMessage]:
293315
"""Create A2A request for user function response if applicable.
294316
295317
Args:
@@ -323,12 +345,7 @@ def _create_a2a_request_for_user_function_response(
323345
else None
324346
)
325347

326-
return SendMessageRequest(
327-
id=str(uuid.uuid4()),
328-
params=A2AMessageSendParams(
329-
message=a2a_message,
330-
),
331-
)
348+
return a2a_message
332349

333350
def _construct_message_parts_from_session(
334351
self, ctx: InvocationContext
@@ -371,7 +388,7 @@ def _construct_message_parts_from_session(
371388
return message_parts[::-1], context_id
372389

373390
async def _handle_a2a_response(
374-
self, a2a_response: Any, ctx: InvocationContext
391+
self, a2a_response: A2AClientEvent | A2AMessage, ctx: InvocationContext
375392
) -> Event:
376393
"""Handle A2A response and convert to Event.
377394
@@ -383,63 +400,36 @@ async def _handle_a2a_response(
383400
Event object representing the response
384401
"""
385402
try:
386-
if isinstance(a2a_response.root, SendMessageSuccessResponse):
387-
if a2a_response.root.result:
388-
if isinstance(a2a_response.root.result, A2ATask):
389-
event = convert_a2a_task_to_event(
390-
a2a_response.root.result,
391-
self.name,
392-
ctx,
393-
self._a2a_part_converter,
394-
)
395-
event.custom_metadata = event.custom_metadata or {}
396-
event.custom_metadata[A2A_METADATA_PREFIX + "task_id"] = (
397-
a2a_response.root.result.id
398-
)
403+
if isinstance(a2a_response, tuple):
404+
# ClientEvent is a tuple of the absolute Task state and the last update.
405+
# We only need the Task state.
406+
task = a2a_response[0]
407+
event = convert_a2a_task_to_event(task, self.name, ctx)
408+
event.custom_metadata = event.custom_metadata or {}
409+
event.custom_metadata[A2A_METADATA_PREFIX + "task_id"] = task.id
410+
if task.context_id:
411+
event.custom_metadata[A2A_METADATA_PREFIX + "context_id"] = (
412+
task.context_id
413+
)
399414

400-
else:
401-
event = convert_a2a_message_to_event(
402-
a2a_response.root.result,
403-
self.name,
404-
ctx,
405-
self._a2a_part_converter,
406-
)
407-
event.custom_metadata = event.custom_metadata or {}
408-
if a2a_response.root.result.task_id:
409-
event.custom_metadata[A2A_METADATA_PREFIX + "task_id"] = (
410-
a2a_response.root.result.task_id
411-
)
412-
413-
if a2a_response.root.result.context_id:
414-
event.custom_metadata[A2A_METADATA_PREFIX + "context_id"] = (
415-
a2a_response.root.result.context_id
416-
)
415+
# Otherwise, it's a regular A2AMessage.
416+
elif isinstance(a2a_response, A2AMessage):
417+
event = convert_a2a_message_to_event(a2a_response, self.name, ctx)
418+
event.custom_metadata = event.custom_metadata or {}
417419

418-
else:
419-
logger.warning("A2A response has no result: %s", a2a_response.root)
420-
event = Event(
421-
author=self.name,
422-
invocation_id=ctx.invocation_id,
423-
branch=ctx.branch,
420+
if a2a_response.context_id:
421+
event.custom_metadata[A2A_METADATA_PREFIX + "context_id"] = (
422+
a2a_response.context_id
424423
)
425424
else:
426-
# Handle error response
427-
error_response = a2a_response.root
428-
logger.error(
429-
"A2A request failed with error: %s, data: %s",
430-
error_response.error.message,
431-
error_response.error.data,
432-
)
433425
event = Event(
434426
author=self.name,
435-
error_message=error_response.error.message,
436-
error_code=str(error_response.error.code),
427+
error_message="Unknown A2A response type",
437428
invocation_id=ctx.invocation_id,
438429
branch=ctx.branch,
439430
)
440-
441431
return event
442-
except Exception as e:
432+
except A2AClientError as e:
443433
logger.error("Failed to handle A2A response: %s", e)
444434
return Event(
445435
author=self.name,
@@ -482,36 +472,33 @@ async def _run_async_impl(
482472
)
483473
return
484474

485-
a2a_request = SendMessageRequest(
486-
id=str(uuid.uuid4()),
487-
params=A2AMessageSendParams(
488-
message=A2AMessage(
489-
message_id=str(uuid.uuid4()),
490-
parts=message_parts,
491-
role="user",
492-
context_id=context_id,
493-
)
494-
),
475+
a2a_request = A2AMessage(
476+
message_id=str(uuid.uuid4()),
477+
parts=message_parts,
478+
role="user",
479+
context_id=context_id,
495480
)
496481

497482
logger.debug(build_a2a_request_log(a2a_request))
498483

499484
try:
500-
a2a_response = await self._a2a_client.send_message(request=a2a_request)
501-
logger.debug(build_a2a_response_log(a2a_response))
485+
async for a2a_response in self._a2a_client.send_message(
486+
request=a2a_request
487+
):
488+
logger.debug(build_a2a_response_log(a2a_response))
502489

503-
event = await self._handle_a2a_response(a2a_response, ctx)
490+
event = await self._handle_a2a_response(a2a_response, ctx)
504491

505-
# Add metadata about the request and response
506-
event.custom_metadata = event.custom_metadata or {}
507-
event.custom_metadata[A2A_METADATA_PREFIX + "request"] = (
508-
a2a_request.model_dump(exclude_none=True, by_alias=True)
509-
)
510-
event.custom_metadata[A2A_METADATA_PREFIX + "response"] = (
511-
a2a_response.root.model_dump(exclude_none=True, by_alias=True)
512-
)
492+
# Add metadata about the request and response
493+
event.custom_metadata = event.custom_metadata or {}
494+
event.custom_metadata[A2A_METADATA_PREFIX + "request"] = (
495+
a2a_request.model_dump(exclude_none=True, by_alias=True)
496+
)
497+
event.custom_metadata[A2A_METADATA_PREFIX + "response"] = (
498+
a2a_response.model_dump(exclude_none=True, by_alias=True)
499+
)
513500

514-
yield event
501+
yield event
515502

516503
except Exception as e:
517504
error_message = f"A2A request failed: {e}"

0 commit comments

Comments
 (0)