forked from molecule-ai/molecule-core
Every A2A response now includes a tool_trace — the list of tools/commands the agent actually invoked during execution. This enables verifying agent claims against what they actually did, catches hallucinated "I checked X" responses, and provides an audit trail for the CEO to control hundreds of agents by checking the top-level PM's trace. Changes: - Python runtime: collect tool name/input/output_preview on every on_tool_start/on_tool_end event, embed in Message.metadata.tool_trace - Go platform: extract tool_trace from A2A response metadata, store in new activity_logs.tool_trace JSONB column with GIN index - Activity API: expose tool_trace in List and broadcast endpoints - Migration 039: adds tool_trace column + GIN index Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
434 lines
20 KiB
Python
434 lines
20 KiB
Python
"""Bridge between LangGraph agent and A2A protocol, with SSE streaming support.
|
|
|
|
SSE streaming architecture
|
|
--------------------------
|
|
The A2A SDK (``DefaultRequestHandler`` + ``EventQueue``) owns the SSE transport
|
|
layer. This executor's job is to push the right event types into the queue as
|
|
work progresses:
|
|
|
|
1. ``TaskStatusUpdateEvent(state=working)`` — immediately signals start
|
|
2. ``TaskArtifactUpdateEvent(chunk, append=…)`` — one per LLM text token
|
|
3. ``Message(final_text)`` — terminal event
|
|
|
|
Client compatibility
|
|
--------------------
|
|
*Non-streaming* (``message/send``):
|
|
``ResultAggregator.consume_all()`` processes status/artifact events
|
|
(updating the task in the store) and returns the final ``Message``
|
|
immediately — backward-compatible with ``a2a_client.py`` which reads
|
|
``data["result"]["parts"][0]["text"]``.
|
|
|
|
*Streaming* (``message/stream``):
|
|
``consume_and_emit()`` yields every event above as SSE, letting the client
|
|
render tokens in real time.
|
|
|
|
LangGraph integration
|
|
---------------------
|
|
Uses ``agent.astream_events(version="v2")`` to receive ``on_chat_model_stream``
|
|
events with ``AIMessageChunk`` payloads. Text is extracted from both plain
|
|
strings (OpenAI / Groq) and Anthropic-style content-block lists. Non-text
|
|
content (tool_use, etc.) is silently skipped. A fresh ``artifact_id`` is
|
|
generated for each new LLM ``run_id`` so tool-call cycles are grouped cleanly.
|
|
"""
|
|
|
|
import functools
|
|
import logging
|
|
import os
|
|
import uuid
|
|
|
|
from a2a.server.agent_execution import AgentExecutor, RequestContext
|
|
from a2a.server.events import EventQueue
|
|
from a2a.server.tasks import TaskUpdater
|
|
from a2a.types import Part, TextPart
|
|
from a2a.utils import new_agent_text_message
|
|
from shared_runtime import (
|
|
extract_history as _extract_history,
|
|
extract_message_text,
|
|
brief_task,
|
|
set_current_task,
|
|
)
|
|
from builtin_tools.telemetry import (
|
|
A2A_TASK_ID,
|
|
GEN_AI_OPERATION_NAME,
|
|
GEN_AI_REQUEST_MODEL,
|
|
GEN_AI_SYSTEM,
|
|
WORKSPACE_ID_ATTR,
|
|
_incoming_trace_context,
|
|
gen_ai_system_from_model,
|
|
get_tracer,
|
|
record_llm_token_usage,
|
|
)
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
_WORKSPACE_ID = os.environ.get("WORKSPACE_ID", "unknown")
|
|
|
|
# LangGraph ReAct cycle budget per turn. Library default is 25; 500 covers
|
|
# PM fan-outs (plan → 6 delegations → 6 awaits → 6 results → synthesize ≈
|
|
# 30+ steps even before retries). Overridable via LANGGRAPH_RECURSION_LIMIT.
|
|
DEFAULT_RECURSION_LIMIT = 500
|
|
|
|
|
|
def _parse_recursion_limit() -> int:
|
|
"""Read LANGGRAPH_RECURSION_LIMIT; fall back to DEFAULT_RECURSION_LIMIT
|
|
with a WARNING log on any unparseable or non-positive value."""
|
|
raw = os.environ.get("LANGGRAPH_RECURSION_LIMIT", "")
|
|
if not raw:
|
|
return DEFAULT_RECURSION_LIMIT
|
|
try:
|
|
n = int(raw)
|
|
except ValueError:
|
|
logger.warning(
|
|
"LANGGRAPH_RECURSION_LIMIT=%r is not an integer; using default %d",
|
|
raw, DEFAULT_RECURSION_LIMIT,
|
|
)
|
|
return DEFAULT_RECURSION_LIMIT
|
|
if n <= 0:
|
|
logger.warning(
|
|
"LANGGRAPH_RECURSION_LIMIT=%d is not positive; using default %d",
|
|
n, DEFAULT_RECURSION_LIMIT,
|
|
)
|
|
return DEFAULT_RECURSION_LIMIT
|
|
return n
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Compliance (OWASP Top 10 for Agentic Apps) — optional, lazy-loaded
|
|
# ---------------------------------------------------------------------------
|
|
|
|
try:
|
|
from builtin_tools.compliance import (
|
|
AgencyTracker,
|
|
ExcessiveAgencyError,
|
|
PromptInjectionError,
|
|
redact_pii as _redact_pii,
|
|
sanitize_input as _sanitize_input,
|
|
)
|
|
_COMPLIANCE_AVAILABLE = True
|
|
except ImportError: # pragma: no cover
|
|
_COMPLIANCE_AVAILABLE = False
|
|
|
|
|
|
@functools.lru_cache(maxsize=1)
|
|
def _get_compliance_cfg():
|
|
"""Return ComplianceConfig or None (cached for process lifetime)."""
|
|
try:
|
|
from config import load_config
|
|
return load_config().compliance
|
|
except Exception:
|
|
return None
|
|
|
|
|
|
def _extract_chunk_text(content) -> list[str]:
|
|
"""Extract text strings from an LLM streaming chunk's content field.
|
|
|
|
Handles both provider content styles:
|
|
- OpenAI / Groq: ``content`` is a plain ``str`` (empty for tool-call chunks).
|
|
- Anthropic: ``content`` is a list of typed blocks, e.g.
|
|
``[{"type": "text", "text": "Hello"}, {"type": "tool_use", ...}]``
|
|
|
|
Only ``"text"`` blocks are returned; ``tool_use``, ``tool_result``, and
|
|
other non-text blocks are filtered out so raw tool JSON never appears in
|
|
the SSE stream.
|
|
|
|
Args:
|
|
content: ``chunk.content`` value from an ``on_chat_model_stream`` event.
|
|
|
|
Returns:
|
|
List of non-empty text strings.
|
|
"""
|
|
if isinstance(content, str):
|
|
return [content] if content else []
|
|
if isinstance(content, list):
|
|
texts: list[str] = []
|
|
for block in content:
|
|
if isinstance(block, dict) and block.get("type") == "text":
|
|
text = block.get("text", "")
|
|
if text:
|
|
texts.append(text)
|
|
elif isinstance(block, str) and block:
|
|
texts.append(block)
|
|
return texts
|
|
return []
|
|
|
|
|
|
class LangGraphA2AExecutor(AgentExecutor):
|
|
"""Bridges LangGraph agent to A2A event model with SSE streaming support.
|
|
|
|
Always uses ``agent.astream_events()`` so that:
|
|
- Streaming clients (``message/stream``) receive token-level SSE events.
|
|
- Non-streaming clients (``message/send``) receive the final ``Message``
|
|
collected from the same stream — no duplicate LLM call, full compat.
|
|
"""
|
|
|
|
def __init__(self, agent, heartbeat=None, model: str = "unknown"):
|
|
self.agent = agent # Compiled LangGraph graph (create_react_agent output)
|
|
self._heartbeat = heartbeat
|
|
self._model = model # e.g. "anthropic:claude-sonnet-4-6"
|
|
|
|
async def execute(self, context: RequestContext, event_queue: EventQueue) -> None:
|
|
"""Execute a task from an A2A request with SSE streaming.
|
|
|
|
Routes through the Temporal durable workflow when a global
|
|
``TemporalWorkflowWrapper`` is initialised and connected to Temporal;
|
|
otherwise falls back to ``_core_execute()`` (direct path).
|
|
|
|
Event emission sequence:
|
|
1. TaskStatusUpdateEvent(working) — immediate start signal
|
|
2. TaskArtifactUpdateEvent chunks — token-by-token via astream_events
|
|
3. Message(final_text) — terminal; non-streaming clients
|
|
return on this; streaming clients
|
|
also receive it as the last SSE event.
|
|
"""
|
|
# ── Optional Temporal durable execution wrapper ──────────────────────
|
|
# When a TemporalWorkflowWrapper is active this routes execution through
|
|
# a MoleculeAIAgentWorkflow (task_receive → llm_call → task_complete).
|
|
# Falls back silently to _core_execute() on any error or if Temporal
|
|
# is unavailable, so the client always receives a response.
|
|
try:
|
|
from builtin_tools.temporal_workflow import get_wrapper as _get_temporal_wrapper
|
|
|
|
_tw = _get_temporal_wrapper()
|
|
if _tw is not None and _tw.is_available():
|
|
return await _tw.run(self, context, event_queue)
|
|
except Exception:
|
|
pass # Never let the wrapper path crash the executor
|
|
|
|
await self._core_execute(context, event_queue)
|
|
|
|
async def _core_execute(self, context: RequestContext, event_queue: EventQueue) -> str:
|
|
"""Core execution pipeline — called directly or from a Temporal activity.
|
|
|
|
This is the original ``execute()`` body, extracted so that the Temporal
|
|
``llm_call`` activity can invoke it without re-entering the wrapper
|
|
check and causing infinite recursion.
|
|
|
|
Returns the final response text (empty string on empty input or error).
|
|
|
|
Event emission sequence:
|
|
1. TaskStatusUpdateEvent(working) — immediate start signal
|
|
2. TaskArtifactUpdateEvent chunks — token-by-token via astream_events
|
|
3. Message(final_text) — terminal event
|
|
"""
|
|
user_input = extract_message_text(context)
|
|
if not user_input:
|
|
parts = getattr(getattr(context, "message", None), "parts", None)
|
|
logger.warning("A2A execute: no text content in message parts: %s", parts)
|
|
await event_queue.enqueue_event(
|
|
new_agent_text_message("Error: message contained no text content.")
|
|
)
|
|
return ""
|
|
|
|
# ── OA-01: Prompt injection check (OWASP Agentic Top 10) ────────────
|
|
_compliance_cfg = _get_compliance_cfg() if _COMPLIANCE_AVAILABLE else None
|
|
if _COMPLIANCE_AVAILABLE and _compliance_cfg and _compliance_cfg.mode == "owasp_agentic":
|
|
try:
|
|
user_input = _sanitize_input(
|
|
user_input,
|
|
prompt_injection_mode=_compliance_cfg.prompt_injection,
|
|
context_id=context.context_id or "",
|
|
)
|
|
except PromptInjectionError as exc:
|
|
await event_queue.enqueue_event(
|
|
new_agent_text_message(f"Request blocked: {exc}")
|
|
)
|
|
return ""
|
|
|
|
logger.info("A2A execute: user_input=%s", user_input[:200])
|
|
|
|
# ── OTEL: task_receive span ──────────────────────────────────────────
|
|
parent_ctx = _incoming_trace_context.get()
|
|
tracer = get_tracer()
|
|
|
|
_result: str = "" # captured inside the span for return after it closes
|
|
|
|
with tracer.start_as_current_span("task_receive", context=parent_ctx) as task_span:
|
|
task_span.set_attribute(WORKSPACE_ID_ATTR, _WORKSPACE_ID)
|
|
task_span.set_attribute(A2A_TASK_ID, context.context_id or "")
|
|
task_span.set_attribute("a2a.input_preview", user_input[:256])
|
|
|
|
await set_current_task(self._heartbeat, brief_task(user_input))
|
|
|
|
# Resolve IDs — the RequestContextBuilder always sets them, but
|
|
# we generate fallbacks for safety (e.g. in unit tests).
|
|
task_id = context.task_id or str(uuid.uuid4())
|
|
context_id = context.context_id or str(uuid.uuid4())
|
|
|
|
updater = TaskUpdater(event_queue, task_id, context_id)
|
|
|
|
try:
|
|
messages = _extract_history(context)
|
|
if messages:
|
|
logger.info("A2A execute: injecting %d history messages", len(messages))
|
|
messages.append(("human", user_input))
|
|
|
|
# Recursion limit: see DEFAULT_RECURSION_LIMIT and
|
|
# _parse_recursion_limit() at module top. Re-read on every
|
|
# call so the env var can be hot-changed between requests.
|
|
recursion_limit = _parse_recursion_limit()
|
|
run_config = {
|
|
"configurable": {"thread_id": context_id},
|
|
"run_name": f"a2a-{context_id[:8]}",
|
|
"recursion_limit": recursion_limit,
|
|
}
|
|
|
|
# ── OTEL: llm_call span ──────────────────────────────────────
|
|
with tracer.start_as_current_span("llm_call") as llm_span:
|
|
llm_span.set_attribute(GEN_AI_OPERATION_NAME, "chat")
|
|
llm_span.set_attribute(GEN_AI_SYSTEM, gen_ai_system_from_model(self._model))
|
|
llm_span.set_attribute(GEN_AI_REQUEST_MODEL, self._model)
|
|
llm_span.set_attribute(WORKSPACE_ID_ATTR, _WORKSPACE_ID)
|
|
|
|
# ── Step 1: signal "working" to streaming clients ─────────
|
|
await updater.start_work()
|
|
|
|
# ── Step 2: stream tokens via LangGraph astream_events ────
|
|
# Each "on_chat_model_stream" event carries an AIMessageChunk.
|
|
# We emit one TaskArtifactUpdateEvent per text chunk so SSE
|
|
# clients can render tokens in real time.
|
|
# artifact_id resets on each new LLM run_id so agent→tool→agent
|
|
# cycles each get their own artifact slot.
|
|
|
|
artifact_id = str(uuid.uuid4())
|
|
has_streamed = False # True after first chunk for current artifact
|
|
current_run_id = None # Detects new LLM call in a ReAct cycle
|
|
accumulated: list[str] = [] # All text for the final Message
|
|
last_ai_message = None # Saved for token-usage telemetry
|
|
|
|
# ── OA-03: Excessive agency tracker ──────────────────────
|
|
_agency = (
|
|
AgencyTracker(
|
|
max_tool_calls=_compliance_cfg.max_tool_calls_per_task,
|
|
max_duration_seconds=float(_compliance_cfg.max_task_duration_seconds),
|
|
)
|
|
if _COMPLIANCE_AVAILABLE and _compliance_cfg and _compliance_cfg.mode == "owasp_agentic"
|
|
else None
|
|
)
|
|
|
|
# ── Tool trace: collect every tool invocation for
|
|
# platform-level observability ────────────────────
|
|
tool_trace: list[dict] = []
|
|
|
|
async for event in self.agent.astream_events(
|
|
{"messages": messages},
|
|
config=run_config,
|
|
version="v2",
|
|
):
|
|
kind = event.get("event", "")
|
|
|
|
if kind == "on_chat_model_stream":
|
|
run_id = event.get("run_id", "")
|
|
if run_id and run_id != current_run_id:
|
|
# New LLM run started — fresh artifact slot
|
|
current_run_id = run_id
|
|
artifact_id = str(uuid.uuid4())
|
|
has_streamed = False
|
|
|
|
chunk = event.get("data", {}).get("chunk")
|
|
if chunk is not None:
|
|
texts = _extract_chunk_text(chunk.content)
|
|
for text in texts:
|
|
await updater.add_artifact(
|
|
parts=[Part(root=TextPart(text=text))],
|
|
artifact_id=artifact_id,
|
|
append=has_streamed, # False=first, True=append
|
|
last_chunk=False,
|
|
)
|
|
has_streamed = True
|
|
accumulated.append(text)
|
|
|
|
elif kind == "on_tool_start":
|
|
tool_name = event.get("name", "?")
|
|
tool_input = event.get("data", {}).get("input", "")
|
|
logger.debug("SSE: tool start — %s", tool_name)
|
|
tool_trace.append({
|
|
"tool": tool_name,
|
|
"input": str(tool_input)[:500] if tool_input else "",
|
|
})
|
|
if _agency is not None:
|
|
_agency.on_tool_call(
|
|
tool_name=tool_name,
|
|
context_id=context_id,
|
|
)
|
|
|
|
elif kind == "on_tool_end":
|
|
tool_end_name = event.get("name", "?")
|
|
tool_output = event.get("data", {}).get("output", "")
|
|
logger.debug("SSE: tool end — %s", tool_end_name)
|
|
if tool_trace and tool_trace[-1]["tool"] == tool_end_name:
|
|
tool_trace[-1]["output_preview"] = str(tool_output)[:300] if tool_output else ""
|
|
|
|
elif kind == "on_chat_model_end":
|
|
# Capture the last completed AIMessage for token telemetry
|
|
output = event.get("data", {}).get("output")
|
|
if output is not None:
|
|
last_ai_message = output
|
|
|
|
# Record token usage from the last completed LLM call
|
|
if last_ai_message is not None:
|
|
record_llm_token_usage(llm_span, {"messages": [last_ai_message]})
|
|
|
|
# Build final text from all accumulated streaming tokens
|
|
final_text = "".join(accumulated).strip() or "(no response generated)"
|
|
logger.info("A2A execute: response length=%d chars", len(final_text))
|
|
|
|
# ── OA-02 / OA-06: Output PII redaction ──────────────────────
|
|
if _COMPLIANCE_AVAILABLE and _compliance_cfg and _compliance_cfg.mode == "owasp_agentic":
|
|
final_text, _pii_types = _redact_pii(final_text)
|
|
if _pii_types:
|
|
from builtin_tools.audit import log_event as _audit_log
|
|
_audit_log(
|
|
event_type="compliance",
|
|
action="pii.redact",
|
|
resource="task_output",
|
|
outcome="redacted",
|
|
pii_types=_pii_types,
|
|
context_id=context_id,
|
|
)
|
|
|
|
# ── OTEL: task_complete span ─────────────────────────────────
|
|
with tracer.start_as_current_span("task_complete") as done_span:
|
|
done_span.set_attribute(WORKSPACE_ID_ATTR, _WORKSPACE_ID)
|
|
done_span.set_attribute(A2A_TASK_ID, context_id)
|
|
done_span.set_attribute("task.has_response", bool(accumulated))
|
|
done_span.set_attribute("task.response_length", len(final_text))
|
|
|
|
# ── Step 3: emit final Message ────────────────────────────────
|
|
# Non-streaming: ResultAggregator.consume_all() returns this
|
|
# immediately as the response (a2a_client.py reads .parts[0].text).
|
|
# Streaming: yielded as the last SSE event in the stream.
|
|
msg = new_agent_text_message(final_text, task_id=task_id, context_id=context_id)
|
|
if tool_trace:
|
|
msg.metadata = {"tool_trace": tool_trace}
|
|
await event_queue.enqueue_event(msg)
|
|
_result = final_text
|
|
|
|
except Exception as e:
|
|
logger.error("A2A execute error: %s", e, exc_info=True)
|
|
try:
|
|
task_span.record_exception(e)
|
|
from opentelemetry.trace import StatusCode
|
|
task_span.set_status(StatusCode.ERROR, str(e))
|
|
except Exception:
|
|
pass
|
|
# Emit a Message so both streaming and non-streaming clients
|
|
# receive an error response rather than hanging.
|
|
await event_queue.enqueue_event(
|
|
new_agent_text_message(
|
|
f"Agent error: {e}", task_id=task_id, context_id=context_id
|
|
)
|
|
)
|
|
finally:
|
|
await set_current_task(self._heartbeat, "")
|
|
|
|
return _result
|
|
|
|
async def cancel(self, context: RequestContext, event_queue: EventQueue) -> None:
|
|
"""Cancel a running task — emits canceled state to comply with A2A protocol."""
|
|
from a2a.types import TaskStatus, TaskState, TaskStatusUpdateEvent
|
|
await event_queue.enqueue_event(
|
|
TaskStatusUpdateEvent(
|
|
status=TaskStatus(state=TaskState.canceled),
|
|
final=True,
|
|
)
|
|
)
|