forked from molecule-ai/molecule-core
573 lines
21 KiB
Python
573 lines
21 KiB
Python
"""Shared helpers for AgentExecutor implementations.
|
|
|
|
Used by both CLIAgentExecutor (codex, ollama) and ClaudeSDKExecutor (claude-code).
|
|
Provides:
|
|
- Memory recall/commit (HTTP to platform /memories endpoints)
|
|
- Delegation results consumption (atomic file rename)
|
|
- Current task heartbeat updates
|
|
- System prompt loading from /configs
|
|
- A2A instructions text for system prompt injection (MCP and CLI variants)
|
|
- Brief task summary extraction (markdown-aware)
|
|
- Error message sanitization (exception classes and subprocess categories)
|
|
- Shared workspace path constants and the MCP server path resolver
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import asyncio
|
|
import json
|
|
import logging
|
|
import os
|
|
import re
|
|
import subprocess
|
|
from pathlib import Path
|
|
from typing import TYPE_CHECKING, Any
|
|
|
|
import httpx
|
|
|
|
from builtin_tools.security import _redact_secrets
|
|
|
|
if TYPE_CHECKING:
|
|
from heartbeat import HeartbeatLoop
|
|
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
# ========================================================================
|
|
# Constants — workspace container layout
|
|
# ========================================================================
|
|
|
|
WORKSPACE_MOUNT = "/workspace"
|
|
CONFIG_MOUNT = "/configs"
|
|
DEFAULT_MCP_SERVER_PATH = "/app/a2a_mcp_server.py"
|
|
DEFAULT_DELEGATION_RESULTS_FILE = "/tmp/delegation_results.jsonl"
|
|
PLATFORM_HTTP_TIMEOUT_S = 5.0
|
|
MEMORY_RECALL_LIMIT = 10
|
|
MEMORY_CONTENT_MAX_CHARS = 200
|
|
BRIEF_SUMMARY_MAX_LEN = 80
|
|
|
|
|
|
def get_mcp_server_path() -> str:
|
|
"""Return the path to the stdio MCP server script.
|
|
|
|
Overridable via A2A_MCP_SERVER_PATH for tests and non-default layouts.
|
|
"""
|
|
return os.environ.get("A2A_MCP_SERVER_PATH", DEFAULT_MCP_SERVER_PATH)
|
|
|
|
|
|
# ========================================================================
|
|
# HTTP client (shared, lazily initialised)
|
|
# ========================================================================
|
|
|
|
_http_client: httpx.AsyncClient | None = None
|
|
|
|
|
|
def get_http_client() -> httpx.AsyncClient:
|
|
"""Lazy-init a shared httpx client for platform API calls."""
|
|
global _http_client
|
|
if _http_client is None or _http_client.is_closed:
|
|
_http_client = httpx.AsyncClient(timeout=PLATFORM_HTTP_TIMEOUT_S)
|
|
return _http_client
|
|
|
|
|
|
def reset_http_client_for_tests() -> None:
|
|
"""Test helper — drop the shared client so the next call rebuilds it.
|
|
|
|
Not for production use. Exposed so tests can guarantee a clean slate
|
|
between cases without touching module internals.
|
|
"""
|
|
global _http_client
|
|
_http_client = None
|
|
|
|
|
|
# ========================================================================
|
|
# Memory recall + commit
|
|
# ========================================================================
|
|
|
|
async def recall_memories() -> str:
|
|
"""Recall recent memories from the platform API.
|
|
|
|
Returns a newline-joined bullet list of up to MEMORY_RECALL_LIMIT most recent
|
|
memories, or empty string when the platform is unreachable / not configured
|
|
/ returns a non-200 / returns an unexpected payload shape.
|
|
"""
|
|
workspace_id = os.environ.get("WORKSPACE_ID", "")
|
|
platform_url = os.environ.get("PLATFORM_URL", "")
|
|
if not workspace_id or not platform_url:
|
|
return ""
|
|
# Fix E (Cycle 5): send auth headers so the WorkspaceAuth middleware
|
|
# (Fix A) allows access once the workspace has a live token on file.
|
|
try:
|
|
from platform_auth import auth_headers as _platform_auth
|
|
_auth = _platform_auth()
|
|
except Exception:
|
|
_auth = {}
|
|
try:
|
|
resp = await get_http_client().get(
|
|
f"{platform_url}/workspaces/{workspace_id}/memories",
|
|
headers=_auth,
|
|
)
|
|
if not 200 <= resp.status_code < 300:
|
|
logger.debug(
|
|
"recall_memories: non-2xx response %s from platform",
|
|
resp.status_code,
|
|
)
|
|
return ""
|
|
data = resp.json()
|
|
except Exception as exc:
|
|
logger.debug("recall_memories: request failed: %s", exc)
|
|
return ""
|
|
if not isinstance(data, list) or not data:
|
|
return ""
|
|
lines = [
|
|
f"- [{m.get('scope', '?')}] {m.get('content', '')}"
|
|
for m in data[-MEMORY_RECALL_LIMIT:]
|
|
]
|
|
return "\n".join(lines)
|
|
|
|
|
|
async def commit_memory(content: str) -> None:
|
|
"""Save a memory to the platform API. Best-effort, no error propagation."""
|
|
workspace_id = os.environ.get("WORKSPACE_ID", "")
|
|
platform_url = os.environ.get("PLATFORM_URL", "")
|
|
if not workspace_id or not platform_url or not content:
|
|
return
|
|
content = _redact_secrets(content)
|
|
# Fix E (Cycle 5): include auth header so WorkspaceAuth middleware allows access.
|
|
try:
|
|
from platform_auth import auth_headers as _platform_auth
|
|
_auth = _platform_auth()
|
|
except Exception:
|
|
_auth = {}
|
|
try:
|
|
await get_http_client().post(
|
|
f"{platform_url}/workspaces/{workspace_id}/memories",
|
|
json={"content": content, "scope": "LOCAL"},
|
|
headers=_auth,
|
|
)
|
|
except Exception as exc:
|
|
logger.debug("commit_memory: request failed: %s", exc)
|
|
|
|
|
|
# ========================================================================
|
|
# Delegation results — written by heartbeat loop, consumed atomically
|
|
# ========================================================================
|
|
|
|
def read_delegation_results() -> str:
|
|
"""Read and consume delegation results written by the heartbeat loop.
|
|
|
|
Uses atomic rename to prevent races with the heartbeat writer.
|
|
Returns formatted text suitable for prompt injection, or empty string.
|
|
"""
|
|
results_file = Path(
|
|
os.environ.get("DELEGATION_RESULTS_FILE", DEFAULT_DELEGATION_RESULTS_FILE)
|
|
)
|
|
if not results_file.exists():
|
|
return ""
|
|
consumed = results_file.with_suffix(".consumed")
|
|
try:
|
|
results_file.rename(consumed)
|
|
except OSError:
|
|
return "" # File disappeared between exists() and rename()
|
|
try:
|
|
raw = consumed.read_text(encoding="utf-8", errors="replace")
|
|
except OSError:
|
|
return ""
|
|
finally:
|
|
consumed.unlink(missing_ok=True)
|
|
|
|
parts: list[str] = []
|
|
for line in raw.strip().split("\n"):
|
|
if not line.strip():
|
|
continue
|
|
try:
|
|
record = json.loads(line)
|
|
except json.JSONDecodeError:
|
|
continue
|
|
status = record.get("status", "?")
|
|
summary = record.get("summary", "")
|
|
preview = record.get("response_preview", "")
|
|
parts.append(f"- [{status}] {summary}")
|
|
if preview:
|
|
parts.append(f" Response: {preview[:200]}")
|
|
return "\n".join(parts)
|
|
|
|
|
|
# ========================================================================
|
|
# Current task heartbeat update
|
|
# ========================================================================
|
|
|
|
async def set_current_task(heartbeat: "HeartbeatLoop | None", task: str) -> None:
|
|
"""Update current task on heartbeat and push immediately via platform API."""
|
|
if heartbeat is not None:
|
|
heartbeat.current_task = task
|
|
heartbeat.active_tasks = 1 if task else 0
|
|
workspace_id = os.environ.get("WORKSPACE_ID", "")
|
|
platform_url = os.environ.get("PLATFORM_URL", "")
|
|
if not (workspace_id and platform_url):
|
|
return
|
|
try:
|
|
try:
|
|
from platform_auth import auth_headers as _auth
|
|
_headers = _auth()
|
|
except Exception:
|
|
_headers = {}
|
|
await get_http_client().post(
|
|
f"{platform_url}/registry/heartbeat",
|
|
json={
|
|
"workspace_id": workspace_id,
|
|
"current_task": task,
|
|
"active_tasks": 1 if task else 0,
|
|
"error_rate": 0,
|
|
"sample_error": "",
|
|
"uptime_seconds": 0,
|
|
},
|
|
headers=_headers,
|
|
)
|
|
except Exception as exc:
|
|
logger.debug("set_current_task: heartbeat push failed: %s", exc)
|
|
|
|
|
|
# ========================================================================
|
|
# System prompt loading
|
|
# ========================================================================
|
|
|
|
def get_system_prompt(config_path: str, fallback: str | None = None) -> str | None:
|
|
"""Read system-prompt.md from the config dir each call (supports hot-reload).
|
|
|
|
Falls back to the provided string if the file doesn't exist.
|
|
"""
|
|
prompt_file = Path(config_path) / "system-prompt.md"
|
|
if prompt_file.exists():
|
|
return prompt_file.read_text(encoding="utf-8", errors="replace").strip()
|
|
return fallback
|
|
|
|
|
|
_A2A_INSTRUCTIONS_MCP = """## Inter-Agent Communication
|
|
You have MCP tools for communicating with other workspaces:
|
|
- list_peers: discover available peer workspaces (name, ID, status, role)
|
|
- delegate_task: send a task and WAIT for the response (for quick tasks)
|
|
- delegate_task_async: send a task and return immediately with a task_id (for long tasks)
|
|
- check_task_status: poll an async task's status and get results when done
|
|
- get_workspace_info: get your own workspace info
|
|
|
|
For quick questions, use delegate_task (synchronous).
|
|
For long-running work (building pages, running audits), use delegate_task_async + check_task_status.
|
|
Always use list_peers first to discover available workspace IDs.
|
|
Access control is enforced — you can only reach siblings and parent/children.
|
|
|
|
PROACTIVE MESSAGING: Use send_message_to_user to push messages to the user's chat at ANY time:
|
|
- Acknowledge tasks immediately: "Got it, delegating to the team now..."
|
|
- Send progress updates during long work: "Research Lead finished, waiting on Dev Lead..."
|
|
- Deliver follow-up results: "All teams reported back. Here's the synthesis: ..."
|
|
This lets you respond quickly ("I'll work on this") and come back later with results.
|
|
|
|
If delegate_task returns a DELEGATION FAILED message, do NOT forward the raw error to the user.
|
|
Instead: (1) try delegating to a different peer, (2) handle the task yourself, or
|
|
(3) tell the user which peer is unavailable and provide your own best answer."""
|
|
|
|
|
|
_A2A_INSTRUCTIONS_CLI = """## Inter-Agent Communication
|
|
You can delegate tasks to other workspaces using the a2a command:
|
|
python3 /app/a2a_cli.py peers # List available peers
|
|
python3 /app/a2a_cli.py delegate <workspace_id> <task> # Sync: wait for response
|
|
python3 /app/a2a_cli.py delegate --async <workspace_id> <task> # Async: return task_id
|
|
python3 /app/a2a_cli.py status <workspace_id> <task_id> # Check async task
|
|
python3 /app/a2a_cli.py info # Your workspace info
|
|
|
|
For quick questions, use sync delegate. For long tasks, use --async + status.
|
|
Only delegate to peers listed by the peers command (access control enforced)."""
|
|
|
|
|
|
def get_a2a_instructions(mcp: bool = True) -> str:
|
|
"""Return inter-agent communication instructions for system-prompt injection.
|
|
|
|
Pass `mcp=True` (default) for MCP-capable runtimes (Claude Code via SDK,
|
|
Codex). Pass `mcp=False` for CLI-only runtimes (Ollama, custom) that have
|
|
to call a2a_cli.py as a subprocess.
|
|
"""
|
|
return _A2A_INSTRUCTIONS_MCP if mcp else _A2A_INSTRUCTIONS_CLI
|
|
|
|
|
|
_HMA_INSTRUCTIONS = """## Hierarchical Memory (HMA)
|
|
You have persistent memory tools that survive across sessions and restarts:
|
|
|
|
- **commit_memory(content, scope)**: Save important information.
|
|
- LOCAL: private to you only (default)
|
|
- TEAM: shared with your parent workspace and siblings (same team)
|
|
- GLOBAL: shared with the entire org (only root workspaces can write)
|
|
|
|
- **recall_memory(query)**: Search your accessible memories. Returns LOCAL + TEAM + GLOBAL matches.
|
|
|
|
**When to use memory:**
|
|
- After making a decision or learning something non-obvious → commit_memory("decision X because Y", scope="TEAM")
|
|
- Before starting work → recall_memory("what did the team decide about X")
|
|
- When you discover org-wide knowledge (repo locations, API patterns, conventions) → commit_memory(fact, scope="GLOBAL") if you are a root workspace, or scope="TEAM" to share with your team
|
|
- After completing a task → commit_memory("completed task X, PR #N opened", scope="TEAM") so your lead and teammates know
|
|
|
|
**Memory is automatically recalled** at the start of each new session. Use it proactively during work to share context.
|
|
"""
|
|
|
|
|
|
def get_hma_instructions() -> str:
|
|
"""Return HMA memory instructions for system-prompt injection."""
|
|
return _HMA_INSTRUCTIONS
|
|
|
|
|
|
# ========================================================================
|
|
# Misc text helpers
|
|
# ========================================================================
|
|
|
|
_MARKDOWN_FENCE = "```"
|
|
_MARKDOWN_HR = "---"
|
|
|
|
|
|
_BRIEF_SUMMARY_MIN_LEN = 4 # 1 char + 3-char ellipsis
|
|
|
|
|
|
def brief_summary(text: str, max_len: int = BRIEF_SUMMARY_MAX_LEN) -> str:
|
|
"""Extract a one-line task summary for the canvas card display.
|
|
|
|
Strips markdown headers (#, ##, ###), bold/italic markers (**, __),
|
|
and skips code fences and horizontal rules. Returns the first meaningful
|
|
line, truncated with an ellipsis when it exceeds `max_len`.
|
|
|
|
`max_len` is clamped to at least 4 (one real character plus a 3-char
|
|
ellipsis) so degenerate callers can't produce negative slice indices.
|
|
"""
|
|
max_len = max(max_len, _BRIEF_SUMMARY_MIN_LEN)
|
|
for raw_line in text.split("\n"):
|
|
line = raw_line.strip()
|
|
while line.startswith("#"):
|
|
line = line[1:]
|
|
line = line.strip()
|
|
if not line or line.startswith(_MARKDOWN_FENCE) or line == _MARKDOWN_HR:
|
|
continue
|
|
line = line.replace("**", "").replace("__", "")
|
|
if len(line) > max_len:
|
|
return line[: max_len - 3] + "..."
|
|
return line
|
|
return text[:max_len]
|
|
|
|
|
|
def extract_message_text(message: Any) -> str:
|
|
"""Extract text from an A2A message (handles both .text and .root.text patterns)."""
|
|
parts = getattr(message, "parts", None) or []
|
|
text_parts: list[str] = []
|
|
for part in parts:
|
|
text = getattr(part, "text", None)
|
|
if text:
|
|
text_parts.append(text)
|
|
continue
|
|
root = getattr(part, "root", None)
|
|
if root is not None:
|
|
root_text = getattr(root, "text", None)
|
|
if root_text:
|
|
text_parts.append(root_text)
|
|
return " ".join(text_parts).strip()
|
|
|
|
|
|
# Word-boundary patterns for subprocess stderr classification. Using word
|
|
# boundaries avoids false positives like "author" matching "auth" or
|
|
# "generate" matching "rate".
|
|
_RATE_LIMIT_RE = re.compile(r"\brate\b|\b429\b|\boverloaded\b", re.IGNORECASE)
|
|
_AUTH_RE = re.compile(r"\bauth(?:entication|orization)?\b|\bapi[_-]?key\b", re.IGNORECASE)
|
|
_SESSION_RE = re.compile(r"\bsession\b|\bno conversation found\b", re.IGNORECASE)
|
|
|
|
|
|
def classify_subprocess_error(stderr_text: str, exit_code: int | None) -> str:
|
|
"""Map a subprocess stderr blob to a short, user-safe category tag.
|
|
|
|
The full stderr goes to the workspace logs via `logger.error`; only the
|
|
category is surfaced to the user to avoid leaking tokens, internal paths,
|
|
or stack traces in the chat UI. Used with `sanitize_agent_error` to
|
|
produce a user-facing message for subprocess failures.
|
|
"""
|
|
if _RATE_LIMIT_RE.search(stderr_text):
|
|
return "rate_limited"
|
|
if _AUTH_RE.search(stderr_text):
|
|
return "auth_failed"
|
|
if _SESSION_RE.search(stderr_text):
|
|
return "session_error"
|
|
if exit_code is not None and exit_code != 0:
|
|
return f"exit_{exit_code}"
|
|
return "subprocess_error"
|
|
|
|
|
|
def sanitize_agent_error(
|
|
exc: BaseException | None = None,
|
|
category: str | None = None,
|
|
) -> str:
|
|
"""Render an agent-side failure into a user-safe error message.
|
|
|
|
Either pass an exception (class name is used as the tag) or an explicit
|
|
category string (e.g. from `classify_subprocess_error`). If both are
|
|
given, `category` wins. If neither, the tag defaults to "unknown".
|
|
|
|
The message body is deliberately dropped — exception messages and
|
|
subprocess stderr frequently leak stack traces, paths, tokens, and
|
|
API keys. Full detail is available in the workspace logs via
|
|
`logger.exception()` / `logger.error()`.
|
|
"""
|
|
if category:
|
|
tag = category
|
|
elif exc is not None:
|
|
tag = type(exc).__name__
|
|
else:
|
|
tag = "unknown"
|
|
return f"Agent error ({tag}) — see workspace logs for details."
|
|
|
|
|
|
# ========================================================================
|
|
# Auto-push hook — push unpushed commits and open PR after task completion
|
|
# ========================================================================
|
|
|
|
# Git/gh wrappers at /usr/local/bin have GH_TOKEN baked in.
|
|
_GIT = "/usr/local/bin/git"
|
|
_GH = "/usr/local/bin/gh"
|
|
_PROTECTED_BRANCHES = frozenset({"staging", "main", "master"})
|
|
|
|
|
|
def _run_git(args: list[str], cwd: str, timeout: int = 30) -> subprocess.CompletedProcess:
|
|
"""Run a git/gh command with bounded timeout. Never raises on failure."""
|
|
return subprocess.run(
|
|
args,
|
|
cwd=cwd,
|
|
capture_output=True,
|
|
text=True,
|
|
timeout=timeout,
|
|
)
|
|
|
|
|
|
def _auto_push_and_pr_sync(cwd: str) -> None:
|
|
"""Synchronous implementation of the auto-push hook.
|
|
|
|
1. Check if we're in a git repo with unpushed commits on a feature branch.
|
|
2. Push the branch.
|
|
3. Open a PR against staging if one doesn't already exist.
|
|
|
|
Designed to be called from a background thread — never raises, logs all
|
|
errors. Uses the git/gh wrappers at /usr/local/bin/ which have GH_TOKEN
|
|
baked in.
|
|
"""
|
|
try:
|
|
# --- Guard: is this a git repo? ---
|
|
probe = _run_git([_GIT, "rev-parse", "--is-inside-work-tree"], cwd)
|
|
if probe.returncode != 0:
|
|
return
|
|
|
|
# --- Guard: get current branch ---
|
|
branch_result = _run_git(
|
|
[_GIT, "rev-parse", "--abbrev-ref", "HEAD"], cwd
|
|
)
|
|
if branch_result.returncode != 0:
|
|
return
|
|
branch = branch_result.stdout.strip()
|
|
if not branch or branch in _PROTECTED_BRANCHES or branch == "HEAD":
|
|
return
|
|
|
|
# --- Guard: any unpushed commits? ---
|
|
log_result = _run_git(
|
|
[_GIT, "log", "origin/staging..HEAD", "--oneline"], cwd
|
|
)
|
|
if log_result.returncode != 0 or not log_result.stdout.strip():
|
|
# No unpushed commits (or origin/staging doesn't exist).
|
|
return
|
|
|
|
unpushed_lines = log_result.stdout.strip().splitlines()
|
|
logger.info(
|
|
"auto-push: %d unpushed commit(s) on branch '%s', pushing...",
|
|
len(unpushed_lines),
|
|
branch,
|
|
)
|
|
|
|
# --- Push ---
|
|
push_result = _run_git(
|
|
[_GIT, "push", "origin", branch], cwd, timeout=60
|
|
)
|
|
if push_result.returncode != 0:
|
|
logger.warning(
|
|
"auto-push: git push failed (exit %d): %s",
|
|
push_result.returncode,
|
|
(push_result.stderr or push_result.stdout)[:500],
|
|
)
|
|
return
|
|
|
|
logger.info("auto-push: pushed branch '%s' successfully", branch)
|
|
|
|
# --- Check if PR already exists ---
|
|
pr_list = _run_git(
|
|
[_GH, "pr", "list", "--head", branch, "--json", "number"], cwd
|
|
)
|
|
if pr_list.returncode != 0:
|
|
logger.warning(
|
|
"auto-push: gh pr list failed (exit %d): %s",
|
|
pr_list.returncode,
|
|
(pr_list.stderr or pr_list.stdout)[:500],
|
|
)
|
|
return
|
|
|
|
existing_prs = json.loads(pr_list.stdout.strip() or "[]")
|
|
if existing_prs:
|
|
logger.info(
|
|
"auto-push: PR already exists for branch '%s' (#%s), skipping create",
|
|
branch,
|
|
existing_prs[0].get("number", "?"),
|
|
)
|
|
return
|
|
|
|
# --- Get first commit message for PR title ---
|
|
first_commit = _run_git(
|
|
[_GIT, "log", "origin/staging..HEAD", "--reverse",
|
|
"--format=%s", "-1"],
|
|
cwd,
|
|
)
|
|
pr_title = first_commit.stdout.strip() if first_commit.returncode == 0 else branch
|
|
# Truncate to 256 chars (GitHub limit)
|
|
if len(pr_title) > 256:
|
|
pr_title = pr_title[:253] + "..."
|
|
|
|
# --- Create PR ---
|
|
pr_create = _run_git(
|
|
[
|
|
_GH, "pr", "create",
|
|
"--base", "staging",
|
|
"--title", pr_title,
|
|
"--body", "Auto-created by workspace agent",
|
|
],
|
|
cwd,
|
|
timeout=60,
|
|
)
|
|
if pr_create.returncode != 0:
|
|
logger.warning(
|
|
"auto-push: gh pr create failed (exit %d): %s",
|
|
pr_create.returncode,
|
|
(pr_create.stderr or pr_create.stdout)[:500],
|
|
)
|
|
else:
|
|
pr_url = pr_create.stdout.strip()
|
|
logger.info("auto-push: created PR %s", pr_url)
|
|
|
|
except subprocess.TimeoutExpired:
|
|
logger.warning("auto-push: command timed out, skipping")
|
|
except Exception:
|
|
logger.exception("auto-push: unexpected error (non-fatal)")
|
|
|
|
|
|
async def auto_push_hook(cwd: str | None = None) -> None:
|
|
"""Post-execution hook: push unpushed commits and open a PR.
|
|
|
|
Runs the git/gh subprocess work in a background thread via
|
|
asyncio.to_thread so it never blocks the agent's event loop.
|
|
Catches all exceptions — the agent must never crash due to this hook.
|
|
"""
|
|
if cwd is None:
|
|
cwd = WORKSPACE_MOUNT
|
|
if not os.path.isdir(cwd):
|
|
return
|
|
try:
|
|
await asyncio.to_thread(_auto_push_and_pr_sync, cwd)
|
|
except Exception:
|
|
logger.exception("auto_push_hook: failed (non-fatal)")
|