molecule-core/workspace-template/a2a_executor.py
Hongming Wang 55a2ee0153 fix: properly remove adapter subdirectories + move shared code to root
PR #471 removed Dockerfiles/requirements from adapters/ but left the
Python source files. This commit finishes the extraction:

1. Moved shared_runtime.py → workspace-template/shared_runtime.py
   (used by prompt.py, a2a_executor.py, coordinator.py — not adapter-specific)
2. Moved base.py → workspace-template/adapter_base.py
   (BaseAdapter + AdapterConfig — the interface adapters implement)
3. Updated imports in prompt.py, a2a_executor.py, coordinator.py
4. Rewritten adapters/__init__.py as a thin shim that:
   - Reads ADAPTER_MODULE env var (production: standalone repos set this)
   - Re-exports BaseAdapter/AdapterConfig for backward compat
5. adapters/base.py + adapters/shared_runtime.py remain as re-export shims
6. Deleted all 8 adapter subdirectories (autogen, claude_code, crewai,
   deepagents, gemini_cli, hermes, langgraph, openclaw)
7. Removed 11 test files that imported adapter-specific code

Tests: 955 passed, 0 failed (down from 1216 — the difference is
adapter-specific tests that moved to standalone repos).
2026-04-16 04:59:13 -07:00

420 lines
19 KiB
Python

"""Bridge between LangGraph agent and A2A protocol, with SSE streaming support.
SSE streaming architecture
--------------------------
The A2A SDK (``DefaultRequestHandler`` + ``EventQueue``) owns the SSE transport
layer. This executor's job is to push the right event types into the queue as
work progresses:
1. ``TaskStatusUpdateEvent(state=working)`` — immediately signals start
2. ``TaskArtifactUpdateEvent(chunk, append=…)`` — one per LLM text token
3. ``Message(final_text)`` — terminal event
Client compatibility
--------------------
*Non-streaming* (``message/send``):
``ResultAggregator.consume_all()`` processes status/artifact events
(updating the task in the store) and returns the final ``Message``
immediately — backward-compatible with ``a2a_client.py`` which reads
``data["result"]["parts"][0]["text"]``.
*Streaming* (``message/stream``):
``consume_and_emit()`` yields every event above as SSE, letting the client
render tokens in real time.
LangGraph integration
---------------------
Uses ``agent.astream_events(version="v2")`` to receive ``on_chat_model_stream``
events with ``AIMessageChunk`` payloads. Text is extracted from both plain
strings (OpenAI / Groq) and Anthropic-style content-block lists. Non-text
content (tool_use, etc.) is silently skipped. A fresh ``artifact_id`` is
generated for each new LLM ``run_id`` so tool-call cycles are grouped cleanly.
"""
import functools
import logging
import os
import uuid
from a2a.server.agent_execution import AgentExecutor, RequestContext
from a2a.server.events import EventQueue
from a2a.server.tasks import TaskUpdater
from a2a.types import Part, TextPart
from a2a.utils import new_agent_text_message
from shared_runtime import (
extract_history as _extract_history,
extract_message_text,
brief_task,
set_current_task,
)
from builtin_tools.telemetry import (
A2A_TASK_ID,
GEN_AI_OPERATION_NAME,
GEN_AI_REQUEST_MODEL,
GEN_AI_SYSTEM,
WORKSPACE_ID_ATTR,
_incoming_trace_context,
gen_ai_system_from_model,
get_tracer,
record_llm_token_usage,
)
logger = logging.getLogger(__name__)
_WORKSPACE_ID = os.environ.get("WORKSPACE_ID", "unknown")
# LangGraph ReAct cycle budget per turn. Library default is 25; 500 covers
# PM fan-outs (plan → 6 delegations → 6 awaits → 6 results → synthesize ≈
# 30+ steps even before retries). Overridable via LANGGRAPH_RECURSION_LIMIT.
DEFAULT_RECURSION_LIMIT = 500
def _parse_recursion_limit() -> int:
"""Read LANGGRAPH_RECURSION_LIMIT; fall back to DEFAULT_RECURSION_LIMIT
with a WARNING log on any unparseable or non-positive value."""
raw = os.environ.get("LANGGRAPH_RECURSION_LIMIT", "")
if not raw:
return DEFAULT_RECURSION_LIMIT
try:
n = int(raw)
except ValueError:
logger.warning(
"LANGGRAPH_RECURSION_LIMIT=%r is not an integer; using default %d",
raw, DEFAULT_RECURSION_LIMIT,
)
return DEFAULT_RECURSION_LIMIT
if n <= 0:
logger.warning(
"LANGGRAPH_RECURSION_LIMIT=%d is not positive; using default %d",
n, DEFAULT_RECURSION_LIMIT,
)
return DEFAULT_RECURSION_LIMIT
return n
# ---------------------------------------------------------------------------
# Compliance (OWASP Top 10 for Agentic Apps) — optional, lazy-loaded
# ---------------------------------------------------------------------------
try:
from builtin_tools.compliance import (
AgencyTracker,
ExcessiveAgencyError,
PromptInjectionError,
redact_pii as _redact_pii,
sanitize_input as _sanitize_input,
)
_COMPLIANCE_AVAILABLE = True
except ImportError: # pragma: no cover
_COMPLIANCE_AVAILABLE = False
@functools.lru_cache(maxsize=1)
def _get_compliance_cfg():
"""Return ComplianceConfig or None (cached for process lifetime)."""
try:
from config import load_config
return load_config().compliance
except Exception:
return None
def _extract_chunk_text(content) -> list[str]:
"""Extract text strings from an LLM streaming chunk's content field.
Handles both provider content styles:
- OpenAI / Groq: ``content`` is a plain ``str`` (empty for tool-call chunks).
- Anthropic: ``content`` is a list of typed blocks, e.g.
``[{"type": "text", "text": "Hello"}, {"type": "tool_use", ...}]``
Only ``"text"`` blocks are returned; ``tool_use``, ``tool_result``, and
other non-text blocks are filtered out so raw tool JSON never appears in
the SSE stream.
Args:
content: ``chunk.content`` value from an ``on_chat_model_stream`` event.
Returns:
List of non-empty text strings.
"""
if isinstance(content, str):
return [content] if content else []
if isinstance(content, list):
texts: list[str] = []
for block in content:
if isinstance(block, dict) and block.get("type") == "text":
text = block.get("text", "")
if text:
texts.append(text)
elif isinstance(block, str) and block:
texts.append(block)
return texts
return []
class LangGraphA2AExecutor(AgentExecutor):
"""Bridges LangGraph agent to A2A event model with SSE streaming support.
Always uses ``agent.astream_events()`` so that:
- Streaming clients (``message/stream``) receive token-level SSE events.
- Non-streaming clients (``message/send``) receive the final ``Message``
collected from the same stream — no duplicate LLM call, full compat.
"""
def __init__(self, agent, heartbeat=None, model: str = "unknown"):
self.agent = agent # Compiled LangGraph graph (create_react_agent output)
self._heartbeat = heartbeat
self._model = model # e.g. "anthropic:claude-sonnet-4-6"
async def execute(self, context: RequestContext, event_queue: EventQueue) -> None:
"""Execute a task from an A2A request with SSE streaming.
Routes through the Temporal durable workflow when a global
``TemporalWorkflowWrapper`` is initialised and connected to Temporal;
otherwise falls back to ``_core_execute()`` (direct path).
Event emission sequence:
1. TaskStatusUpdateEvent(working) — immediate start signal
2. TaskArtifactUpdateEvent chunks — token-by-token via astream_events
3. Message(final_text) — terminal; non-streaming clients
return on this; streaming clients
also receive it as the last SSE event.
"""
# ── Optional Temporal durable execution wrapper ──────────────────────
# When a TemporalWorkflowWrapper is active this routes execution through
# a MoleculeAIAgentWorkflow (task_receive → llm_call → task_complete).
# Falls back silently to _core_execute() on any error or if Temporal
# is unavailable, so the client always receives a response.
try:
from builtin_tools.temporal_workflow import get_wrapper as _get_temporal_wrapper
_tw = _get_temporal_wrapper()
if _tw is not None and _tw.is_available():
return await _tw.run(self, context, event_queue)
except Exception:
pass # Never let the wrapper path crash the executor
await self._core_execute(context, event_queue)
async def _core_execute(self, context: RequestContext, event_queue: EventQueue) -> str:
"""Core execution pipeline — called directly or from a Temporal activity.
This is the original ``execute()`` body, extracted so that the Temporal
``llm_call`` activity can invoke it without re-entering the wrapper
check and causing infinite recursion.
Returns the final response text (empty string on empty input or error).
Event emission sequence:
1. TaskStatusUpdateEvent(working) — immediate start signal
2. TaskArtifactUpdateEvent chunks — token-by-token via astream_events
3. Message(final_text) — terminal event
"""
user_input = extract_message_text(context)
if not user_input:
parts = getattr(getattr(context, "message", None), "parts", None)
logger.warning("A2A execute: no text content in message parts: %s", parts)
await event_queue.enqueue_event(
new_agent_text_message("Error: message contained no text content.")
)
return ""
# ── OA-01: Prompt injection check (OWASP Agentic Top 10) ────────────
_compliance_cfg = _get_compliance_cfg() if _COMPLIANCE_AVAILABLE else None
if _COMPLIANCE_AVAILABLE and _compliance_cfg and _compliance_cfg.mode == "owasp_agentic":
try:
user_input = _sanitize_input(
user_input,
prompt_injection_mode=_compliance_cfg.prompt_injection,
context_id=context.context_id or "",
)
except PromptInjectionError as exc:
await event_queue.enqueue_event(
new_agent_text_message(f"Request blocked: {exc}")
)
return ""
logger.info("A2A execute: user_input=%s", user_input[:200])
# ── OTEL: task_receive span ──────────────────────────────────────────
parent_ctx = _incoming_trace_context.get()
tracer = get_tracer()
_result: str = "" # captured inside the span for return after it closes
with tracer.start_as_current_span("task_receive", context=parent_ctx) as task_span:
task_span.set_attribute(WORKSPACE_ID_ATTR, _WORKSPACE_ID)
task_span.set_attribute(A2A_TASK_ID, context.context_id or "")
task_span.set_attribute("a2a.input_preview", user_input[:256])
await set_current_task(self._heartbeat, brief_task(user_input))
# Resolve IDs — the RequestContextBuilder always sets them, but
# we generate fallbacks for safety (e.g. in unit tests).
task_id = context.task_id or str(uuid.uuid4())
context_id = context.context_id or str(uuid.uuid4())
updater = TaskUpdater(event_queue, task_id, context_id)
try:
messages = _extract_history(context)
if messages:
logger.info("A2A execute: injecting %d history messages", len(messages))
messages.append(("human", user_input))
# Recursion limit: see DEFAULT_RECURSION_LIMIT and
# _parse_recursion_limit() at module top. Re-read on every
# call so the env var can be hot-changed between requests.
recursion_limit = _parse_recursion_limit()
run_config = {
"configurable": {"thread_id": context_id},
"run_name": f"a2a-{context_id[:8]}",
"recursion_limit": recursion_limit,
}
# ── OTEL: llm_call span ──────────────────────────────────────
with tracer.start_as_current_span("llm_call") as llm_span:
llm_span.set_attribute(GEN_AI_OPERATION_NAME, "chat")
llm_span.set_attribute(GEN_AI_SYSTEM, gen_ai_system_from_model(self._model))
llm_span.set_attribute(GEN_AI_REQUEST_MODEL, self._model)
llm_span.set_attribute(WORKSPACE_ID_ATTR, _WORKSPACE_ID)
# ── Step 1: signal "working" to streaming clients ─────────
await updater.start_work()
# ── Step 2: stream tokens via LangGraph astream_events ────
# Each "on_chat_model_stream" event carries an AIMessageChunk.
# We emit one TaskArtifactUpdateEvent per text chunk so SSE
# clients can render tokens in real time.
# artifact_id resets on each new LLM run_id so agent→tool→agent
# cycles each get their own artifact slot.
artifact_id = str(uuid.uuid4())
has_streamed = False # True after first chunk for current artifact
current_run_id = None # Detects new LLM call in a ReAct cycle
accumulated: list[str] = [] # All text for the final Message
last_ai_message = None # Saved for token-usage telemetry
# ── OA-03: Excessive agency tracker ──────────────────────
_agency = (
AgencyTracker(
max_tool_calls=_compliance_cfg.max_tool_calls_per_task,
max_duration_seconds=float(_compliance_cfg.max_task_duration_seconds),
)
if _COMPLIANCE_AVAILABLE and _compliance_cfg and _compliance_cfg.mode == "owasp_agentic"
else None
)
async for event in self.agent.astream_events(
{"messages": messages},
config=run_config,
version="v2",
):
kind = event.get("event", "")
if kind == "on_chat_model_stream":
run_id = event.get("run_id", "")
if run_id and run_id != current_run_id:
# New LLM run started — fresh artifact slot
current_run_id = run_id
artifact_id = str(uuid.uuid4())
has_streamed = False
chunk = event.get("data", {}).get("chunk")
if chunk is not None:
texts = _extract_chunk_text(chunk.content)
for text in texts:
await updater.add_artifact(
parts=[Part(root=TextPart(text=text))],
artifact_id=artifact_id,
append=has_streamed, # False=first, True=append
last_chunk=False,
)
has_streamed = True
accumulated.append(text)
elif kind == "on_tool_start":
tool_name = event.get("name", "?")
logger.debug("SSE: tool start — %s", tool_name)
if _agency is not None:
_agency.on_tool_call(
tool_name=tool_name,
context_id=context_id,
)
elif kind == "on_tool_end":
logger.debug("SSE: tool end — %s", event.get("name", "?"))
elif kind == "on_chat_model_end":
# Capture the last completed AIMessage for token telemetry
output = event.get("data", {}).get("output")
if output is not None:
last_ai_message = output
# Record token usage from the last completed LLM call
if last_ai_message is not None:
record_llm_token_usage(llm_span, {"messages": [last_ai_message]})
# Build final text from all accumulated streaming tokens
final_text = "".join(accumulated).strip() or "(no response generated)"
logger.info("A2A execute: response length=%d chars", len(final_text))
# ── OA-02 / OA-06: Output PII redaction ──────────────────────
if _COMPLIANCE_AVAILABLE and _compliance_cfg and _compliance_cfg.mode == "owasp_agentic":
final_text, _pii_types = _redact_pii(final_text)
if _pii_types:
from builtin_tools.audit import log_event as _audit_log
_audit_log(
event_type="compliance",
action="pii.redact",
resource="task_output",
outcome="redacted",
pii_types=_pii_types,
context_id=context_id,
)
# ── OTEL: task_complete span ─────────────────────────────────
with tracer.start_as_current_span("task_complete") as done_span:
done_span.set_attribute(WORKSPACE_ID_ATTR, _WORKSPACE_ID)
done_span.set_attribute(A2A_TASK_ID, context_id)
done_span.set_attribute("task.has_response", bool(accumulated))
done_span.set_attribute("task.response_length", len(final_text))
# ── Step 3: emit final Message ────────────────────────────────
# Non-streaming: ResultAggregator.consume_all() returns this
# immediately as the response (a2a_client.py reads .parts[0].text).
# Streaming: yielded as the last SSE event in the stream.
await event_queue.enqueue_event(
new_agent_text_message(final_text, task_id=task_id, context_id=context_id)
)
_result = final_text
except Exception as e:
logger.error("A2A execute error: %s", e, exc_info=True)
try:
task_span.record_exception(e)
from opentelemetry.trace import StatusCode
task_span.set_status(StatusCode.ERROR, str(e))
except Exception:
pass
# Emit a Message so both streaming and non-streaming clients
# receive an error response rather than hanging.
await event_queue.enqueue_event(
new_agent_text_message(
f"Agent error: {e}", task_id=task_id, context_id=context_id
)
)
finally:
await set_current_task(self._heartbeat, "")
return _result
async def cancel(self, context: RequestContext, event_queue: EventQueue) -> None:
"""Cancel a running task — emits canceled state to comply with A2A protocol."""
from a2a.types import TaskStatus, TaskState, TaskStatusUpdateEvent
await event_queue.enqueue_event(
TaskStatusUpdateEvent(
status=TaskStatus(state=TaskState.canceled),
final=True,
)
)