molecule-core/workspace/a2a_executor.py
rabbitblood 6c618c9c3f feat: add tool_trace to activity_logs for platform-level agent observability
Every A2A response now includes a tool_trace — the list of tools/commands
the agent actually invoked during execution. This enables verifying agent
claims against what they actually did, catches hallucinated "I checked X"
responses, and provides an audit trail for the CEO to control hundreds of
agents by checking the top-level PM's trace.

Changes:
- Python runtime: collect tool name/input/output_preview on every
  on_tool_start/on_tool_end event, embed in Message.metadata.tool_trace
- Go platform: extract tool_trace from A2A response metadata, store in
  new activity_logs.tool_trace JSONB column with GIN index
- Activity API: expose tool_trace in List and broadcast endpoints
- Migration 039: adds tool_trace column + GIN index

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-04-22 15:17:14 -07:00

434 lines
20 KiB
Python

"""Bridge between LangGraph agent and A2A protocol, with SSE streaming support.
SSE streaming architecture
--------------------------
The A2A SDK (``DefaultRequestHandler`` + ``EventQueue``) owns the SSE transport
layer. This executor's job is to push the right event types into the queue as
work progresses:
1. ``TaskStatusUpdateEvent(state=working)`` — immediately signals start
2. ``TaskArtifactUpdateEvent(chunk, append=…)`` — one per LLM text token
3. ``Message(final_text)`` — terminal event
Client compatibility
--------------------
*Non-streaming* (``message/send``):
``ResultAggregator.consume_all()`` processes status/artifact events
(updating the task in the store) and returns the final ``Message``
immediately — backward-compatible with ``a2a_client.py`` which reads
``data["result"]["parts"][0]["text"]``.
*Streaming* (``message/stream``):
``consume_and_emit()`` yields every event above as SSE, letting the client
render tokens in real time.
LangGraph integration
---------------------
Uses ``agent.astream_events(version="v2")`` to receive ``on_chat_model_stream``
events with ``AIMessageChunk`` payloads. Text is extracted from both plain
strings (OpenAI / Groq) and Anthropic-style content-block lists. Non-text
content (tool_use, etc.) is silently skipped. A fresh ``artifact_id`` is
generated for each new LLM ``run_id`` so tool-call cycles are grouped cleanly.
"""
import functools
import logging
import os
import uuid
from a2a.server.agent_execution import AgentExecutor, RequestContext
from a2a.server.events import EventQueue
from a2a.server.tasks import TaskUpdater
from a2a.types import Part, TextPart
from a2a.utils import new_agent_text_message
from shared_runtime import (
extract_history as _extract_history,
extract_message_text,
brief_task,
set_current_task,
)
from builtin_tools.telemetry import (
A2A_TASK_ID,
GEN_AI_OPERATION_NAME,
GEN_AI_REQUEST_MODEL,
GEN_AI_SYSTEM,
WORKSPACE_ID_ATTR,
_incoming_trace_context,
gen_ai_system_from_model,
get_tracer,
record_llm_token_usage,
)
logger = logging.getLogger(__name__)
_WORKSPACE_ID = os.environ.get("WORKSPACE_ID", "unknown")
# LangGraph ReAct cycle budget per turn. Library default is 25; 500 covers
# PM fan-outs (plan → 6 delegations → 6 awaits → 6 results → synthesize ≈
# 30+ steps even before retries). Overridable via LANGGRAPH_RECURSION_LIMIT.
DEFAULT_RECURSION_LIMIT = 500
def _parse_recursion_limit() -> int:
"""Read LANGGRAPH_RECURSION_LIMIT; fall back to DEFAULT_RECURSION_LIMIT
with a WARNING log on any unparseable or non-positive value."""
raw = os.environ.get("LANGGRAPH_RECURSION_LIMIT", "")
if not raw:
return DEFAULT_RECURSION_LIMIT
try:
n = int(raw)
except ValueError:
logger.warning(
"LANGGRAPH_RECURSION_LIMIT=%r is not an integer; using default %d",
raw, DEFAULT_RECURSION_LIMIT,
)
return DEFAULT_RECURSION_LIMIT
if n <= 0:
logger.warning(
"LANGGRAPH_RECURSION_LIMIT=%d is not positive; using default %d",
n, DEFAULT_RECURSION_LIMIT,
)
return DEFAULT_RECURSION_LIMIT
return n
# ---------------------------------------------------------------------------
# Compliance (OWASP Top 10 for Agentic Apps) — optional, lazy-loaded
# ---------------------------------------------------------------------------
try:
from builtin_tools.compliance import (
AgencyTracker,
ExcessiveAgencyError,
PromptInjectionError,
redact_pii as _redact_pii,
sanitize_input as _sanitize_input,
)
_COMPLIANCE_AVAILABLE = True
except ImportError: # pragma: no cover
_COMPLIANCE_AVAILABLE = False
@functools.lru_cache(maxsize=1)
def _get_compliance_cfg():
"""Return ComplianceConfig or None (cached for process lifetime)."""
try:
from config import load_config
return load_config().compliance
except Exception:
return None
def _extract_chunk_text(content) -> list[str]:
"""Extract text strings from an LLM streaming chunk's content field.
Handles both provider content styles:
- OpenAI / Groq: ``content`` is a plain ``str`` (empty for tool-call chunks).
- Anthropic: ``content`` is a list of typed blocks, e.g.
``[{"type": "text", "text": "Hello"}, {"type": "tool_use", ...}]``
Only ``"text"`` blocks are returned; ``tool_use``, ``tool_result``, and
other non-text blocks are filtered out so raw tool JSON never appears in
the SSE stream.
Args:
content: ``chunk.content`` value from an ``on_chat_model_stream`` event.
Returns:
List of non-empty text strings.
"""
if isinstance(content, str):
return [content] if content else []
if isinstance(content, list):
texts: list[str] = []
for block in content:
if isinstance(block, dict) and block.get("type") == "text":
text = block.get("text", "")
if text:
texts.append(text)
elif isinstance(block, str) and block:
texts.append(block)
return texts
return []
class LangGraphA2AExecutor(AgentExecutor):
"""Bridges LangGraph agent to A2A event model with SSE streaming support.
Always uses ``agent.astream_events()`` so that:
- Streaming clients (``message/stream``) receive token-level SSE events.
- Non-streaming clients (``message/send``) receive the final ``Message``
collected from the same stream — no duplicate LLM call, full compat.
"""
def __init__(self, agent, heartbeat=None, model: str = "unknown"):
self.agent = agent # Compiled LangGraph graph (create_react_agent output)
self._heartbeat = heartbeat
self._model = model # e.g. "anthropic:claude-sonnet-4-6"
async def execute(self, context: RequestContext, event_queue: EventQueue) -> None:
"""Execute a task from an A2A request with SSE streaming.
Routes through the Temporal durable workflow when a global
``TemporalWorkflowWrapper`` is initialised and connected to Temporal;
otherwise falls back to ``_core_execute()`` (direct path).
Event emission sequence:
1. TaskStatusUpdateEvent(working) — immediate start signal
2. TaskArtifactUpdateEvent chunks — token-by-token via astream_events
3. Message(final_text) — terminal; non-streaming clients
return on this; streaming clients
also receive it as the last SSE event.
"""
# ── Optional Temporal durable execution wrapper ──────────────────────
# When a TemporalWorkflowWrapper is active this routes execution through
# a MoleculeAIAgentWorkflow (task_receive → llm_call → task_complete).
# Falls back silently to _core_execute() on any error or if Temporal
# is unavailable, so the client always receives a response.
try:
from builtin_tools.temporal_workflow import get_wrapper as _get_temporal_wrapper
_tw = _get_temporal_wrapper()
if _tw is not None and _tw.is_available():
return await _tw.run(self, context, event_queue)
except Exception:
pass # Never let the wrapper path crash the executor
await self._core_execute(context, event_queue)
async def _core_execute(self, context: RequestContext, event_queue: EventQueue) -> str:
"""Core execution pipeline — called directly or from a Temporal activity.
This is the original ``execute()`` body, extracted so that the Temporal
``llm_call`` activity can invoke it without re-entering the wrapper
check and causing infinite recursion.
Returns the final response text (empty string on empty input or error).
Event emission sequence:
1. TaskStatusUpdateEvent(working) — immediate start signal
2. TaskArtifactUpdateEvent chunks — token-by-token via astream_events
3. Message(final_text) — terminal event
"""
user_input = extract_message_text(context)
if not user_input:
parts = getattr(getattr(context, "message", None), "parts", None)
logger.warning("A2A execute: no text content in message parts: %s", parts)
await event_queue.enqueue_event(
new_agent_text_message("Error: message contained no text content.")
)
return ""
# ── OA-01: Prompt injection check (OWASP Agentic Top 10) ────────────
_compliance_cfg = _get_compliance_cfg() if _COMPLIANCE_AVAILABLE else None
if _COMPLIANCE_AVAILABLE and _compliance_cfg and _compliance_cfg.mode == "owasp_agentic":
try:
user_input = _sanitize_input(
user_input,
prompt_injection_mode=_compliance_cfg.prompt_injection,
context_id=context.context_id or "",
)
except PromptInjectionError as exc:
await event_queue.enqueue_event(
new_agent_text_message(f"Request blocked: {exc}")
)
return ""
logger.info("A2A execute: user_input=%s", user_input[:200])
# ── OTEL: task_receive span ──────────────────────────────────────────
parent_ctx = _incoming_trace_context.get()
tracer = get_tracer()
_result: str = "" # captured inside the span for return after it closes
with tracer.start_as_current_span("task_receive", context=parent_ctx) as task_span:
task_span.set_attribute(WORKSPACE_ID_ATTR, _WORKSPACE_ID)
task_span.set_attribute(A2A_TASK_ID, context.context_id or "")
task_span.set_attribute("a2a.input_preview", user_input[:256])
await set_current_task(self._heartbeat, brief_task(user_input))
# Resolve IDs — the RequestContextBuilder always sets them, but
# we generate fallbacks for safety (e.g. in unit tests).
task_id = context.task_id or str(uuid.uuid4())
context_id = context.context_id or str(uuid.uuid4())
updater = TaskUpdater(event_queue, task_id, context_id)
try:
messages = _extract_history(context)
if messages:
logger.info("A2A execute: injecting %d history messages", len(messages))
messages.append(("human", user_input))
# Recursion limit: see DEFAULT_RECURSION_LIMIT and
# _parse_recursion_limit() at module top. Re-read on every
# call so the env var can be hot-changed between requests.
recursion_limit = _parse_recursion_limit()
run_config = {
"configurable": {"thread_id": context_id},
"run_name": f"a2a-{context_id[:8]}",
"recursion_limit": recursion_limit,
}
# ── OTEL: llm_call span ──────────────────────────────────────
with tracer.start_as_current_span("llm_call") as llm_span:
llm_span.set_attribute(GEN_AI_OPERATION_NAME, "chat")
llm_span.set_attribute(GEN_AI_SYSTEM, gen_ai_system_from_model(self._model))
llm_span.set_attribute(GEN_AI_REQUEST_MODEL, self._model)
llm_span.set_attribute(WORKSPACE_ID_ATTR, _WORKSPACE_ID)
# ── Step 1: signal "working" to streaming clients ─────────
await updater.start_work()
# ── Step 2: stream tokens via LangGraph astream_events ────
# Each "on_chat_model_stream" event carries an AIMessageChunk.
# We emit one TaskArtifactUpdateEvent per text chunk so SSE
# clients can render tokens in real time.
# artifact_id resets on each new LLM run_id so agent→tool→agent
# cycles each get their own artifact slot.
artifact_id = str(uuid.uuid4())
has_streamed = False # True after first chunk for current artifact
current_run_id = None # Detects new LLM call in a ReAct cycle
accumulated: list[str] = [] # All text for the final Message
last_ai_message = None # Saved for token-usage telemetry
# ── OA-03: Excessive agency tracker ──────────────────────
_agency = (
AgencyTracker(
max_tool_calls=_compliance_cfg.max_tool_calls_per_task,
max_duration_seconds=float(_compliance_cfg.max_task_duration_seconds),
)
if _COMPLIANCE_AVAILABLE and _compliance_cfg and _compliance_cfg.mode == "owasp_agentic"
else None
)
# ── Tool trace: collect every tool invocation for
# platform-level observability ────────────────────
tool_trace: list[dict] = []
async for event in self.agent.astream_events(
{"messages": messages},
config=run_config,
version="v2",
):
kind = event.get("event", "")
if kind == "on_chat_model_stream":
run_id = event.get("run_id", "")
if run_id and run_id != current_run_id:
# New LLM run started — fresh artifact slot
current_run_id = run_id
artifact_id = str(uuid.uuid4())
has_streamed = False
chunk = event.get("data", {}).get("chunk")
if chunk is not None:
texts = _extract_chunk_text(chunk.content)
for text in texts:
await updater.add_artifact(
parts=[Part(root=TextPart(text=text))],
artifact_id=artifact_id,
append=has_streamed, # False=first, True=append
last_chunk=False,
)
has_streamed = True
accumulated.append(text)
elif kind == "on_tool_start":
tool_name = event.get("name", "?")
tool_input = event.get("data", {}).get("input", "")
logger.debug("SSE: tool start — %s", tool_name)
tool_trace.append({
"tool": tool_name,
"input": str(tool_input)[:500] if tool_input else "",
})
if _agency is not None:
_agency.on_tool_call(
tool_name=tool_name,
context_id=context_id,
)
elif kind == "on_tool_end":
tool_end_name = event.get("name", "?")
tool_output = event.get("data", {}).get("output", "")
logger.debug("SSE: tool end — %s", tool_end_name)
if tool_trace and tool_trace[-1]["tool"] == tool_end_name:
tool_trace[-1]["output_preview"] = str(tool_output)[:300] if tool_output else ""
elif kind == "on_chat_model_end":
# Capture the last completed AIMessage for token telemetry
output = event.get("data", {}).get("output")
if output is not None:
last_ai_message = output
# Record token usage from the last completed LLM call
if last_ai_message is not None:
record_llm_token_usage(llm_span, {"messages": [last_ai_message]})
# Build final text from all accumulated streaming tokens
final_text = "".join(accumulated).strip() or "(no response generated)"
logger.info("A2A execute: response length=%d chars", len(final_text))
# ── OA-02 / OA-06: Output PII redaction ──────────────────────
if _COMPLIANCE_AVAILABLE and _compliance_cfg and _compliance_cfg.mode == "owasp_agentic":
final_text, _pii_types = _redact_pii(final_text)
if _pii_types:
from builtin_tools.audit import log_event as _audit_log
_audit_log(
event_type="compliance",
action="pii.redact",
resource="task_output",
outcome="redacted",
pii_types=_pii_types,
context_id=context_id,
)
# ── OTEL: task_complete span ─────────────────────────────────
with tracer.start_as_current_span("task_complete") as done_span:
done_span.set_attribute(WORKSPACE_ID_ATTR, _WORKSPACE_ID)
done_span.set_attribute(A2A_TASK_ID, context_id)
done_span.set_attribute("task.has_response", bool(accumulated))
done_span.set_attribute("task.response_length", len(final_text))
# ── Step 3: emit final Message ────────────────────────────────
# Non-streaming: ResultAggregator.consume_all() returns this
# immediately as the response (a2a_client.py reads .parts[0].text).
# Streaming: yielded as the last SSE event in the stream.
msg = new_agent_text_message(final_text, task_id=task_id, context_id=context_id)
if tool_trace:
msg.metadata = {"tool_trace": tool_trace}
await event_queue.enqueue_event(msg)
_result = final_text
except Exception as e:
logger.error("A2A execute error: %s", e, exc_info=True)
try:
task_span.record_exception(e)
from opentelemetry.trace import StatusCode
task_span.set_status(StatusCode.ERROR, str(e))
except Exception:
pass
# Emit a Message so both streaming and non-streaming clients
# receive an error response rather than hanging.
await event_queue.enqueue_event(
new_agent_text_message(
f"Agent error: {e}", task_id=task_id, context_id=context_id
)
)
finally:
await set_current_task(self._heartbeat, "")
return _result
async def cancel(self, context: RequestContext, event_queue: EventQueue) -> None:
"""Cancel a running task — emits canceled state to comply with A2A protocol."""
from a2a.types import TaskStatus, TaskState, TaskStatusUpdateEvent
await event_queue.enqueue_event(
TaskStatusUpdateEvent(
status=TaskStatus(state=TaskState.canceled),
final=True,
)
)