fix(agent): add tool-call loop guardrails
This commit is contained in:
parent
8d7500d80d
commit
58b89965c8
@ -14,6 +14,7 @@ from difflib import unified_diff
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
from utils import safe_json_loads
|
from utils import safe_json_loads
|
||||||
|
from agent.tool_guardrails import classify_tool_failure
|
||||||
|
|
||||||
# ANSI escape codes for coloring tool failure indicators
|
# ANSI escape codes for coloring tool failure indicators
|
||||||
_RED = "\033[31m"
|
_RED = "\033[31m"
|
||||||
@ -808,30 +809,7 @@ def _detect_tool_failure(tool_name: str, result: str | None) -> tuple[bool, str]
|
|||||||
like ``" [exit 1]"`` for terminal failures, or ``" [error]"`` for generic
|
like ``" [exit 1]"`` for terminal failures, or ``" [error]"`` for generic
|
||||||
failures. On success, returns ``(False, "")``.
|
failures. On success, returns ``(False, "")``.
|
||||||
"""
|
"""
|
||||||
if result is None:
|
return classify_tool_failure(tool_name, result)
|
||||||
return False, ""
|
|
||||||
|
|
||||||
if tool_name == "terminal":
|
|
||||||
data = safe_json_loads(result)
|
|
||||||
if isinstance(data, dict):
|
|
||||||
exit_code = data.get("exit_code")
|
|
||||||
if exit_code is not None and exit_code != 0:
|
|
||||||
return True, f" [exit {exit_code}]"
|
|
||||||
return False, ""
|
|
||||||
|
|
||||||
# Memory-specific: distinguish "full" from real errors
|
|
||||||
if tool_name == "memory":
|
|
||||||
data = safe_json_loads(result)
|
|
||||||
if isinstance(data, dict):
|
|
||||||
if data.get("success") is False and "exceed the limit" in data.get("error", ""):
|
|
||||||
return True, " [full]"
|
|
||||||
|
|
||||||
# Generic heuristic for non-terminal tools
|
|
||||||
lower = result[:500].lower()
|
|
||||||
if '"error"' in lower or '"failed"' in lower or result.startswith("Error"):
|
|
||||||
return True, " [error]"
|
|
||||||
|
|
||||||
return False, ""
|
|
||||||
|
|
||||||
|
|
||||||
def get_cute_tool_message(
|
def get_cute_tool_message(
|
||||||
|
|||||||
381
agent/tool_guardrails.py
Normal file
381
agent/tool_guardrails.py
Normal file
@ -0,0 +1,381 @@
|
|||||||
|
"""Pure tool-call loop guardrail primitives.
|
||||||
|
|
||||||
|
The controller in this module is intentionally side-effect free: it tracks
|
||||||
|
per-turn tool-call observations and returns decisions. Runtime code owns whether
|
||||||
|
those decisions become synthetic tool results or controlled turn halts.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import hashlib
|
||||||
|
import json
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from typing import Any, Mapping
|
||||||
|
|
||||||
|
from utils import safe_json_loads
|
||||||
|
|
||||||
|
|
||||||
|
IDEMPOTENT_TOOL_NAMES = frozenset(
|
||||||
|
{
|
||||||
|
"read_file",
|
||||||
|
"search_files",
|
||||||
|
"web_search",
|
||||||
|
"web_extract",
|
||||||
|
"session_search",
|
||||||
|
"browser_snapshot",
|
||||||
|
"browser_console",
|
||||||
|
"browser_get_images",
|
||||||
|
"mcp_filesystem_read_file",
|
||||||
|
"mcp_filesystem_read_text_file",
|
||||||
|
"mcp_filesystem_read_multiple_files",
|
||||||
|
"mcp_filesystem_list_directory",
|
||||||
|
"mcp_filesystem_list_directory_with_sizes",
|
||||||
|
"mcp_filesystem_directory_tree",
|
||||||
|
"mcp_filesystem_get_file_info",
|
||||||
|
"mcp_filesystem_search_files",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
MUTATING_TOOL_NAMES = frozenset(
|
||||||
|
{
|
||||||
|
"terminal",
|
||||||
|
"execute_code",
|
||||||
|
"write_file",
|
||||||
|
"patch",
|
||||||
|
"todo",
|
||||||
|
"memory",
|
||||||
|
"skill_manage",
|
||||||
|
"browser_click",
|
||||||
|
"browser_type",
|
||||||
|
"browser_press",
|
||||||
|
"browser_scroll",
|
||||||
|
"browser_navigate",
|
||||||
|
"send_message",
|
||||||
|
"cronjob",
|
||||||
|
"delegate_task",
|
||||||
|
"process",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class ToolCallGuardrailConfig:
|
||||||
|
"""Thresholds for per-turn tool-call loop detection."""
|
||||||
|
|
||||||
|
exact_failure_warn_after: int = 2
|
||||||
|
exact_failure_block_after: int = 2
|
||||||
|
same_tool_failure_warn_after: int = 3
|
||||||
|
same_tool_failure_halt_after: int = 5
|
||||||
|
no_progress_warn_after: int = 2
|
||||||
|
no_progress_block_after: int = 2
|
||||||
|
idempotent_tools: frozenset[str] = field(default_factory=lambda: IDEMPOTENT_TOOL_NAMES)
|
||||||
|
mutating_tools: frozenset[str] = field(default_factory=lambda: MUTATING_TOOL_NAMES)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class ToolCallSignature:
|
||||||
|
"""Stable, non-reversible identity for a tool name plus canonical args."""
|
||||||
|
|
||||||
|
tool_name: str
|
||||||
|
args_hash: str
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_call(cls, tool_name: str, args: Mapping[str, Any] | None) -> "ToolCallSignature":
|
||||||
|
canonical = canonical_tool_args(args or {})
|
||||||
|
return cls(tool_name=tool_name, args_hash=_sha256(canonical))
|
||||||
|
|
||||||
|
def to_metadata(self) -> dict[str, str]:
|
||||||
|
"""Return public metadata without raw argument values."""
|
||||||
|
return {"tool_name": self.tool_name, "args_hash": self.args_hash}
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class ToolGuardrailDecision:
|
||||||
|
"""Decision returned by the tool-call guardrail controller."""
|
||||||
|
|
||||||
|
action: str = "allow" # allow | warn | block | halt
|
||||||
|
code: str = "allow"
|
||||||
|
message: str = ""
|
||||||
|
tool_name: str = ""
|
||||||
|
count: int = 0
|
||||||
|
signature: ToolCallSignature | None = None
|
||||||
|
|
||||||
|
@property
|
||||||
|
def allows_execution(self) -> bool:
|
||||||
|
return self.action in {"allow", "warn"}
|
||||||
|
|
||||||
|
@property
|
||||||
|
def should_halt(self) -> bool:
|
||||||
|
return self.action in {"block", "halt"}
|
||||||
|
|
||||||
|
def to_metadata(self) -> dict[str, Any]:
|
||||||
|
data: dict[str, Any] = {
|
||||||
|
"action": self.action,
|
||||||
|
"code": self.code,
|
||||||
|
"message": self.message,
|
||||||
|
"tool_name": self.tool_name,
|
||||||
|
"count": self.count,
|
||||||
|
}
|
||||||
|
if self.signature is not None:
|
||||||
|
data["signature"] = self.signature.to_metadata()
|
||||||
|
return data
|
||||||
|
|
||||||
|
|
||||||
|
def canonical_tool_args(args: Mapping[str, Any]) -> str:
|
||||||
|
"""Return sorted compact JSON for parsed tool arguments."""
|
||||||
|
if not isinstance(args, Mapping):
|
||||||
|
raise TypeError(f"tool args must be a mapping, got {type(args).__name__}")
|
||||||
|
return json.dumps(
|
||||||
|
args,
|
||||||
|
ensure_ascii=False,
|
||||||
|
sort_keys=True,
|
||||||
|
separators=(",", ":"),
|
||||||
|
default=str,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def classify_tool_failure(tool_name: str, result: str | None) -> tuple[bool, str]:
|
||||||
|
"""Classify a tool result using shared display/runtime semantics."""
|
||||||
|
if result is None:
|
||||||
|
return False, ""
|
||||||
|
|
||||||
|
if tool_name == "terminal":
|
||||||
|
data = safe_json_loads(result)
|
||||||
|
if isinstance(data, dict):
|
||||||
|
exit_code = data.get("exit_code")
|
||||||
|
if exit_code is not None and exit_code != 0:
|
||||||
|
return True, f" [exit {exit_code}]"
|
||||||
|
if data.get("success") is False or data.get("failed") is True:
|
||||||
|
return True, " [error]"
|
||||||
|
error = data.get("error")
|
||||||
|
if error is not None and error != "":
|
||||||
|
return True, " [error]"
|
||||||
|
return False, ""
|
||||||
|
|
||||||
|
data = safe_json_loads(result)
|
||||||
|
if isinstance(data, dict):
|
||||||
|
if tool_name == "memory":
|
||||||
|
error = data.get("error", "")
|
||||||
|
if data.get("success") is False and isinstance(error, str) and "exceed the limit" in error:
|
||||||
|
return True, " [full]"
|
||||||
|
if data.get("success") is False or data.get("failed") is True:
|
||||||
|
return True, " [error]"
|
||||||
|
error = data.get("error")
|
||||||
|
if error is not None and error != "":
|
||||||
|
return True, " [error]"
|
||||||
|
return False, ""
|
||||||
|
|
||||||
|
lower = result[:500].lower()
|
||||||
|
if "traceback" in lower or lower.startswith("error:"):
|
||||||
|
return True, " [error]"
|
||||||
|
if '"error"' in lower or '"failed"' in lower or result.startswith("Error"):
|
||||||
|
return True, " [error]"
|
||||||
|
return False, ""
|
||||||
|
|
||||||
|
|
||||||
|
class ToolCallGuardrailController:
|
||||||
|
"""Per-turn controller for repeated failed/non-progressing tool calls."""
|
||||||
|
|
||||||
|
def __init__(self, config: ToolCallGuardrailConfig | None = None):
|
||||||
|
self.config = config or ToolCallGuardrailConfig()
|
||||||
|
self.reset_for_turn()
|
||||||
|
|
||||||
|
def reset_for_turn(self) -> None:
|
||||||
|
self._exact_failure_counts: dict[ToolCallSignature, int] = {}
|
||||||
|
self._same_tool_failure_counts: dict[str, int] = {}
|
||||||
|
self._no_progress: dict[ToolCallSignature, tuple[str, int]] = {}
|
||||||
|
self._halt_decision: ToolGuardrailDecision | None = None
|
||||||
|
|
||||||
|
@property
|
||||||
|
def halt_decision(self) -> ToolGuardrailDecision | None:
|
||||||
|
return self._halt_decision
|
||||||
|
|
||||||
|
def before_call(self, tool_name: str, args: Mapping[str, Any] | None) -> ToolGuardrailDecision:
|
||||||
|
signature = ToolCallSignature.from_call(tool_name, _coerce_args(args))
|
||||||
|
|
||||||
|
exact_count = self._exact_failure_counts.get(signature, 0)
|
||||||
|
if exact_count >= self.config.exact_failure_block_after:
|
||||||
|
decision = ToolGuardrailDecision(
|
||||||
|
action="block",
|
||||||
|
code="repeated_exact_failure_block",
|
||||||
|
message=(
|
||||||
|
f"Blocked {tool_name}: the same tool call failed {exact_count} "
|
||||||
|
"times with identical arguments. Stop retrying it unchanged; "
|
||||||
|
"change strategy or explain the blocker."
|
||||||
|
),
|
||||||
|
tool_name=tool_name,
|
||||||
|
count=exact_count,
|
||||||
|
signature=signature,
|
||||||
|
)
|
||||||
|
self._halt_decision = decision
|
||||||
|
return decision
|
||||||
|
|
||||||
|
if self._is_idempotent(tool_name):
|
||||||
|
record = self._no_progress.get(signature)
|
||||||
|
if record is not None:
|
||||||
|
_result_hash, repeat_count = record
|
||||||
|
if repeat_count >= self.config.no_progress_block_after:
|
||||||
|
decision = ToolGuardrailDecision(
|
||||||
|
action="block",
|
||||||
|
code="idempotent_no_progress_block",
|
||||||
|
message=(
|
||||||
|
f"Blocked {tool_name}: this read-only call returned the same "
|
||||||
|
f"result {repeat_count} times. Stop repeating it unchanged; "
|
||||||
|
"use the result already provided or try a different query."
|
||||||
|
),
|
||||||
|
tool_name=tool_name,
|
||||||
|
count=repeat_count,
|
||||||
|
signature=signature,
|
||||||
|
)
|
||||||
|
self._halt_decision = decision
|
||||||
|
return decision
|
||||||
|
|
||||||
|
return ToolGuardrailDecision(tool_name=tool_name, signature=signature)
|
||||||
|
|
||||||
|
def after_call(
|
||||||
|
self,
|
||||||
|
tool_name: str,
|
||||||
|
args: Mapping[str, Any] | None,
|
||||||
|
result: str | None,
|
||||||
|
*,
|
||||||
|
failed: bool | None = None,
|
||||||
|
) -> ToolGuardrailDecision:
|
||||||
|
args = _coerce_args(args)
|
||||||
|
signature = ToolCallSignature.from_call(tool_name, args)
|
||||||
|
if failed is None:
|
||||||
|
failed, _ = classify_tool_failure(tool_name, result)
|
||||||
|
|
||||||
|
if failed:
|
||||||
|
exact_count = self._exact_failure_counts.get(signature, 0) + 1
|
||||||
|
self._exact_failure_counts[signature] = exact_count
|
||||||
|
self._no_progress.pop(signature, None)
|
||||||
|
|
||||||
|
same_count = self._same_tool_failure_counts.get(tool_name, 0) + 1
|
||||||
|
self._same_tool_failure_counts[tool_name] = same_count
|
||||||
|
|
||||||
|
if same_count >= self.config.same_tool_failure_halt_after:
|
||||||
|
decision = ToolGuardrailDecision(
|
||||||
|
action="halt",
|
||||||
|
code="same_tool_failure_halt",
|
||||||
|
message=(
|
||||||
|
f"Stopped {tool_name}: it failed {same_count} times this turn. "
|
||||||
|
"Stop retrying the same failing tool path and choose a different approach."
|
||||||
|
),
|
||||||
|
tool_name=tool_name,
|
||||||
|
count=same_count,
|
||||||
|
signature=signature,
|
||||||
|
)
|
||||||
|
self._halt_decision = decision
|
||||||
|
return decision
|
||||||
|
|
||||||
|
if exact_count >= self.config.exact_failure_warn_after:
|
||||||
|
return ToolGuardrailDecision(
|
||||||
|
action="warn",
|
||||||
|
code="repeated_exact_failure_warning",
|
||||||
|
message=(
|
||||||
|
f"Tool guardrail: {tool_name} has failed {exact_count} times "
|
||||||
|
"with identical arguments. Do not retry it unchanged; inspect the "
|
||||||
|
"error and change strategy."
|
||||||
|
),
|
||||||
|
tool_name=tool_name,
|
||||||
|
count=exact_count,
|
||||||
|
signature=signature,
|
||||||
|
)
|
||||||
|
|
||||||
|
if same_count >= self.config.same_tool_failure_warn_after:
|
||||||
|
return ToolGuardrailDecision(
|
||||||
|
action="warn",
|
||||||
|
code="same_tool_failure_warning",
|
||||||
|
message=(
|
||||||
|
f"Tool guardrail: {tool_name} has failed {same_count} times "
|
||||||
|
"this turn. Change approach before retrying."
|
||||||
|
),
|
||||||
|
tool_name=tool_name,
|
||||||
|
count=same_count,
|
||||||
|
signature=signature,
|
||||||
|
)
|
||||||
|
|
||||||
|
return ToolGuardrailDecision(tool_name=tool_name, count=exact_count, signature=signature)
|
||||||
|
|
||||||
|
self._exact_failure_counts.pop(signature, None)
|
||||||
|
self._same_tool_failure_counts.pop(tool_name, None)
|
||||||
|
|
||||||
|
if not self._is_idempotent(tool_name):
|
||||||
|
self._no_progress.pop(signature, None)
|
||||||
|
return ToolGuardrailDecision(tool_name=tool_name, signature=signature)
|
||||||
|
|
||||||
|
result_hash = _result_hash(result)
|
||||||
|
previous = self._no_progress.get(signature)
|
||||||
|
repeat_count = 1
|
||||||
|
if previous is not None and previous[0] == result_hash:
|
||||||
|
repeat_count = previous[1] + 1
|
||||||
|
self._no_progress[signature] = (result_hash, repeat_count)
|
||||||
|
|
||||||
|
if repeat_count >= self.config.no_progress_warn_after:
|
||||||
|
return ToolGuardrailDecision(
|
||||||
|
action="warn",
|
||||||
|
code="idempotent_no_progress_warning",
|
||||||
|
message=(
|
||||||
|
f"Tool guardrail: {tool_name} returned the same result "
|
||||||
|
f"{repeat_count} times. Use the result or change the query instead "
|
||||||
|
"of repeating it unchanged."
|
||||||
|
),
|
||||||
|
tool_name=tool_name,
|
||||||
|
count=repeat_count,
|
||||||
|
signature=signature,
|
||||||
|
)
|
||||||
|
|
||||||
|
return ToolGuardrailDecision(tool_name=tool_name, count=repeat_count, signature=signature)
|
||||||
|
|
||||||
|
def _is_idempotent(self, tool_name: str) -> bool:
|
||||||
|
if tool_name in self.config.mutating_tools:
|
||||||
|
return False
|
||||||
|
return tool_name in self.config.idempotent_tools
|
||||||
|
|
||||||
|
|
||||||
|
def toolguard_synthetic_result(decision: ToolGuardrailDecision) -> str:
|
||||||
|
"""Build a synthetic role=tool content string for a blocked tool call."""
|
||||||
|
return json.dumps(
|
||||||
|
{
|
||||||
|
"error": decision.message,
|
||||||
|
"guardrail": decision.to_metadata(),
|
||||||
|
},
|
||||||
|
ensure_ascii=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def append_toolguard_guidance(result: str, decision: ToolGuardrailDecision) -> str:
|
||||||
|
"""Append runtime guidance to the current tool result content."""
|
||||||
|
if decision.action not in {"warn", "halt"} or not decision.message:
|
||||||
|
return result
|
||||||
|
suffix = (
|
||||||
|
"\n\n[Tool guardrail: "
|
||||||
|
f"{decision.code}; count={decision.count}; {decision.message}]"
|
||||||
|
)
|
||||||
|
return (result or "") + suffix
|
||||||
|
|
||||||
|
|
||||||
|
def _coerce_args(args: Mapping[str, Any] | None) -> Mapping[str, Any]:
|
||||||
|
return args if isinstance(args, Mapping) else {}
|
||||||
|
|
||||||
|
|
||||||
|
def _result_hash(result: str | None) -> str:
|
||||||
|
parsed = safe_json_loads(result or "")
|
||||||
|
if parsed is not None:
|
||||||
|
try:
|
||||||
|
canonical = json.dumps(
|
||||||
|
parsed,
|
||||||
|
ensure_ascii=False,
|
||||||
|
sort_keys=True,
|
||||||
|
separators=(",", ":"),
|
||||||
|
default=str,
|
||||||
|
)
|
||||||
|
except TypeError:
|
||||||
|
canonical = str(parsed)
|
||||||
|
else:
|
||||||
|
canonical = result or ""
|
||||||
|
return _sha256(canonical)
|
||||||
|
|
||||||
|
|
||||||
|
def _sha256(value: str) -> str:
|
||||||
|
return hashlib.sha256(value.encode("utf-8")).hexdigest()
|
||||||
301
run_agent.py
301
run_agent.py
@ -162,6 +162,12 @@ from agent.display import (
|
|||||||
_detect_tool_failure,
|
_detect_tool_failure,
|
||||||
get_tool_emoji as _get_tool_emoji,
|
get_tool_emoji as _get_tool_emoji,
|
||||||
)
|
)
|
||||||
|
from agent.tool_guardrails import (
|
||||||
|
ToolCallGuardrailController,
|
||||||
|
ToolGuardrailDecision,
|
||||||
|
append_toolguard_guidance,
|
||||||
|
toolguard_synthetic_result,
|
||||||
|
)
|
||||||
from agent.trajectory import (
|
from agent.trajectory import (
|
||||||
convert_scratchpad_to_think, has_incomplete_scratchpad,
|
convert_scratchpad_to_think, has_incomplete_scratchpad,
|
||||||
save_trajectory as _save_trajectory_to_file,
|
save_trajectory as _save_trajectory_to_file,
|
||||||
@ -1150,6 +1156,8 @@ class AIAgent:
|
|||||||
# Tool execution state — allows _vprint during tool execution
|
# Tool execution state — allows _vprint during tool execution
|
||||||
# even when stream consumers are registered (no tokens streaming then)
|
# even when stream consumers are registered (no tokens streaming then)
|
||||||
self._executing_tools = False
|
self._executing_tools = False
|
||||||
|
self._tool_guardrails = ToolCallGuardrailController()
|
||||||
|
self._tool_guardrail_halt_decision: ToolGuardrailDecision | None = None
|
||||||
|
|
||||||
# Interrupt mechanism for breaking out of tool loops
|
# Interrupt mechanism for breaking out of tool loops
|
||||||
self._interrupt_requested = False
|
self._interrupt_requested = False
|
||||||
@ -9107,6 +9115,44 @@ class AIAgent:
|
|||||||
)
|
)
|
||||||
return compressed, new_system_prompt
|
return compressed, new_system_prompt
|
||||||
|
|
||||||
|
def _set_tool_guardrail_halt(self, decision: ToolGuardrailDecision) -> None:
|
||||||
|
"""Record the first guardrail decision that should stop this turn."""
|
||||||
|
if decision.should_halt and self._tool_guardrail_halt_decision is None:
|
||||||
|
self._tool_guardrail_halt_decision = decision
|
||||||
|
|
||||||
|
def _toolguard_controlled_halt_response(self, decision: ToolGuardrailDecision) -> str:
|
||||||
|
tool = decision.tool_name or "a tool"
|
||||||
|
return (
|
||||||
|
f"I stopped retrying {tool} because it hit the tool-call guardrail "
|
||||||
|
f"({decision.code}) after {decision.count} repeated non-progressing "
|
||||||
|
"attempts. The last tool result explains the blocker; the next step is "
|
||||||
|
"to change strategy instead of repeating the same call."
|
||||||
|
)
|
||||||
|
|
||||||
|
def _append_guardrail_observation(
|
||||||
|
self,
|
||||||
|
tool_name: str,
|
||||||
|
function_args: dict,
|
||||||
|
function_result: str,
|
||||||
|
*,
|
||||||
|
failed: bool,
|
||||||
|
) -> str:
|
||||||
|
decision = self._tool_guardrails.after_call(
|
||||||
|
tool_name,
|
||||||
|
function_args,
|
||||||
|
function_result,
|
||||||
|
failed=failed,
|
||||||
|
)
|
||||||
|
if decision.action in {"warn", "halt"}:
|
||||||
|
function_result = append_toolguard_guidance(function_result, decision)
|
||||||
|
if decision.should_halt:
|
||||||
|
self._set_tool_guardrail_halt(decision)
|
||||||
|
return function_result
|
||||||
|
|
||||||
|
def _guardrail_block_result(self, decision: ToolGuardrailDecision) -> str:
|
||||||
|
self._set_tool_guardrail_halt(decision)
|
||||||
|
return toolguard_synthetic_result(decision)
|
||||||
|
|
||||||
def _execute_tool_calls(self, assistant_message, messages: list, effective_task_id: str, api_call_count: int = 0) -> None:
|
def _execute_tool_calls(self, assistant_message, messages: list, effective_task_id: str, api_call_count: int = 0) -> None:
|
||||||
"""Execute tool calls from the assistant message and append results to messages.
|
"""Execute tool calls from the assistant message and append results to messages.
|
||||||
|
|
||||||
@ -9150,7 +9196,8 @@ class AIAgent:
|
|||||||
)
|
)
|
||||||
|
|
||||||
def _invoke_tool(self, function_name: str, function_args: dict, effective_task_id: str,
|
def _invoke_tool(self, function_name: str, function_args: dict, effective_task_id: str,
|
||||||
tool_call_id: Optional[str] = None, messages: list = None) -> str:
|
tool_call_id: Optional[str] = None, messages: list = None,
|
||||||
|
pre_tool_block_checked: bool = False) -> str:
|
||||||
"""Invoke a single tool and return the result string. No display logic.
|
"""Invoke a single tool and return the result string. No display logic.
|
||||||
|
|
||||||
Handles both agent-level tools (todo, memory, etc.) and registry-dispatched
|
Handles both agent-level tools (todo, memory, etc.) and registry-dispatched
|
||||||
@ -9159,13 +9206,14 @@ class AIAgent:
|
|||||||
"""
|
"""
|
||||||
# Check plugin hooks for a block directive before executing anything.
|
# Check plugin hooks for a block directive before executing anything.
|
||||||
block_message: Optional[str] = None
|
block_message: Optional[str] = None
|
||||||
try:
|
if not pre_tool_block_checked:
|
||||||
from hermes_cli.plugins import get_pre_tool_call_block_message
|
try:
|
||||||
block_message = get_pre_tool_call_block_message(
|
from hermes_cli.plugins import get_pre_tool_call_block_message
|
||||||
function_name, function_args, task_id=effective_task_id or "",
|
block_message = get_pre_tool_call_block_message(
|
||||||
)
|
function_name, function_args, task_id=effective_task_id or "",
|
||||||
except Exception:
|
)
|
||||||
pass
|
except Exception:
|
||||||
|
pass
|
||||||
if block_message is not None:
|
if block_message is not None:
|
||||||
return json.dumps({"error": block_message}, ensure_ascii=False)
|
return json.dumps({"error": block_message}, ensure_ascii=False)
|
||||||
|
|
||||||
@ -9317,13 +9365,31 @@ class AIAgent:
|
|||||||
except Exception:
|
except Exception:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
parsed_calls.append((tool_call, function_name, function_args))
|
block_result = None
|
||||||
|
blocked_by_guardrail = False
|
||||||
|
try:
|
||||||
|
from hermes_cli.plugins import get_pre_tool_call_block_message
|
||||||
|
block_message = get_pre_tool_call_block_message(
|
||||||
|
function_name, function_args, task_id=effective_task_id or "",
|
||||||
|
)
|
||||||
|
except Exception:
|
||||||
|
block_message = None
|
||||||
|
|
||||||
|
if block_message is not None:
|
||||||
|
block_result = json.dumps({"error": block_message}, ensure_ascii=False)
|
||||||
|
else:
|
||||||
|
guardrail_decision = self._tool_guardrails.before_call(function_name, function_args)
|
||||||
|
if not guardrail_decision.allows_execution:
|
||||||
|
block_result = self._guardrail_block_result(guardrail_decision)
|
||||||
|
blocked_by_guardrail = True
|
||||||
|
|
||||||
|
parsed_calls.append((tool_call, function_name, function_args, block_result, blocked_by_guardrail))
|
||||||
|
|
||||||
# ── Logging / callbacks ──────────────────────────────────────────
|
# ── Logging / callbacks ──────────────────────────────────────────
|
||||||
tool_names_str = ", ".join(name for _, name, _ in parsed_calls)
|
tool_names_str = ", ".join(name for _, name, _, _, _ in parsed_calls)
|
||||||
if not self.quiet_mode:
|
if not self.quiet_mode:
|
||||||
print(f" ⚡ Concurrent: {num_tools} tool calls — {tool_names_str}")
|
print(f" ⚡ Concurrent: {num_tools} tool calls — {tool_names_str}")
|
||||||
for i, (tc, name, args) in enumerate(parsed_calls, 1):
|
for i, (tc, name, args, block_result, blocked_by_guardrail) in enumerate(parsed_calls, 1):
|
||||||
args_str = json.dumps(args, ensure_ascii=False)
|
args_str = json.dumps(args, ensure_ascii=False)
|
||||||
if self.verbose_logging:
|
if self.verbose_logging:
|
||||||
print(f" 📞 Tool {i}: {name}({list(args.keys())})")
|
print(f" 📞 Tool {i}: {name}({list(args.keys())})")
|
||||||
@ -9332,7 +9398,9 @@ class AIAgent:
|
|||||||
args_preview = args_str[:self.log_prefix_chars] + "..." if len(args_str) > self.log_prefix_chars else args_str
|
args_preview = args_str[:self.log_prefix_chars] + "..." if len(args_str) > self.log_prefix_chars else args_str
|
||||||
print(f" 📞 Tool {i}: {name}({list(args.keys())}) - {args_preview}")
|
print(f" 📞 Tool {i}: {name}({list(args.keys())}) - {args_preview}")
|
||||||
|
|
||||||
for tc, name, args in parsed_calls:
|
for tc, name, args, block_result, blocked_by_guardrail in parsed_calls:
|
||||||
|
if block_result is not None:
|
||||||
|
continue
|
||||||
if self.tool_progress_callback:
|
if self.tool_progress_callback:
|
||||||
try:
|
try:
|
||||||
preview = _build_tool_preview(name, args)
|
preview = _build_tool_preview(name, args)
|
||||||
@ -9340,7 +9408,9 @@ class AIAgent:
|
|||||||
except Exception as cb_err:
|
except Exception as cb_err:
|
||||||
logging.debug(f"Tool progress callback error: {cb_err}")
|
logging.debug(f"Tool progress callback error: {cb_err}")
|
||||||
|
|
||||||
for tc, name, args in parsed_calls:
|
for tc, name, args, block_result, blocked_by_guardrail in parsed_calls:
|
||||||
|
if block_result is not None:
|
||||||
|
continue
|
||||||
if self.tool_start_callback:
|
if self.tool_start_callback:
|
||||||
try:
|
try:
|
||||||
self.tool_start_callback(tc.id, name, args)
|
self.tool_start_callback(tc.id, name, args)
|
||||||
@ -9348,8 +9418,11 @@ class AIAgent:
|
|||||||
logging.debug(f"Tool start callback error: {cb_err}")
|
logging.debug(f"Tool start callback error: {cb_err}")
|
||||||
|
|
||||||
# ── Concurrent execution ─────────────────────────────────────────
|
# ── Concurrent execution ─────────────────────────────────────────
|
||||||
# Each slot holds (function_name, function_args, function_result, duration, error_flag)
|
# Each slot holds (function_name, function_args, function_result, duration, error_flag, blocked_flag)
|
||||||
results = [None] * num_tools
|
results = [None] * num_tools
|
||||||
|
for i, (tc, name, args, block_result, blocked_by_guardrail) in enumerate(parsed_calls):
|
||||||
|
if block_result is not None:
|
||||||
|
results[i] = (name, args, block_result, 0.0, True, True)
|
||||||
|
|
||||||
# Touch activity before launching workers so the gateway knows
|
# Touch activity before launching workers so the gateway knows
|
||||||
# we're executing tools (not stuck).
|
# we're executing tools (not stuck).
|
||||||
@ -9404,7 +9477,14 @@ class AIAgent:
|
|||||||
pass
|
pass
|
||||||
start = time.time()
|
start = time.time()
|
||||||
try:
|
try:
|
||||||
result = self._invoke_tool(function_name, function_args, effective_task_id, tool_call.id, messages=messages)
|
result = self._invoke_tool(
|
||||||
|
function_name,
|
||||||
|
function_args,
|
||||||
|
effective_task_id,
|
||||||
|
tool_call.id,
|
||||||
|
messages=messages,
|
||||||
|
pre_tool_block_checked=True,
|
||||||
|
)
|
||||||
except Exception as tool_error:
|
except Exception as tool_error:
|
||||||
result = f"Error executing tool '{function_name}': {tool_error}"
|
result = f"Error executing tool '{function_name}': {tool_error}"
|
||||||
logger.error("_invoke_tool raised for %s: %s", function_name, tool_error, exc_info=True)
|
logger.error("_invoke_tool raised for %s: %s", function_name, tool_error, exc_info=True)
|
||||||
@ -9414,7 +9494,7 @@ class AIAgent:
|
|||||||
logger.info("tool %s failed (%.2fs): %s", function_name, duration, result[:200])
|
logger.info("tool %s failed (%.2fs): %s", function_name, duration, result[:200])
|
||||||
else:
|
else:
|
||||||
logger.info("tool %s completed (%.2fs, %d chars)", function_name, duration, len(result))
|
logger.info("tool %s completed (%.2fs, %d chars)", function_name, duration, len(result))
|
||||||
results[index] = (function_name, function_args, result, duration, is_error)
|
results[index] = (function_name, function_args, result, duration, is_error, False)
|
||||||
# Tear down worker-tid tracking. Clear any interrupt bit we may
|
# Tear down worker-tid tracking. Clear any interrupt bit we may
|
||||||
# have set so the next task scheduled onto this recycled tid
|
# have set so the next task scheduled onto this recycled tid
|
||||||
# starts with a clean slate.
|
# starts with a clean slate.
|
||||||
@ -9440,61 +9520,67 @@ class AIAgent:
|
|||||||
spinner.start()
|
spinner.start()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
max_workers = min(num_tools, _MAX_TOOL_WORKERS)
|
runnable_calls = [
|
||||||
with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor:
|
(i, tc, name, args)
|
||||||
futures = []
|
for i, (tc, name, args, block_result, blocked_by_guardrail) in enumerate(parsed_calls)
|
||||||
for i, (tc, name, args) in enumerate(parsed_calls):
|
if block_result is None
|
||||||
# Propagate ContextVars (e.g. _approval_session_key); mirrors asyncio.to_thread.
|
]
|
||||||
ctx = contextvars.copy_context()
|
futures = []
|
||||||
f = executor.submit(ctx.run, _run_tool, i, tc, name, args)
|
if runnable_calls:
|
||||||
futures.append(f)
|
max_workers = min(len(runnable_calls), _MAX_TOOL_WORKERS)
|
||||||
|
with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor:
|
||||||
|
for i, tc, name, args in runnable_calls:
|
||||||
|
# Propagate ContextVars (e.g. _approval_session_key); mirrors asyncio.to_thread.
|
||||||
|
ctx = contextvars.copy_context()
|
||||||
|
f = executor.submit(ctx.run, _run_tool, i, tc, name, args)
|
||||||
|
futures.append(f)
|
||||||
|
|
||||||
# Wait for all to complete with periodic heartbeats so the
|
# Wait for all to complete with periodic heartbeats so the
|
||||||
# gateway's inactivity monitor doesn't kill us during long
|
# gateway's inactivity monitor doesn't kill us during long
|
||||||
# concurrent tool batches. Also check for user interrupts
|
# concurrent tool batches. Also check for user interrupts
|
||||||
# so we don't block indefinitely when the user sends /stop
|
# so we don't block indefinitely when the user sends /stop
|
||||||
# or a new message during concurrent tool execution.
|
# or a new message during concurrent tool execution.
|
||||||
_conc_start = time.time()
|
_conc_start = time.time()
|
||||||
_interrupt_logged = False
|
_interrupt_logged = False
|
||||||
while True:
|
while True:
|
||||||
done, not_done = concurrent.futures.wait(
|
done, not_done = concurrent.futures.wait(
|
||||||
futures, timeout=5.0,
|
futures, timeout=5.0,
|
||||||
)
|
|
||||||
if not not_done:
|
|
||||||
break
|
|
||||||
|
|
||||||
# Check for interrupt — the per-thread interrupt signal
|
|
||||||
# already causes individual tools (terminal, execute_code)
|
|
||||||
# to abort, but tools without interrupt checks (web_search,
|
|
||||||
# read_file) will run to completion. Cancel any futures
|
|
||||||
# that haven't started yet so we don't block on them.
|
|
||||||
if self._interrupt_requested:
|
|
||||||
if not _interrupt_logged:
|
|
||||||
_interrupt_logged = True
|
|
||||||
self._vprint(
|
|
||||||
f"{self.log_prefix}⚡ Interrupt: cancelling "
|
|
||||||
f"{len(not_done)} pending concurrent tool(s)",
|
|
||||||
force=True,
|
|
||||||
)
|
|
||||||
for f in not_done:
|
|
||||||
f.cancel()
|
|
||||||
# Give already-running tools a moment to notice the
|
|
||||||
# per-thread interrupt signal and exit gracefully.
|
|
||||||
concurrent.futures.wait(not_done, timeout=3.0)
|
|
||||||
break
|
|
||||||
|
|
||||||
_conc_elapsed = int(time.time() - _conc_start)
|
|
||||||
# Heartbeat every ~30s (6 × 5s poll intervals)
|
|
||||||
if _conc_elapsed > 0 and _conc_elapsed % 30 < 6:
|
|
||||||
_still_running = [
|
|
||||||
parsed_calls[futures.index(f)][1]
|
|
||||||
for f in not_done
|
|
||||||
if f in futures
|
|
||||||
]
|
|
||||||
self._touch_activity(
|
|
||||||
f"concurrent tools running ({_conc_elapsed}s, "
|
|
||||||
f"{len(not_done)} remaining: {', '.join(_still_running[:3])})"
|
|
||||||
)
|
)
|
||||||
|
if not not_done:
|
||||||
|
break
|
||||||
|
|
||||||
|
# Check for interrupt — the per-thread interrupt signal
|
||||||
|
# already causes individual tools (terminal, execute_code)
|
||||||
|
# to abort, but tools without interrupt checks (web_search,
|
||||||
|
# read_file) will run to completion. Cancel any futures
|
||||||
|
# that haven't started yet so we don't block on them.
|
||||||
|
if self._interrupt_requested:
|
||||||
|
if not _interrupt_logged:
|
||||||
|
_interrupt_logged = True
|
||||||
|
self._vprint(
|
||||||
|
f"{self.log_prefix}⚡ Interrupt: cancelling "
|
||||||
|
f"{len(not_done)} pending concurrent tool(s)",
|
||||||
|
force=True,
|
||||||
|
)
|
||||||
|
for f in not_done:
|
||||||
|
f.cancel()
|
||||||
|
# Give already-running tools a moment to notice the
|
||||||
|
# per-thread interrupt signal and exit gracefully.
|
||||||
|
concurrent.futures.wait(not_done, timeout=3.0)
|
||||||
|
break
|
||||||
|
|
||||||
|
_conc_elapsed = int(time.time() - _conc_start)
|
||||||
|
# Heartbeat every ~30s (6 × 5s poll intervals)
|
||||||
|
if _conc_elapsed > 0 and _conc_elapsed % 30 < 6:
|
||||||
|
_still_running = [
|
||||||
|
parsed_calls[futures.index(f)][1]
|
||||||
|
for f in not_done
|
||||||
|
if f in futures
|
||||||
|
]
|
||||||
|
self._touch_activity(
|
||||||
|
f"concurrent tools running ({_conc_elapsed}s, "
|
||||||
|
f"{len(not_done)} remaining: {', '.join(_still_running[:3])})"
|
||||||
|
)
|
||||||
finally:
|
finally:
|
||||||
if spinner:
|
if spinner:
|
||||||
# Build a summary message for the spinner stop
|
# Build a summary message for the spinner stop
|
||||||
@ -9503,8 +9589,9 @@ class AIAgent:
|
|||||||
spinner.stop(f"⚡ {completed}/{num_tools} tools completed in {total_dur:.1f}s total")
|
spinner.stop(f"⚡ {completed}/{num_tools} tools completed in {total_dur:.1f}s total")
|
||||||
|
|
||||||
# ── Post-execution: display per-tool results ─────────────────────
|
# ── Post-execution: display per-tool results ─────────────────────
|
||||||
for i, (tc, name, args) in enumerate(parsed_calls):
|
for i, (tc, name, args, block_result, blocked_by_guardrail) in enumerate(parsed_calls):
|
||||||
r = results[i]
|
r = results[i]
|
||||||
|
blocked = False
|
||||||
if r is None:
|
if r is None:
|
||||||
# Tool was cancelled (interrupt) or thread didn't return
|
# Tool was cancelled (interrupt) or thread didn't return
|
||||||
if self._interrupt_requested:
|
if self._interrupt_requested:
|
||||||
@ -9513,13 +9600,21 @@ class AIAgent:
|
|||||||
function_result = f"Error executing tool '{name}': thread did not return a result"
|
function_result = f"Error executing tool '{name}': thread did not return a result"
|
||||||
tool_duration = 0.0
|
tool_duration = 0.0
|
||||||
else:
|
else:
|
||||||
function_name, function_args, function_result, tool_duration, is_error = r
|
function_name, function_args, function_result, tool_duration, is_error, blocked = r
|
||||||
|
|
||||||
|
if not blocked:
|
||||||
|
function_result = self._append_guardrail_observation(
|
||||||
|
function_name,
|
||||||
|
function_args,
|
||||||
|
function_result,
|
||||||
|
failed=is_error,
|
||||||
|
)
|
||||||
|
|
||||||
if is_error:
|
if is_error:
|
||||||
result_preview = function_result[:200] if len(function_result) > 200 else function_result
|
result_preview = function_result[:200] if len(function_result) > 200 else function_result
|
||||||
logger.warning("Tool %s returned error (%.2fs): %s", function_name, tool_duration, result_preview)
|
logger.warning("Tool %s returned error (%.2fs): %s", function_name, tool_duration, result_preview)
|
||||||
|
|
||||||
if self.tool_progress_callback:
|
if not blocked and self.tool_progress_callback:
|
||||||
try:
|
try:
|
||||||
self.tool_progress_callback(
|
self.tool_progress_callback(
|
||||||
"tool.completed", function_name, None, None,
|
"tool.completed", function_name, None, None,
|
||||||
@ -9547,7 +9642,7 @@ class AIAgent:
|
|||||||
self._current_tool = None
|
self._current_tool = None
|
||||||
self._touch_activity(f"tool completed: {name} ({tool_duration:.1f}s)")
|
self._touch_activity(f"tool completed: {name} ({tool_duration:.1f}s)")
|
||||||
|
|
||||||
if self.tool_complete_callback:
|
if not blocked and self.tool_complete_callback:
|
||||||
try:
|
try:
|
||||||
self.tool_complete_callback(tc.id, name, args, function_result)
|
self.tool_complete_callback(tc.id, name, args, function_result)
|
||||||
except Exception as cb_err:
|
except Exception as cb_err:
|
||||||
@ -9629,9 +9724,17 @@ class AIAgent:
|
|||||||
except Exception:
|
except Exception:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
if _block_msg is not None:
|
_guardrail_block_decision: ToolGuardrailDecision | None = None
|
||||||
# Tool blocked by plugin policy — skip counter resets.
|
if _block_msg is None:
|
||||||
# Execution is handled below in the tool dispatch chain.
|
guardrail_decision = self._tool_guardrails.before_call(function_name, function_args)
|
||||||
|
if not guardrail_decision.allows_execution:
|
||||||
|
_guardrail_block_decision = guardrail_decision
|
||||||
|
|
||||||
|
_execution_blocked = _block_msg is not None or _guardrail_block_decision is not None
|
||||||
|
|
||||||
|
if _execution_blocked:
|
||||||
|
# Tool blocked by plugin or guardrail policy — skip counters,
|
||||||
|
# callbacks, checkpointing, activity mutation, and real execution.
|
||||||
pass
|
pass
|
||||||
else:
|
else:
|
||||||
# Reset nudge counters when the relevant tool is actually used
|
# Reset nudge counters when the relevant tool is actually used
|
||||||
@ -9649,35 +9752,35 @@ class AIAgent:
|
|||||||
args_preview = args_str[:self.log_prefix_chars] + "..." if len(args_str) > self.log_prefix_chars else args_str
|
args_preview = args_str[:self.log_prefix_chars] + "..." if len(args_str) > self.log_prefix_chars else args_str
|
||||||
print(f" 📞 Tool {i}: {function_name}({list(function_args.keys())}) - {args_preview}")
|
print(f" 📞 Tool {i}: {function_name}({list(function_args.keys())}) - {args_preview}")
|
||||||
|
|
||||||
if _block_msg is None:
|
if not _execution_blocked:
|
||||||
self._current_tool = function_name
|
self._current_tool = function_name
|
||||||
self._touch_activity(f"executing tool: {function_name}")
|
self._touch_activity(f"executing tool: {function_name}")
|
||||||
|
|
||||||
# Set activity callback for long-running tool execution (terminal
|
# Set activity callback for long-running tool execution (terminal
|
||||||
# commands, etc.) so the gateway's inactivity monitor doesn't kill
|
# commands, etc.) so the gateway's inactivity monitor doesn't kill
|
||||||
# the agent while a command is running.
|
# the agent while a command is running.
|
||||||
if _block_msg is None:
|
if not _execution_blocked:
|
||||||
try:
|
try:
|
||||||
from tools.environments.base import set_activity_callback
|
from tools.environments.base import set_activity_callback
|
||||||
set_activity_callback(self._touch_activity)
|
set_activity_callback(self._touch_activity)
|
||||||
except Exception:
|
except Exception:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
if _block_msg is None and self.tool_progress_callback:
|
if not _execution_blocked and self.tool_progress_callback:
|
||||||
try:
|
try:
|
||||||
preview = _build_tool_preview(function_name, function_args)
|
preview = _build_tool_preview(function_name, function_args)
|
||||||
self.tool_progress_callback("tool.started", function_name, preview, function_args)
|
self.tool_progress_callback("tool.started", function_name, preview, function_args)
|
||||||
except Exception as cb_err:
|
except Exception as cb_err:
|
||||||
logging.debug(f"Tool progress callback error: {cb_err}")
|
logging.debug(f"Tool progress callback error: {cb_err}")
|
||||||
|
|
||||||
if _block_msg is None and self.tool_start_callback:
|
if not _execution_blocked and self.tool_start_callback:
|
||||||
try:
|
try:
|
||||||
self.tool_start_callback(tool_call.id, function_name, function_args)
|
self.tool_start_callback(tool_call.id, function_name, function_args)
|
||||||
except Exception as cb_err:
|
except Exception as cb_err:
|
||||||
logging.debug(f"Tool start callback error: {cb_err}")
|
logging.debug(f"Tool start callback error: {cb_err}")
|
||||||
|
|
||||||
# Checkpoint: snapshot working dir before file-mutating tools
|
# Checkpoint: snapshot working dir before file-mutating tools
|
||||||
if _block_msg is None and function_name in ("write_file", "patch") and self._checkpoint_mgr.enabled:
|
if not _execution_blocked and function_name in ("write_file", "patch") and self._checkpoint_mgr.enabled:
|
||||||
try:
|
try:
|
||||||
file_path = function_args.get("path", "")
|
file_path = function_args.get("path", "")
|
||||||
if file_path:
|
if file_path:
|
||||||
@ -9689,7 +9792,7 @@ class AIAgent:
|
|||||||
pass # never block tool execution
|
pass # never block tool execution
|
||||||
|
|
||||||
# Checkpoint before destructive terminal commands
|
# Checkpoint before destructive terminal commands
|
||||||
if _block_msg is None and function_name == "terminal" and self._checkpoint_mgr.enabled:
|
if not _execution_blocked and function_name == "terminal" and self._checkpoint_mgr.enabled:
|
||||||
try:
|
try:
|
||||||
cmd = function_args.get("command", "")
|
cmd = function_args.get("command", "")
|
||||||
if _is_destructive_command(cmd):
|
if _is_destructive_command(cmd):
|
||||||
@ -9706,6 +9809,11 @@ class AIAgent:
|
|||||||
# Tool blocked by plugin policy — return error without executing.
|
# Tool blocked by plugin policy — return error without executing.
|
||||||
function_result = json.dumps({"error": _block_msg}, ensure_ascii=False)
|
function_result = json.dumps({"error": _block_msg}, ensure_ascii=False)
|
||||||
tool_duration = 0.0
|
tool_duration = 0.0
|
||||||
|
elif _guardrail_block_decision is not None:
|
||||||
|
# Tool blocked by tool-loop guardrail — synthesize exactly one
|
||||||
|
# tool result for the original tool_call_id without executing.
|
||||||
|
function_result = self._guardrail_block_result(_guardrail_block_decision)
|
||||||
|
tool_duration = 0.0
|
||||||
elif function_name == "todo":
|
elif function_name == "todo":
|
||||||
from tools.todo_tool import todo_tool as _todo_tool
|
from tools.todo_tool import todo_tool as _todo_tool
|
||||||
function_result = _todo_tool(
|
function_result = _todo_tool(
|
||||||
@ -9889,12 +9997,22 @@ class AIAgent:
|
|||||||
# Log tool errors to the persistent error log so [error] tags
|
# Log tool errors to the persistent error log so [error] tags
|
||||||
# in the UI always have a corresponding detailed entry on disk.
|
# in the UI always have a corresponding detailed entry on disk.
|
||||||
_is_error_result, _ = _detect_tool_failure(function_name, function_result)
|
_is_error_result, _ = _detect_tool_failure(function_name, function_result)
|
||||||
|
if not _execution_blocked:
|
||||||
|
function_result = self._append_guardrail_observation(
|
||||||
|
function_name,
|
||||||
|
function_args,
|
||||||
|
function_result,
|
||||||
|
failed=_is_error_result,
|
||||||
|
)
|
||||||
|
result_preview = function_result if self.verbose_logging else (
|
||||||
|
function_result[:200] if len(function_result) > 200 else function_result
|
||||||
|
)
|
||||||
if _is_error_result:
|
if _is_error_result:
|
||||||
logger.warning("Tool %s returned error (%.2fs): %s", function_name, tool_duration, result_preview)
|
logger.warning("Tool %s returned error (%.2fs): %s", function_name, tool_duration, result_preview)
|
||||||
else:
|
else:
|
||||||
logger.info("tool %s completed (%.2fs, %d chars)", function_name, tool_duration, len(function_result))
|
logger.info("tool %s completed (%.2fs, %d chars)", function_name, tool_duration, len(function_result))
|
||||||
|
|
||||||
if self.tool_progress_callback:
|
if not _execution_blocked and self.tool_progress_callback:
|
||||||
try:
|
try:
|
||||||
self.tool_progress_callback(
|
self.tool_progress_callback(
|
||||||
"tool.completed", function_name, None, None,
|
"tool.completed", function_name, None, None,
|
||||||
@ -9910,7 +10028,7 @@ class AIAgent:
|
|||||||
logging.debug(f"Tool {function_name} completed in {tool_duration:.2f}s")
|
logging.debug(f"Tool {function_name} completed in {tool_duration:.2f}s")
|
||||||
logging.debug(f"Tool result ({len(function_result)} chars): {function_result}")
|
logging.debug(f"Tool result ({len(function_result)} chars): {function_result}")
|
||||||
|
|
||||||
if self.tool_complete_callback:
|
if not _execution_blocked and self.tool_complete_callback:
|
||||||
try:
|
try:
|
||||||
self.tool_complete_callback(tool_call.id, function_name, function_args, function_result)
|
self.tool_complete_callback(tool_call.id, function_name, function_args, function_result)
|
||||||
except Exception as cb_err:
|
except Exception as cb_err:
|
||||||
@ -10244,6 +10362,8 @@ class AIAgent:
|
|||||||
self._last_content_tools_all_housekeeping = False
|
self._last_content_tools_all_housekeeping = False
|
||||||
self._mute_post_response = False
|
self._mute_post_response = False
|
||||||
self._unicode_sanitization_passes = 0
|
self._unicode_sanitization_passes = 0
|
||||||
|
self._tool_guardrails.reset_for_turn()
|
||||||
|
self._tool_guardrail_halt_decision = None
|
||||||
|
|
||||||
# Pre-turn connection health check: detect and clean up dead TCP
|
# Pre-turn connection health check: detect and clean up dead TCP
|
||||||
# connections left over from provider outages or dropped streams.
|
# connections left over from provider outages or dropped streams.
|
||||||
@ -13041,6 +13161,16 @@ class AIAgent:
|
|||||||
|
|
||||||
self._execute_tool_calls(assistant_message, messages, effective_task_id, api_call_count)
|
self._execute_tool_calls(assistant_message, messages, effective_task_id, api_call_count)
|
||||||
|
|
||||||
|
if self._tool_guardrail_halt_decision is not None:
|
||||||
|
decision = self._tool_guardrail_halt_decision
|
||||||
|
_turn_exit_reason = "guardrail_halt"
|
||||||
|
final_response = self._toolguard_controlled_halt_response(decision)
|
||||||
|
self._emit_status(
|
||||||
|
f"⚠️ Tool guardrail halted {decision.tool_name}: {decision.code}"
|
||||||
|
)
|
||||||
|
messages.append({"role": "assistant", "content": final_response})
|
||||||
|
break
|
||||||
|
|
||||||
# Reset per-turn retry counters after successful tool
|
# Reset per-turn retry counters after successful tool
|
||||||
# execution so a single truncation doesn't poison the
|
# execution so a single truncation doesn't poison the
|
||||||
# entire conversation.
|
# entire conversation.
|
||||||
@ -13567,6 +13697,7 @@ class AIAgent:
|
|||||||
"messages": messages,
|
"messages": messages,
|
||||||
"api_calls": api_call_count,
|
"api_calls": api_call_count,
|
||||||
"completed": completed,
|
"completed": completed,
|
||||||
|
"turn_exit_reason": _turn_exit_reason,
|
||||||
"partial": False, # True only when stopped due to invalid tool calls
|
"partial": False, # True only when stopped due to invalid tool calls
|
||||||
"interrupted": interrupted,
|
"interrupted": interrupted,
|
||||||
"response_previewed": getattr(self, "_response_was_previewed", False),
|
"response_previewed": getattr(self, "_response_was_previewed", False),
|
||||||
@ -13586,6 +13717,8 @@ class AIAgent:
|
|||||||
"cost_status": self.session_cost_status,
|
"cost_status": self.session_cost_status,
|
||||||
"cost_source": self.session_cost_source,
|
"cost_source": self.session_cost_source,
|
||||||
}
|
}
|
||||||
|
if self._tool_guardrail_halt_decision is not None:
|
||||||
|
result["guardrail"] = self._tool_guardrail_halt_decision.to_metadata()
|
||||||
# If a /steer landed after the final assistant turn (no more tool
|
# If a /steer landed after the final assistant turn (no more tool
|
||||||
# batches to drain into), hand it back to the caller so it can be
|
# batches to drain into), hand it back to the caller so it can be
|
||||||
# delivered as the next user turn instead of being silently lost.
|
# delivered as the next user turn instead of being silently lost.
|
||||||
|
|||||||
142
tests/agent/test_tool_guardrails.py
Normal file
142
tests/agent/test_tool_guardrails.py
Normal file
@ -0,0 +1,142 @@
|
|||||||
|
"""Pure tool-call guardrail primitive tests."""
|
||||||
|
|
||||||
|
import json
|
||||||
|
|
||||||
|
from agent.tool_guardrails import (
|
||||||
|
ToolCallGuardrailConfig,
|
||||||
|
ToolCallGuardrailController,
|
||||||
|
ToolCallSignature,
|
||||||
|
canonical_tool_args,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_tool_call_signature_hashes_canonical_nested_unicode_args_without_exposing_raw_args():
|
||||||
|
args_a = {
|
||||||
|
"z": [{"β": "☤", "a": 1}],
|
||||||
|
"a": {"y": 2, "x": "secret-token-value"},
|
||||||
|
}
|
||||||
|
args_b = {
|
||||||
|
"a": {"x": "secret-token-value", "y": 2},
|
||||||
|
"z": [{"a": 1, "β": "☤"}],
|
||||||
|
}
|
||||||
|
|
||||||
|
assert canonical_tool_args(args_a) == canonical_tool_args(args_b)
|
||||||
|
sig_a = ToolCallSignature.from_call("web_search", args_a)
|
||||||
|
sig_b = ToolCallSignature.from_call("web_search", args_b)
|
||||||
|
|
||||||
|
assert sig_a == sig_b
|
||||||
|
assert len(sig_a.args_hash) == 64
|
||||||
|
metadata = sig_a.to_metadata()
|
||||||
|
assert metadata == {"tool_name": "web_search", "args_hash": sig_a.args_hash}
|
||||||
|
assert "secret-token-value" not in json.dumps(metadata)
|
||||||
|
assert "☤" not in json.dumps(metadata)
|
||||||
|
|
||||||
|
|
||||||
|
def test_repeated_identical_failed_call_warns_then_blocks_before_third_execution():
|
||||||
|
controller = ToolCallGuardrailController(
|
||||||
|
ToolCallGuardrailConfig(
|
||||||
|
exact_failure_warn_after=2,
|
||||||
|
exact_failure_block_after=2,
|
||||||
|
same_tool_failure_halt_after=99,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
args = {"query": "same"}
|
||||||
|
|
||||||
|
assert controller.before_call("web_search", args).action == "allow"
|
||||||
|
first = controller.after_call("web_search", args, '{"error":"boom"}', failed=True)
|
||||||
|
assert first.action == "allow"
|
||||||
|
|
||||||
|
assert controller.before_call("web_search", args).action == "allow"
|
||||||
|
second = controller.after_call("web_search", args, '{"error":"boom"}', failed=True)
|
||||||
|
assert second.action == "warn"
|
||||||
|
assert second.code == "repeated_exact_failure_warning"
|
||||||
|
assert second.count == 2
|
||||||
|
|
||||||
|
blocked = controller.before_call("web_search", args)
|
||||||
|
assert blocked.action == "block"
|
||||||
|
assert blocked.code == "repeated_exact_failure_block"
|
||||||
|
assert blocked.tool_name == "web_search"
|
||||||
|
assert blocked.count == 2
|
||||||
|
|
||||||
|
|
||||||
|
def test_success_resets_exact_signature_failure_streak():
|
||||||
|
controller = ToolCallGuardrailController(
|
||||||
|
ToolCallGuardrailConfig(exact_failure_block_after=2, same_tool_failure_halt_after=99)
|
||||||
|
)
|
||||||
|
args = {"query": "same"}
|
||||||
|
|
||||||
|
controller.after_call("web_search", args, '{"error":"boom"}', failed=True)
|
||||||
|
controller.after_call("web_search", args, '{"ok":true}', failed=False)
|
||||||
|
|
||||||
|
assert controller.before_call("web_search", args).action == "allow"
|
||||||
|
controller.after_call("web_search", args, '{"error":"boom"}', failed=True)
|
||||||
|
assert controller.before_call("web_search", args).action == "allow"
|
||||||
|
|
||||||
|
|
||||||
|
def test_same_tool_varying_args_failure_streak_warns_then_halts_independent_of_exact_streak():
|
||||||
|
controller = ToolCallGuardrailController(
|
||||||
|
ToolCallGuardrailConfig(
|
||||||
|
exact_failure_block_after=99,
|
||||||
|
same_tool_failure_warn_after=2,
|
||||||
|
same_tool_failure_halt_after=3,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
first = controller.after_call("terminal", {"command": "cmd-1"}, '{"exit_code":1}', failed=True)
|
||||||
|
assert first.action == "allow"
|
||||||
|
second = controller.after_call("terminal", {"command": "cmd-2"}, '{"exit_code":1}', failed=True)
|
||||||
|
assert second.action == "warn"
|
||||||
|
assert second.code == "same_tool_failure_warning"
|
||||||
|
third = controller.after_call("terminal", {"command": "cmd-3"}, '{"exit_code":1}', failed=True)
|
||||||
|
assert third.action == "halt"
|
||||||
|
assert third.code == "same_tool_failure_halt"
|
||||||
|
assert third.count == 3
|
||||||
|
|
||||||
|
|
||||||
|
def test_idempotent_no_progress_repeated_result_warns_then_blocks_future_repeat():
|
||||||
|
controller = ToolCallGuardrailController(
|
||||||
|
ToolCallGuardrailConfig(no_progress_warn_after=2, no_progress_block_after=2)
|
||||||
|
)
|
||||||
|
args = {"path": "/tmp/same.txt"}
|
||||||
|
result = "same file contents"
|
||||||
|
|
||||||
|
assert controller.before_call("read_file", args).action == "allow"
|
||||||
|
assert controller.after_call("read_file", args, result, failed=False).action == "allow"
|
||||||
|
assert controller.before_call("read_file", args).action == "allow"
|
||||||
|
warn = controller.after_call("read_file", args, result, failed=False)
|
||||||
|
assert warn.action == "warn"
|
||||||
|
assert warn.code == "idempotent_no_progress_warning"
|
||||||
|
|
||||||
|
blocked = controller.before_call("read_file", args)
|
||||||
|
assert blocked.action == "block"
|
||||||
|
assert blocked.code == "idempotent_no_progress_block"
|
||||||
|
|
||||||
|
|
||||||
|
def test_mutating_or_unknown_tools_are_not_blocked_for_repeated_identical_success_output_by_default():
|
||||||
|
controller = ToolCallGuardrailController(
|
||||||
|
ToolCallGuardrailConfig(no_progress_warn_after=2, no_progress_block_after=2)
|
||||||
|
)
|
||||||
|
|
||||||
|
for _ in range(3):
|
||||||
|
assert controller.before_call("write_file", {"path": "/tmp/x", "content": "x"}).action == "allow"
|
||||||
|
assert controller.after_call("write_file", {"path": "/tmp/x", "content": "x"}, "ok", failed=False).action == "allow"
|
||||||
|
assert controller.before_call("custom_tool", {"x": 1}).action == "allow"
|
||||||
|
assert controller.after_call("custom_tool", {"x": 1}, "ok", failed=False).action == "allow"
|
||||||
|
|
||||||
|
|
||||||
|
def test_reset_for_turn_clears_bounded_guardrail_state():
|
||||||
|
controller = ToolCallGuardrailController(
|
||||||
|
ToolCallGuardrailConfig(exact_failure_block_after=2, no_progress_block_after=2)
|
||||||
|
)
|
||||||
|
controller.after_call("web_search", {"query": "same"}, '{"error":"boom"}', failed=True)
|
||||||
|
controller.after_call("web_search", {"query": "same"}, '{"error":"boom"}', failed=True)
|
||||||
|
controller.after_call("read_file", {"path": "/tmp/x"}, "same", failed=False)
|
||||||
|
controller.after_call("read_file", {"path": "/tmp/x"}, "same", failed=False)
|
||||||
|
|
||||||
|
assert controller.before_call("web_search", {"query": "same"}).action == "block"
|
||||||
|
assert controller.before_call("read_file", {"path": "/tmp/x"}).action == "block"
|
||||||
|
|
||||||
|
controller.reset_for_turn()
|
||||||
|
|
||||||
|
assert controller.before_call("web_search", {"query": "same"}).action == "allow"
|
||||||
|
assert controller.before_call("read_file", {"path": "/tmp/x"}).action == "allow"
|
||||||
202
tests/run_agent/test_tool_call_guardrail_runtime.py
Normal file
202
tests/run_agent/test_tool_call_guardrail_runtime.py
Normal file
@ -0,0 +1,202 @@
|
|||||||
|
"""Runtime tests for tool-call loop guardrails."""
|
||||||
|
|
||||||
|
import json
|
||||||
|
import uuid
|
||||||
|
from types import SimpleNamespace
|
||||||
|
from unittest.mock import MagicMock, patch
|
||||||
|
|
||||||
|
from run_agent import AIAgent
|
||||||
|
|
||||||
|
|
||||||
|
def _make_tool_defs(*names: str) -> list[dict]:
|
||||||
|
return [
|
||||||
|
{
|
||||||
|
"type": "function",
|
||||||
|
"function": {
|
||||||
|
"name": name,
|
||||||
|
"description": f"{name} tool",
|
||||||
|
"parameters": {"type": "object", "properties": {}},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
for name in names
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def _mock_tool_call(name="web_search", arguments="{}", call_id=None):
|
||||||
|
return SimpleNamespace(
|
||||||
|
id=call_id or f"call_{uuid.uuid4().hex[:8]}",
|
||||||
|
type="function",
|
||||||
|
function=SimpleNamespace(name=name, arguments=arguments),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _mock_response(content="Hello", finish_reason="stop", tool_calls=None):
|
||||||
|
msg = SimpleNamespace(content=content, tool_calls=tool_calls)
|
||||||
|
choice = SimpleNamespace(message=msg, finish_reason=finish_reason)
|
||||||
|
return SimpleNamespace(choices=[choice], model="test/model", usage=None)
|
||||||
|
|
||||||
|
|
||||||
|
def _make_agent(*tool_names: str, max_iterations: int = 10) -> AIAgent:
|
||||||
|
with (
|
||||||
|
patch("run_agent.get_tool_definitions", return_value=_make_tool_defs(*tool_names)),
|
||||||
|
patch("run_agent.check_toolset_requirements", return_value={}),
|
||||||
|
patch("run_agent.OpenAI"),
|
||||||
|
):
|
||||||
|
agent = AIAgent(
|
||||||
|
api_key="test-key-1234567890",
|
||||||
|
base_url="https://openrouter.ai/api/v1",
|
||||||
|
max_iterations=max_iterations,
|
||||||
|
quiet_mode=True,
|
||||||
|
skip_context_files=True,
|
||||||
|
skip_memory=True,
|
||||||
|
)
|
||||||
|
agent.client = MagicMock()
|
||||||
|
agent._cached_system_prompt = "You are helpful."
|
||||||
|
agent._use_prompt_caching = False
|
||||||
|
agent.tool_delay = 0
|
||||||
|
agent.compression_enabled = False
|
||||||
|
agent.save_trajectories = False
|
||||||
|
return agent
|
||||||
|
|
||||||
|
|
||||||
|
def _seed_exact_failures(agent: AIAgent, tool_name: str, args: dict, count: int = 2) -> None:
|
||||||
|
for _ in range(count):
|
||||||
|
agent._tool_guardrails.after_call(
|
||||||
|
tool_name,
|
||||||
|
args,
|
||||||
|
json.dumps({"error": "boom"}),
|
||||||
|
failed=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_sequential_path_blocks_repeated_exact_failure_before_execution():
|
||||||
|
agent = _make_agent("web_search")
|
||||||
|
args = {"query": "same"}
|
||||||
|
_seed_exact_failures(agent, "web_search", args)
|
||||||
|
starts = []
|
||||||
|
progress = []
|
||||||
|
agent.tool_start_callback = lambda *a, **k: starts.append((a, k))
|
||||||
|
agent.tool_progress_callback = lambda *a, **k: progress.append((a, k))
|
||||||
|
tc = _mock_tool_call("web_search", json.dumps(args), "c-block")
|
||||||
|
msg = SimpleNamespace(content="", tool_calls=[tc])
|
||||||
|
messages = []
|
||||||
|
|
||||||
|
with patch("run_agent.handle_function_call", return_value="SHOULD_NOT_RUN") as mock_hfc:
|
||||||
|
agent._execute_tool_calls_sequential(msg, messages, "task-1")
|
||||||
|
|
||||||
|
mock_hfc.assert_not_called()
|
||||||
|
assert starts == []
|
||||||
|
assert progress == []
|
||||||
|
assert len(messages) == 1
|
||||||
|
assert messages[0]["role"] == "tool"
|
||||||
|
assert messages[0]["tool_call_id"] == "c-block"
|
||||||
|
assert "repeated_exact_failure_block" in messages[0]["content"]
|
||||||
|
|
||||||
|
|
||||||
|
def test_sequential_after_call_appends_guidance_to_tool_result_without_extra_messages():
|
||||||
|
agent = _make_agent("web_search")
|
||||||
|
args = {"query": "same"}
|
||||||
|
_seed_exact_failures(agent, "web_search", args, count=1)
|
||||||
|
tc = _mock_tool_call("web_search", json.dumps(args), "c-warn")
|
||||||
|
msg = SimpleNamespace(content="", tool_calls=[tc])
|
||||||
|
messages = []
|
||||||
|
|
||||||
|
with patch("run_agent.handle_function_call", return_value=json.dumps({"error": "boom"})):
|
||||||
|
agent._execute_tool_calls_sequential(msg, messages, "task-1")
|
||||||
|
|
||||||
|
assert [m["role"] for m in messages] == ["tool"]
|
||||||
|
assert messages[0]["tool_call_id"] == "c-warn"
|
||||||
|
assert "Tool guardrail" in messages[0]["content"]
|
||||||
|
assert "repeated_exact_failure_warning" in messages[0]["content"]
|
||||||
|
|
||||||
|
|
||||||
|
def test_concurrent_path_does_not_submit_blocked_calls_and_preserves_result_order():
|
||||||
|
agent = _make_agent("web_search")
|
||||||
|
blocked_args = {"query": "blocked"}
|
||||||
|
allowed_args = {"query": "allowed"}
|
||||||
|
_seed_exact_failures(agent, "web_search", blocked_args)
|
||||||
|
starts = []
|
||||||
|
progress_events = []
|
||||||
|
agent.tool_start_callback = lambda tool_call_id, name, args: starts.append((tool_call_id, name, args))
|
||||||
|
agent.tool_progress_callback = lambda event, name, preview, args, **kw: progress_events.append((event, name, args, kw))
|
||||||
|
calls = [
|
||||||
|
_mock_tool_call("web_search", json.dumps(blocked_args), "c-block"),
|
||||||
|
_mock_tool_call("web_search", json.dumps(allowed_args), "c-allow"),
|
||||||
|
]
|
||||||
|
msg = SimpleNamespace(content="", tool_calls=calls)
|
||||||
|
messages = []
|
||||||
|
executed = []
|
||||||
|
|
||||||
|
def fake_handle(name, args, task_id, **kwargs):
|
||||||
|
executed.append((name, args, kwargs["tool_call_id"]))
|
||||||
|
return json.dumps({"ok": args["query"]})
|
||||||
|
|
||||||
|
with patch("run_agent.handle_function_call", side_effect=fake_handle):
|
||||||
|
agent._execute_tool_calls_concurrent(msg, messages, "task-1")
|
||||||
|
|
||||||
|
assert executed == [("web_search", allowed_args, "c-allow")]
|
||||||
|
assert [m["tool_call_id"] for m in messages] == ["c-block", "c-allow"]
|
||||||
|
assert "repeated_exact_failure_block" in messages[0]["content"]
|
||||||
|
assert json.loads(messages[1]["content"]) == {"ok": "allowed"}
|
||||||
|
assert starts == [("c-allow", "web_search", allowed_args)]
|
||||||
|
started_events = [event for event in progress_events if event[0] == "tool.started"]
|
||||||
|
completed_events = [event for event in progress_events if event[0] == "tool.completed"]
|
||||||
|
assert started_events == [("tool.started", "web_search", allowed_args, {})]
|
||||||
|
assert len(completed_events) == 1
|
||||||
|
assert completed_events[0][1] == "web_search"
|
||||||
|
|
||||||
|
|
||||||
|
def test_plugin_pre_tool_block_wins_without_counting_as_toolguard_block():
|
||||||
|
agent = _make_agent("web_search")
|
||||||
|
args = {"query": "same"}
|
||||||
|
tc = _mock_tool_call("web_search", json.dumps(args), "c-plugin")
|
||||||
|
msg = SimpleNamespace(content="", tool_calls=[tc])
|
||||||
|
messages = []
|
||||||
|
|
||||||
|
with (
|
||||||
|
patch("hermes_cli.plugins.get_pre_tool_call_block_message", return_value="plugin policy"),
|
||||||
|
patch("run_agent.handle_function_call", return_value="SHOULD_NOT_RUN") as mock_hfc,
|
||||||
|
):
|
||||||
|
agent._execute_tool_calls_sequential(msg, messages, "task-1")
|
||||||
|
|
||||||
|
mock_hfc.assert_not_called()
|
||||||
|
assert "plugin policy" in messages[0]["content"]
|
||||||
|
assert agent._tool_guardrails.before_call("web_search", args).action == "allow"
|
||||||
|
|
||||||
|
|
||||||
|
def test_run_conversation_returns_controlled_guardrail_halt_without_top_level_error():
|
||||||
|
agent = _make_agent("web_search", max_iterations=10)
|
||||||
|
same_args = {"query": "same"}
|
||||||
|
responses = [
|
||||||
|
_mock_response(
|
||||||
|
content="",
|
||||||
|
finish_reason="tool_calls",
|
||||||
|
tool_calls=[_mock_tool_call("web_search", json.dumps(same_args), f"c{i}")],
|
||||||
|
)
|
||||||
|
for i in range(1, 10)
|
||||||
|
]
|
||||||
|
agent.client.chat.completions.create.side_effect = responses
|
||||||
|
|
||||||
|
with (
|
||||||
|
patch("run_agent.handle_function_call", return_value=json.dumps({"error": "boom"})) as mock_hfc,
|
||||||
|
patch.object(agent, "_persist_session"),
|
||||||
|
patch.object(agent, "_save_trajectory"),
|
||||||
|
patch.object(agent, "_cleanup_task_resources"),
|
||||||
|
):
|
||||||
|
result = agent.run_conversation("search repeatedly")
|
||||||
|
|
||||||
|
assert mock_hfc.call_count == 2
|
||||||
|
assert result["api_calls"] == 3
|
||||||
|
assert result["api_calls"] < agent.max_iterations
|
||||||
|
assert result["turn_exit_reason"] == "guardrail_halt"
|
||||||
|
assert "error" not in result
|
||||||
|
assert result["completed"] is True
|
||||||
|
assert "stopped retrying" in result["final_response"]
|
||||||
|
assert result["guardrail"]["code"] == "repeated_exact_failure_block"
|
||||||
|
assert result["guardrail"]["tool_name"] == "web_search"
|
||||||
|
|
||||||
|
assistant_tool_calls = [m for m in result["messages"] if m.get("role") == "assistant" and m.get("tool_calls")]
|
||||||
|
for assistant_msg in assistant_tool_calls:
|
||||||
|
call_ids = [tc["id"] for tc in assistant_msg["tool_calls"]]
|
||||||
|
following_results = [m for m in result["messages"] if m.get("role") == "tool" and m.get("tool_call_id") in call_ids]
|
||||||
|
assert len(following_results) == len(call_ids)
|
||||||
Loading…
Reference in New Issue
Block a user