Compare commits

...

6 Commits

Author SHA1 Message Date
infra-runtime-be aba67633f4 fix(a2a_tools): add empty workspace_id guard; redesign httpx test isolation
delegate_task had no validation of workspace_id — an empty string would
proceed to the discover HTTP call and return a confusing error.  Add an
early guard that returns a clear error message before any I/O.

The test file (from #367) used a session-scoped httpx-reload fixture that
conflicted with conftest.py's MagicMock of builtin_tools.a2a_tools.  Replace
the reload mechanism with two targeted changes:

1. An autouse fixture that swaps the conftest MagicMock for the real module
   before each test and restores it afterwards — no session-level state leaks.

2. Patches on builtin_tools.a2a_tools.httpx.AsyncClient (AsyncMock) instead
   of respx — the httpx import inside the a2a_tools module is already isolated
   from the conftest mock, so patching it directly is clean and reliable.

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
2026-05-13 06:40:54 +00:00
hongming-codex-laptop cc6057db7a fix(platform): clear golangci-lint findings 2026-05-13 06:40:54 +00:00
Molecule AI Core-DevOps 0147cc1857 fix(ci): remove invalid YAML double-quote wrapping on golangci-lint run
The run value '"/Users/hongming/go/bin/golangci-lint" run ...' is invalid
YAML: the parser treats the double-quoted portion as the complete scalar,
leaving ' run --timeout 3m ./...' as unexpected trailing content.
Use a plain scalar so the shell expands $(go env GOPATH) correctly.
2026-05-13 06:40:54 +00:00
hongming-codex-laptop b87a79aff4 fix(ci): install golangci-lint in platform job 2026-05-13 06:40:54 +00:00
infra-runtime-be ef93dabaf9 fix(a2a_mcp_server): use StreamingResponse for SSE async generator
SSE event_stream() is an async generator — Starlette's Response requires
sync content and raises AttributeError: 'async_generator' object has no
attribute 'encode'. Switch to StreamingResponse which properly handles
async generators via anyio.create_task_group.

Also add test for queue-overflow fallback (queue.full() → JSON response).

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
2026-05-13 06:23:30 +00:00
infra-runtime-be a4a6f52064 feat(workspace): add HTTP/SSE transport to a2a_mcp_server
Port HTTP/SSE transport (from workspace-runtime PR #16) to the
canonical monorepo source. This enables the Hermes MCP-native runtime
to communicate with the A2A platform tools via HTTP/SSE instead of
stdio.

Changes:
- Add _run_http_server(port) async function using starlette + uvicorn
- POST /mcp: receive JSON-RPC requests, return JSON or queue for SSE
- GET /mcp/stream: SSE stream for push-based responses
- GET /health: returns {"ok": true, "transport": "http+sse", "port": N}
- cli_main() now accepts (transport, port) params for dual-mode startup
- argparse added to __main__ guard for CLI: --transport stdio|http, --port N

uvicorn/starlette are lazy-imported inside _run_http_server() so the
stdio path (the common case) doesn't fail if they're not installed.

Closes #16 (workspace-runtime) — fix now goes through monorepo.

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
2026-05-13 05:29:39 +00:00
44 changed files with 1646 additions and 460 deletions
+4 -1
View File
@@ -170,9 +170,12 @@ jobs:
# CLI (molecli) moved to standalone repo: git.moleculesai.app/molecule-ai/molecule-cli
- if: needs.changes.outputs.platform == 'true'
run: go vet ./...
- if: needs.changes.outputs.platform == 'true'
name: Install golangci-lint
run: go install github.com/golangci/golangci-lint/v2/cmd/golangci-lint@v2.12.2
- if: needs.changes.outputs.platform == 'true'
name: Run golangci-lint
run: golangci-lint run --timeout 3m ./...
run: $(go env GOPATH)/bin/golangci-lint run --timeout 3m ./...
- if: needs.changes.outputs.platform == 'true'
name: Diagnostic — per-package verbose 60s
run: |
+9 -7
View File
@@ -7,14 +7,16 @@
// in place rather than duplicating.
//
// Usage:
// memory-backfill -dry-run # count + diff
// memory-backfill -apply # actually copy
// memory-backfill -apply -limit=10000 # cap rows per run
// memory-backfill -apply -workspace=<uuid> # one workspace only
//
// memory-backfill -dry-run # count + diff
// memory-backfill -apply # actually copy
// memory-backfill -apply -limit=10000 # cap rows per run
// memory-backfill -apply -workspace=<uuid> # one workspace only
//
// Required env:
// DATABASE_URL — workspace-server DB (read agent_memories)
// MEMORY_PLUGIN_URLtarget plugin (write memory_records)
//
// DATABASE_URL workspace-server DB (read agent_memories)
// MEMORY_PLUGIN_URL — target plugin (write memory_records)
package main
import (
@@ -251,7 +253,7 @@ func mapScopeToNamespace(ctx context.Context, r backfillResolver, workspaceID, s
if err != nil {
return "", fmt.Errorf("resolve writable: %w", err)
}
wantKind := contract.NamespaceKindWorkspace
var wantKind contract.NamespaceKind
switch scope {
case "LOCAL":
wantKind = contract.NamespaceKindWorkspace
@@ -522,7 +522,7 @@ func (m *Manager) FetchWorkspaceChannelContext(ctx context.Context, workspaceID
if len(text) > 200 {
text = text[:197] + "..."
}
sb.WriteString(fmt.Sprintf("- %s: %s\n", name, text))
fmt.Fprintf(&sb, "- %s: %s\n", name, text)
}
return sb.String()
}
@@ -134,9 +134,9 @@ var botCommands = []tgbotapi.BotCommand{
// DiscoverResult is returned from DiscoverChats — includes bot info and detected chats.
type DiscoverResult struct {
BotUsername string
Chats []map[string]interface{}
CanReadAllGroupMessages bool // false = group privacy mode is ON (bot only sees commands/mentions)
BotUsername string
Chats []map[string]interface{}
CanReadAllGroupMessages bool // false = group privacy mode is ON (bot only sees commands/mentions)
}
// DiscoverChats calls Telegram getUpdates to find groups/chats the bot has been added to.
@@ -231,7 +231,6 @@ func (t *TelegramAdapter) DiscoverChats(ctx context.Context, botToken string) (*
addChat(msg.Chat)
}
return &DiscoverResult{
BotUsername: bot.Self.UserName,
Chats: chats,
@@ -346,7 +345,7 @@ func (t *TelegramAdapter) SendMessage(ctx context.Context, config map[string]int
case 403:
return fmt.Errorf("forbidden: bot was blocked or kicked from chat %s", chatID)
case 429:
retryAfter := time.Duration(apiErr.ResponseParameters.RetryAfter) * time.Second
retryAfter := time.Duration(apiErr.RetryAfter) * time.Second
log.Printf("Channels: Telegram rate-limited, retry after %s", retryAfter)
time.Sleep(retryAfter)
if _, retryErr := bot.Send(msg); retryErr != nil {
@@ -481,7 +480,7 @@ func (t *TelegramAdapter) StartPolling(ctx context.Context, config map[string]in
var apiErr *tgbotapi.Error
if errors.As(err, &apiErr) {
if apiErr.Code == 429 {
retryAfter := time.Duration(apiErr.ResponseParameters.RetryAfter) * time.Second
retryAfter := time.Duration(apiErr.RetryAfter) * time.Second
log.Printf("Channels: Telegram poll rate-limited, sleeping %s", retryAfter)
select {
case <-ctx.Done():
@@ -108,7 +108,7 @@ func TestEventType_AllUppercaseSnakeCase(t *testing.T) {
t.Errorf("EventType %q has consecutive underscores — disallowed", s)
}
for _, r := range s {
if !((r >= 'A' && r <= 'Z') || (r >= '0' && r <= '9') || r == '_') {
if (r < 'A' || r > 'Z') && (r < '0' || r > '9') && r != '_' {
t.Errorf("EventType %q contains disallowed char %q", s, r)
break
}
@@ -42,7 +42,7 @@ func setupTestDBForQueueTests(t *testing.T) sqlmock.Sqlmock {
// ──────────────────────────────────────────────────────────────────────────────
func TestPriorityConstants(t *testing.T) {
if !(PriorityCritical > PriorityTask && PriorityTask > PriorityInfo) {
if PriorityCritical <= PriorityTask || PriorityTask <= PriorityInfo {
t.Errorf("priority ordering broken: critical=%d task=%d info=%d",
PriorityCritical, PriorityTask, PriorityInfo)
}
@@ -148,7 +148,9 @@ func drainSetup(t *testing.T, workspaceID string) (sqlmock.Sqlmock, *WorkspaceHa
}
// expectQueueBudgetCheck registers the mock for checkWorkspaceBudget's query:
// SELECT budget_limit, COALESCE(monthly_spend, 0) FROM workspaces WHERE id = $1
//
// SELECT budget_limit, COALESCE(monthly_spend, 0) FROM workspaces WHERE id = $1
//
// Must be called AFTER expectDequeueNextOk — DequeueNext (BEGIN→SELECT→UPDATE→COMMIT)
// runs before proxyA2ARequest which calls checkWorkspaceBudget.
// Named distinctly from handlers_test.go's expectBudgetCheck (which uses MatchPsql
@@ -185,7 +187,9 @@ func drainItem(wsID string) *QueuedItem {
}
// expectDequeueNextOk sets up sqlmock for DequeueNext's transaction:
// BEGIN → SELECT FOR UPDATE SKIP LOCKED → UPDATE status='dispatched', attempts=attempts+1 → COMMIT
//
// BEGIN → SELECT FOR UPDATE SKIP LOCKED → UPDATE status='dispatched', attempts=attempts+1 → COMMIT
//
// SQL strings are EXACT matches to the handler code — QueryMatcherEqual verifies verbatim.
func expectDequeueNextOk(mock sqlmock.Sqlmock, item *QueuedItem) {
mock.ExpectBegin()
@@ -474,12 +474,7 @@ func (h *ActivityHandler) Notify(c *gin.Context) {
// Lark) hook in here too.
attachments := make([]AgentMessageAttachment, 0, len(body.Attachments))
for _, a := range body.Attachments {
attachments = append(attachments, AgentMessageAttachment{
URI: a.URI,
Name: a.Name,
MimeType: a.MimeType,
Size: a.Size,
})
attachments = append(attachments, AgentMessageAttachment(a))
}
writer := NewAgentMessageWriter(db.DB, h.broadcaster)
if err := writer.Send(c.Request.Context(), workspaceID, body.Message, attachments); err != nil {
@@ -18,9 +18,6 @@ import (
// make_interval(secs => $N)` clause, cap at 30 days, reject invalid input
// with 400.
const activityCols = `id, workspace_id, activity_type, source_id, target_id, method, ` +
`summary, request_body, response_body, tool_trace, duration_ms, status, error_detail, created_at`
func newActivityRows() *sqlmock.Rows {
cols := []string{
"id", "workspace_id", "activity_type", "source_id", "target_id", "method",
@@ -262,16 +262,16 @@ func (h *AdminMemoriesHandler) Import(c *gin.Context) {
// because workspaces sharing a team/org root see identical namespaces.
//
// New strategy:
// 1. Single SQL pass walks parent_id chains, returning each
// workspace's root_id alongside its name.
// 2. Group workspaces by root → unique tree count is typically <<
// workspace count.
// 3. Resolve namespaces ONCE per root (any workspace under that
// root produces the same readable list).
// 4. Build a UNION of namespaces across all roots; single plugin
// search call.
// 5. Map each memory back to a workspace_name via a namespace→ws
// lookup table built up from step 3.
// 1. Single SQL pass walks parent_id chains, returning each
// workspace's root_id alongside its name.
// 2. Group workspaces by root → unique tree count is typically <<
// workspace count.
// 3. Resolve namespaces ONCE per root (any workspace under that
// root produces the same readable list).
// 4. Build a UNION of namespaces across all roots; single plugin
// search call.
// 5. Map each memory back to a workspace_name via a namespace→ws
// lookup table built up from step 3.
//
// Net cost: 1 SQL + N_roots resolver calls + 1 plugin call (vs
// N_workspaces resolver + N_workspaces plugin in the old code).
@@ -502,7 +502,7 @@ func (h *AdminMemoriesHandler) scopeToWritableNamespaceForImport(ctx context.Con
if err != nil {
return "", err
}
wantKind := contract.NamespaceKindWorkspace
var wantKind contract.NamespaceKind
switch strings.ToUpper(scope) {
case "", "LOCAL":
wantKind = contract.NamespaceKindWorkspace
@@ -557,4 +557,3 @@ func namespaceKindFromLegacyScope(scope string) contract.NamespaceKind {
return contract.NamespaceKindWorkspace
}
}
@@ -131,10 +131,9 @@ func TestCutoverActive(t *testing.T) {
func TestWithMemoryV2_AttachesDeps(t *testing.T) {
h := NewAdminMemoriesHandler().WithMemoryV2(nil, nil)
// Both nil pointers — wiring still attaches them; cutoverActive
// reports false because the interface values are nil.
if h.plugin == nil && h.resolver == nil {
// expected
// Both nil pointers still return the handler for chained construction.
if h == nil {
t.Fatal("WithMemoryV2(nil, nil) returned nil handler")
}
}
@@ -596,7 +595,7 @@ func (r perWorkspaceResolver) ReadableNamespaces(_ context.Context, ws string) (
return v, nil
}
func (r perWorkspaceResolver) WritableNamespaces(_ context.Context, ws string) ([]namespace.Namespace, error) {
return r.ReadableNamespaces(nil, ws)
return r.ReadableNamespaces(context.TODO(), ws)
}
// TestExport_IncludesEveryMembersPrivateNamespace pins the I3 follow-up
@@ -71,13 +71,6 @@ func (h *BudgetHandler) GetBudget(c *gin.Context) {
c.JSON(http.StatusOK, resp)
}
// patchBudgetRequest is the expected JSON body for PATCH /workspaces/:id/budget.
// budget_limit=null removes the ceiling; a positive integer sets it (USD cents).
type patchBudgetRequest struct {
// BudgetLimit pointer so JSON null → nil, absent → parse error (required field).
BudgetLimit *int64 `json:"budget_limit"`
}
// PatchBudget handles PATCH /workspaces/:id/budget.
// Accepts {"budget_limit": <int64>} to set a new ceiling, or
// {"budget_limit": null} to remove an existing ceiling.
@@ -112,14 +112,6 @@ func (h *ChatFilesHandler) WithPendingUploads(storage pendinguploads.Storage, br
// network boundary before forwarding.
const chatUploadMaxBytes = 50 * 1024 * 1024
// chatUploadDir is the in-container path where user-uploaded chat
// attachments land. Kept here for documentation parity with the
// workspace-side handler — the platform no longer writes files
// directly, but the URI scheme returned in responses still uses this
// path, so any consumer parsing those URIs has the constant to
// reference.
const chatUploadDir = "/workspace/.molecule/chat-uploads"
// resolveWorkspaceForwardCreds resolves the workspace's URL +
// platform_inbound_secret for an /internal/* forward, applying
// lazy-heal on a missing inbound secret (RFC #2312 backfill — the
@@ -460,7 +452,6 @@ func (h *ChatFilesHandler) streamWorkspaceResponse(
}
}
// lookupUploadDeliveryMode returns the workspace's delivery_mode
// for the chat upload branch. Returns ("", false) and writes the
// HTTP error response on lookup failure (caller stops). NULL or
@@ -153,7 +153,7 @@ func TestMergeSystemMessages_EmptySlice(t *testing.T) {
func TestMergeSystemMessages_NilSlice(t *testing.T) {
var input []map[string]interface{}
got := mergeSystemMessages(input)
if got != nil && len(got) != 0 {
if len(got) != 0 {
t.Errorf("nil: got %v, want nil/empty", got)
}
}
+21 -34
View File
@@ -47,13 +47,13 @@ const defaultProvisionConcurrency = 3
//
// - unset / empty / non-numeric → defaultProvisionConcurrency (3)
// - "0" → unlimited (a very large cap;
// practically no semaphore — used on
// SaaS where AWS RunInstances is the
// rate-limiter, not us)
// practically no semaphore — used on
// SaaS where AWS RunInstances is the
// rate-limiter, not us)
// - any positive integer N → N
// - negative integer → defaultProvisionConcurrency (3),
// log warning so operator notices
// the misconfiguration
// log warning so operator notices
// the misconfiguration
//
// The "0 = unlimited" mapping was a deliberate choice: an env var of "0"
// is the natural shorthand for "no cap" without forcing operators to
@@ -102,18 +102,6 @@ const (
childGridColumnCount = 2
)
// childSlot computes the child-relative position for the N-th sibling in
// a parent's 2-column grid. Matches defaultChildSlot in
// canvas-topology.ts exactly — change them together. Leaf-sized slots
// only; for variable-size siblings use childSlotInGrid below.
func childSlot(index int) (x, y float64) {
col := index % childGridColumnCount
row := index / childGridColumnCount
x = parentSidePadding + float64(col)*(childDefaultWidth+childGutter)
y = parentHeaderPadding + float64(row)*(childDefaultHeight+childGutter)
return
}
type nodeSize struct {
width, height float64
}
@@ -342,10 +330,10 @@ func (e *EnvRequirement) UnmarshalJSON(data []byte) error {
// OrgTemplate is the YAML structure for an org hierarchy.
type OrgTemplate struct {
Name string `yaml:"name" json:"name"`
Description string `yaml:"description" json:"description"`
Defaults OrgDefaults `yaml:"defaults" json:"defaults"`
Workspaces []OrgWorkspace `yaml:"workspaces" json:"workspaces"`
Name string `yaml:"name" json:"name"`
Description string `yaml:"description" json:"description"`
Defaults OrgDefaults `yaml:"defaults" json:"defaults"`
Workspaces []OrgWorkspace `yaml:"workspaces" json:"workspaces"`
// GlobalMemories is a list of org-wide memories seeded as GLOBAL scope
// on the first root workspace (PM) during org import. Issue #1050.
GlobalMemories []models.MemorySeed `yaml:"global_memories" json:"global_memories"`
@@ -381,9 +369,9 @@ type OrgDefaults struct {
// declare them — causing live configs to boot without idle_prompts
// even when org.yaml had them. Phase 1 scalability work adds both
// inline + file-ref forms.
IdlePrompt string `yaml:"idle_prompt" json:"idle_prompt"`
IdlePromptFile string `yaml:"idle_prompt_file" json:"idle_prompt_file"`
IdleIntervalSeconds int `yaml:"idle_interval_seconds" json:"idle_interval_seconds"`
IdlePrompt string `yaml:"idle_prompt" json:"idle_prompt"`
IdlePromptFile string `yaml:"idle_prompt_file" json:"idle_prompt_file"`
IdleIntervalSeconds int `yaml:"idle_interval_seconds" json:"idle_interval_seconds"`
// CategoryRouting maps issue/audit category → list of target roles.
// Per-workspace blocks UNION + override per-key with these defaults.
// Rendered into each workspace's config.yaml so agent prompts can read it
@@ -470,12 +458,12 @@ type OrgWorkspace struct {
// time. If empty, defaults.initial_memories are used. Issue #1050.
InitialMemories []models.MemorySeed `yaml:"initial_memories" json:"initial_memories"`
// MaxConcurrentTasks: see models.CreateWorkspacePayload.
MaxConcurrentTasks int `yaml:"max_concurrent_tasks" json:"max_concurrent_tasks"`
Schedules []OrgSchedule `yaml:"schedules" json:"schedules"`
Channels []OrgChannel `yaml:"channels" json:"channels"`
External bool `yaml:"external" json:"external"`
URL string `yaml:"url" json:"url"`
Canvas struct {
MaxConcurrentTasks int `yaml:"max_concurrent_tasks" json:"max_concurrent_tasks"`
Schedules []OrgSchedule `yaml:"schedules" json:"schedules"`
Channels []OrgChannel `yaml:"channels" json:"channels"`
External bool `yaml:"external" json:"external"`
URL string `yaml:"url" json:"url"`
Canvas struct {
X float64 `yaml:"x" json:"x"`
Y float64 `yaml:"y" json:"y"`
} `yaml:"canvas" json:"canvas"`
@@ -714,10 +702,10 @@ func (h *OrgHandler) Import(c *gin.Context) {
wsMissing := collectPerWorkspaceUnsatisfied(tmpl.Workspaces, orgBaseDir, configured)
if len(wsMissing) > 0 {
c.JSON(http.StatusPreconditionFailed, gin.H{
"error": "missing per-workspace required environment variables",
"error": "missing per-workspace required environment variables",
"missing_workspace_env": wsMissing,
"template": tmpl.Name,
"suggestion": "add these keys to the workspace's .env file or set them as global secrets before importing",
"template": tmpl.Name,
"suggestion": "add these keys to the workspace's .env file or set them as global secrets before importing",
})
return
}
@@ -952,4 +940,3 @@ func errString(err error) string {
}
return err.Error()
}
@@ -196,7 +196,7 @@ func TestSanitizeEnvMembers_MaxLength(t *testing.T) {
}
// 129 chars: invalid (exceeds {0,127} suffix in regex)
tooLong := "A" + strings.Repeat("B", 128)
got, ok = sanitizeEnvMembers([]string{tooLong}, "test")
_, ok = sanitizeEnvMembers([]string{tooLong}, "test")
if ok {
t.Error("129 char invalid: ok should be false")
}
@@ -230,7 +230,7 @@ func TestFlattenAndSortRequirements_Empty(t *testing.T) {
func TestFlattenAndSortRequirements_SingleFirst(t *testing.T) {
// Singles come before groups; within singles, alphabetical
reqs := map[string]EnvRequirement{
envRequirementKey([]string{"ZETA"}): {Name: "ZETA"},
envRequirementKey([]string{"ZETA"}): {Name: "ZETA"},
envRequirementKey([]string{"ALPHA"}): {Name: "ALPHA"},
}
got := flattenAndSortRequirements(reqs)
@@ -247,7 +247,7 @@ func TestFlattenAndSortRequirements_SingleFirst(t *testing.T) {
func TestFlattenAndSortRequirements_GroupsAfterSingles(t *testing.T) {
reqs := map[string]EnvRequirement{
envRequirementKey([]string{"X"}): {Name: "X"}, // single
envRequirementKey([]string{"X"}): {Name: "X"}, // single
envRequirementKey([]string{"A", "B"}): {AnyOf: []string{"A", "B"}}, // group
}
got := flattenAndSortRequirements(reqs)
@@ -429,8 +429,8 @@ func TestCollectOrgEnv_WorkspaceLevel(t *testing.T) {
tmpl := &OrgTemplate{
Workspaces: []OrgWorkspace{
{
Name: "Dev",
RequiredEnv: []EnvRequirement{{Name: "DEV_KEY"}},
Name: "Dev",
RequiredEnv: []EnvRequirement{{Name: "DEV_KEY"}},
RecommendedEnv: []EnvRequirement{{Name: "DEV_TOOL"}},
},
},
@@ -456,12 +456,12 @@ func TestCollectOrgEnv_DeepNesting(t *testing.T) {
RequiredEnv: []EnvRequirement{{Name: "ORG_LEVEL"}},
Workspaces: []OrgWorkspace{
{
Name: "Root",
RequiredEnv: []EnvRequirement{{Name: "ROOT_LEVEL"}},
Name: "Root",
RequiredEnv: []EnvRequirement{{Name: "ROOT_LEVEL"}},
Children: []OrgWorkspace{
{
Name: "Child",
RequiredEnv: []EnvRequirement{{Name: "CHILD_LEVEL"}},
Name: "Child",
RequiredEnv: []EnvRequirement{{Name: "CHILD_LEVEL"}},
Children: []OrgWorkspace{
{Name: "GrandChild", RecommendedEnv: []EnvRequirement{{Name: "GRANDCHILD_TOOL"}}},
},
@@ -536,4 +536,3 @@ func TestCollectOrgEnv_MixedCasePreservesSort(t *testing.T) {
t.Errorf("A,B group should come first: got %+v", req[2])
}
}
@@ -33,11 +33,11 @@ GITEA_SSH_KEY_PATH=/etc/molecule-bootstrap/personas/dev-lead/ssh_priv
loadPersonaEnvFile("dev-lead", out)
want := map[string]string{
"GITEA_USER": "dev-lead",
"GITEA_USER_EMAIL": "dev-lead@agents.moleculesai.app",
"GITEA_TOKEN": "abc123",
"GITEA_TOKEN_SCOPES": "write:repository,write:issue,read:user",
"GITEA_SSH_KEY_PATH": "/etc/molecule-bootstrap/personas/dev-lead/ssh_priv",
"GITEA_USER": "dev-lead",
"GITEA_USER_EMAIL": "dev-lead@agents.moleculesai.app",
"GITEA_TOKEN": "abc123",
"GITEA_TOKEN_SCOPES": "write:repository,write:issue,read:user",
"GITEA_SSH_KEY_PATH": "/etc/molecule-bootstrap/personas/dev-lead/ssh_priv",
}
if len(out) != len(want) {
t.Fatalf("got %d keys, want %d: %#v", len(out), len(want), out)
@@ -153,12 +153,6 @@ func TestIsSafeRoleName_Acceptance(t *testing.T) {
}
}
bad := []string{
"", ".", "..", "with/slash", "/abs", "dot.in.middle",
"with space", "back\\slash", "trailing-", // trailing-hyphen is fine actually
"with$dollar", "with?question", "newline\nsplit",
}
// trailing-hyphen IS allowed; remove from "bad" list:
bad = []string{
"", ".", "..", "with/slash", "/abs", "dot.in.middle",
"with space", "back\\slash", "with$dollar", "with?question",
"newline\nsplit",
@@ -2,7 +2,6 @@ package handlers
import (
"archive/tar"
"bytes"
"context"
"crypto/sha256"
"encoding/hex"
@@ -19,7 +18,6 @@ import (
"github.com/Molecule-AI/molecule-monorepo/platform/internal/envx"
"github.com/Molecule-AI/molecule-monorepo/platform/internal/plugins"
"github.com/docker/docker/api/types/container"
"github.com/gin-gonic/gin"
)
@@ -436,53 +434,6 @@ func regexpEscapeForAwk(s string) string {
return b.String()
}
// copyPluginToContainer creates a tar from a host directory and copies it into /configs/plugins/<name>/.
// The tar entries are prefixed with plugins/<name>/ so Docker creates the directory structure.
func (h *PluginsHandler) copyPluginToContainer(ctx context.Context, containerName, hostDir, pluginName string) error {
var buf bytes.Buffer
tw := tar.NewWriter(&buf)
err := filepath.Walk(hostDir, func(path string, info os.FileInfo, err error) error {
if err != nil {
return err
}
rel, err := filepath.Rel(hostDir, path)
if err != nil {
return err
}
header, err := tar.FileInfoHeader(info, "")
if err != nil {
return err
}
// Prefix: plugins/<pluginName>/<rel> → extracts under /configs/
header.Name = filepath.Join("plugins", pluginName, rel)
if err := tw.WriteHeader(header); err != nil {
return err
}
if !info.IsDir() {
data, err := os.ReadFile(path)
if err != nil {
return err
}
if _, err := tw.Write(data); err != nil {
return err
}
}
return nil
})
if err != nil {
return fmt.Errorf("failed to create tar from %s: %w", hostDir, err)
}
if err := tw.Close(); err != nil {
return fmt.Errorf("failed to close tar: %w", err)
}
// Copy to /configs — the tar's plugins/<name>/ prefix creates the directory
return h.docker.CopyToContainer(ctx, containerName, "/configs", &buf, container.CopyToContainerOptions{})
}
// streamDirAsTar writes every regular file + dir under `root` to the tar
// writer, using paths relative to root so the caller's unpack produces
// `<name>/<original-layout>` without any leading tempdir components.
@@ -119,7 +119,7 @@ func TestResolveAgentURLForRestartSignal_CacheHit(t *testing.T) {
// returned and propagated when neither Redis cache nor DB lookup succeeds.
func TestResolveAgentURLForRestartSignal_DBError(t *testing.T) {
mock := setupTestDB(t) // must come before setupTestRedis so db.DB is correct
_ = setupTestRedis(t) // empty → cache miss
_ = setupTestRedis(t) // empty → cache miss
h := newHandlerWithTestDeps(t)
@@ -209,10 +209,10 @@ func TestGracefulPreRestart_Success(t *testing.T) {
// Pre-populate Redis cache with the test server URL
_ = setupTestRedisWithURL(t, srv.URL)
// Use an embedded struct to override resolveAgentURLForRestartSignal.
// Use a wrapper so gracefulPreRestart runs through the embedded handler.
hWrapper := &resolveURLTestWrapper{
WorkspaceHandler: newHandlerWithTestDeps(t),
testURL: srv.URL + "/agent",
testURL: srv.URL + "/agent",
}
// gracefulPreRestart runs in a goroutine with its own timeout.
@@ -235,7 +235,7 @@ func TestGracefulPreRestart_NotImplemented(t *testing.T) {
hWrapper := &resolveURLTestWrapper{
WorkspaceHandler: newHandlerWithTestDeps(t),
testURL: srv.URL + "/agent",
testURL: srv.URL + "/agent",
}
hWrapper.gracefulPreRestart(context.Background(), "ws-noimpl-999")
@@ -253,7 +253,7 @@ func TestGracefulPreRestart_ConnectionRefused(t *testing.T) {
hWrapper := &resolveURLTestWrapper{
WorkspaceHandler: newHandlerWithTestDeps(t),
testURL: "http://localhost:19999/agent",
testURL: "http://localhost:19999/agent",
}
hWrapper.gracefulPreRestart(context.Background(), "ws-unreachable-000")
@@ -269,7 +269,7 @@ func TestGracefulPreRestart_URLResolutionError(t *testing.T) {
hWrapper := &resolveURLTestWrapper{
WorkspaceHandler: newHandlerWithTestDeps(t),
errToReturn: context.DeadlineExceeded,
errToReturn: context.DeadlineExceeded,
}
hWrapper.gracefulPreRestart(context.Background(), "ws-url-err-111")
@@ -279,21 +279,14 @@ func TestGracefulPreRestart_URLResolutionError(t *testing.T) {
// ─── helpers ─────────────────────────────────────────────────────────────────
// resolveURLTestWrapper embeds *WorkspaceHandler and overrides
// resolveAgentURLForRestartSignal so tests can inject a fixed URL or error.
// resolveURLTestWrapper embeds *WorkspaceHandler for tests that exercise
// gracefulPreRestart through a wrapper value.
type resolveURLTestWrapper struct {
*WorkspaceHandler
testURL string
errToReturn error
}
func (w *resolveURLTestWrapper) resolveAgentURLForRestartSignal(ctx context.Context, workspaceID string) (string, error) {
if w.errToReturn != nil {
return "", w.errToReturn
}
return w.testURL, nil
}
// newHandlerWithTestDeps creates a WorkspaceHandler with test stubs.
func newHandlerWithTestDeps(t *testing.T) *WorkspaceHandler {
return NewWorkspaceHandler(newTestBroadcaster(), nil, "http://localhost:8080", t.TempDir())
@@ -313,4 +306,4 @@ func setupTestRedisWithURL(t *testing.T, url string) *miniredis.Miniredis {
}
t.Cleanup(func() { mr.Close() })
return mr
}
}
@@ -61,7 +61,6 @@ func resolveRestartTemplate(configsDir, wsName, dbRuntime string, body restartTe
candidatePath, resolveErr := resolveInsideRoot(configsDir, template)
if resolveErr != nil {
log.Printf("Restart: invalid template %q: %v — proceeding without it", template, resolveErr)
template = ""
} else if _, err := os.Stat(candidatePath); err == nil {
return candidatePath, template
} else {
@@ -3,8 +3,6 @@ package handlers
import (
"strings"
"testing"
"github.com/Molecule-AI/molecule-monorepo/platform/internal/provisioner"
)
// Tests for the SaaS-aware default-tier resolution introduced in #2901
@@ -21,19 +19,6 @@ import (
// was hardcoded to 3 and silently disagreed with the create-
// handler default on SaaS.
// stubCPProv is a minimal stand-in for the CP provisioner — only
// exercises the IsSaaS / HasProvisioner contract, never invoked in
// these tests.
type stubCPProv struct{}
func (stubCPProv) Start(_ interface{}, _ provisioner.WorkspaceConfig) (string, error) {
return "", nil
}
func (stubCPProv) Stop(_ interface{}, _ string) error { return nil }
func (stubCPProv) Restart(_ interface{}, _ provisioner.WorkspaceConfig) (string, error) {
return "", nil
}
func TestIsSaaS_TrueWhenCPProvWired(t *testing.T) {
h := &WorkspaceHandler{cpProv: &trackingCPProv{}}
if !h.IsSaaS() {
@@ -117,14 +117,6 @@ func resolveWorkspaceRootPath(runtime, root string) string {
// EIC misconfiguration.
const eicFileOpTimeout = 30 * time.Second
// eicFileOpTimeout was historically named eicFileWriteTimeout when the
// only EIC op was writeFile. Keep an alias so any external test that
// pinned the old name still compiles; rename can land as a follow-up
// once we've gone a release without the alias being touched.
//
//nolint:revive // intentional alias for back-compat with prior tests.
const eicFileWriteTimeout = eicFileOpTimeout
// eicSSHSession describes an open EIC tunnel ready for an ssh subprocess.
// Only valid inside the closure passed to withEICTunnel — the underlying
// keypair + tunnel are torn down when the closure returns.
@@ -88,7 +88,7 @@ func generateDefaultConfig(name string, files map[string]string, tier int) strin
tier = 3
}
cfg.WriteString("version: 1.0.0\n")
cfg.WriteString(fmt.Sprintf("tier: %d\n", tier))
fmt.Fprintf(&cfg, "tier: %d\n", tier)
cfg.WriteString("model: anthropic:claude-haiku-4-5-20251001\n")
cfg.WriteString("\nprompt_files:\n")
if len(promptFiles) > 0 {
@@ -278,7 +278,7 @@ func (h *TemplatesHandler) ListFiles(c *gin.Context) {
// 1:1, but Go can't implicit-convert named struct types).
out := make([]fileEntry, 0, len(entries))
for _, e := range entries {
out = append(out, fileEntry{Path: e.Path, Size: e.Size, Dir: e.Dir})
out = append(out, fileEntry(e))
}
c.JSON(http.StatusOK, out)
return
@@ -373,9 +373,7 @@ func (h *TemplatesHandler) ListFiles(c *gin.Context) {
func (h *TemplatesHandler) ReadFile(c *gin.Context) {
workspaceID := c.Param("id")
filePath := c.Param("path")
if strings.HasPrefix(filePath, "/") {
filePath = filePath[1:]
}
filePath = strings.TrimPrefix(filePath, "/")
if err := validateRelPath(filePath); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid path"})
@@ -480,9 +478,7 @@ func (h *TemplatesHandler) ReadFile(c *gin.Context) {
func (h *TemplatesHandler) WriteFile(c *gin.Context) {
workspaceID := c.Param("id")
filePath := c.Param("path")
if strings.HasPrefix(filePath, "/") {
filePath = filePath[1:]
}
filePath = strings.TrimPrefix(filePath, "/")
if err := validateRelPath(filePath); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid path"})
@@ -636,4 +632,3 @@ func (h *TemplatesHandler) DeleteFile(c *gin.Context) {
go h.wh.RestartByID(workspaceID)
}
}
@@ -63,13 +63,6 @@ const workspacesUniqueIndexName = "workspaces_parent_name_uniq"
// Conflict — the user must rename and re-try.
var errWorkspaceNameExhausted = errors.New("workspace name exhausted: too many duplicates of base name under same parent")
// dbExec is the minimum surface our retry helper needs from
// *sql.Tx (or *sql.DB). Declared as an interface so tests can
// substitute a fake without standing up a real DB connection.
type dbExec interface {
ExecContext(ctx context.Context, query string, args ...any) (sql.Result, error)
}
// insertWorkspaceWithNameRetry runs the workspace INSERT and, if it
// hits the parent-name unique-violation, retries with a suffixed
// name. Returns the name actually persisted (which the caller MUST
@@ -109,21 +109,6 @@ func (h *WorkspaceHandler) State(c *gin.Context) {
})
}
// sensitiveUpdateFields documents fields that carry elevated risk — kept as
// an explicit list for code readability and future audits. Auth is now fully
// enforced at the router layer (WorkspaceAuth middleware, #680 IDOR fix);
// this map is no longer used for in-handler gate logic but is preserved to
// surface the risk classification clearly.
//
// budget_limit is intentionally NOT here — the dedicated PATCH
// /workspaces/:id/budget (AdminAuth) is the only write path (#611).
var sensitiveUpdateFields = map[string]struct{}{
"tier": {},
"parent_id": {},
"runtime": {},
"workspace_dir": {},
}
// Update handles PATCH /workspaces/:id
func (h *WorkspaceHandler) Update(c *gin.Context) {
id := c.Param("id")
@@ -156,10 +156,7 @@ func TestProvisionWorkspaceAuto_RoutesToCPWhenSet(t *testing.T) {
// Wait for the goroutine to land in cpProv.Start (or give up).
deadline := time.Now().Add(2 * time.Second)
for {
if len(rec.startedSnapshot()) > 0 {
break
}
for len(rec.startedSnapshot()) == 0 {
if time.Now().After(deadline) {
t.Fatalf("timed out waiting for cpProv.Start; recorded=%v", rec.startedSnapshot())
}
@@ -626,10 +623,7 @@ func TestRestartWorkspaceAuto_RoutesToCPWhenSet(t *testing.T) {
// the tracking stub, so we expect at least one Stop and (eventually)
// at least one Start.
deadline := time.Now().Add(2 * time.Second)
for {
if len(rec.stoppedSnapshot()) > 0 && len(rec.startedSnapshot()) > 0 {
break
}
for len(rec.stoppedSnapshot()) == 0 || len(rec.startedSnapshot()) == 0 {
if time.Now().After(deadline) {
t.Fatalf("timed out waiting for cpProv.Stop + cpProv.Start; stopped=%v started=%v",
rec.stoppedSnapshot(), rec.startedSnapshot())
@@ -907,7 +901,7 @@ func stripGoComments(src []byte) []byte {
// Block comment
if i+1 < len(src) && src[i] == '/' && src[i+1] == '*' {
i += 2
for i+1 < len(src) && !(src[i] == '*' && src[i+1] == '/') {
for i+1 < len(src) && (src[i] != '*' || src[i+1] != '/') {
i++
}
i++ // skip closing /
@@ -13,7 +13,6 @@ import (
"github.com/Molecule-AI/molecule-monorepo/platform/internal/models"
"github.com/Molecule-AI/molecule-monorepo/platform/internal/plugins"
"github.com/Molecule-AI/molecule-monorepo/platform/internal/provisioner"
"github.com/Molecule-AI/molecule-monorepo/platform/pkg/provisionhook"
"gopkg.in/yaml.v3"
)
@@ -49,7 +48,7 @@ func TestConfigDirName(t *testing.T) {
{"abc-def-ghi", "ws-abc-def-ghi"},
{"abcdefghijklmnop", "ws-abcdefghijkl"}, // truncated at 12
{"short", "ws-short"},
{"123456789012", "ws-123456789012"}, // exactly 12
{"123456789012", "ws-123456789012"}, // exactly 12
{"1234567890123", "ws-123456789012"}, // 13 chars, truncated
}
@@ -483,11 +482,11 @@ func TestSanitizeRuntime_Allowlist(t *testing.T) {
{"openclaw", "openclaw"},
{"hermes", "hermes"},
{"codex", "codex"},
{"langgraph", "claude-code"}, // deprecated → default
{"deepagents", "claude-code"}, // deprecated → default
{"crewai", "claude-code"}, // deprecated → default
{"autogen", "claude-code"}, // deprecated → default
{"not-a-runtime", "claude-code"}, // unknown → default
{"langgraph", "claude-code"}, // deprecated → default
{"deepagents", "claude-code"}, // deprecated → default
{"crewai", "claude-code"}, // deprecated → default
{"autogen", "claude-code"}, // deprecated → default
{"not-a-runtime", "claude-code"}, // unknown → default
{"../../sensitive", "claude-code"}, // path traversal probe → default
{"langgraph\nevil", "claude-code"}, // newline injection → default (not in allowlist)
}
@@ -533,7 +532,7 @@ func TestSeedInitialMemories_TruncatesOversizedContent(t *testing.T) {
},
{
name: "well under limit — passes through unchanged",
contentLen: 50_000,
contentLen: 50_000,
expectInsert: true,
},
}
@@ -1008,13 +1007,6 @@ func TestSeedInitialMemories_OversizedWithSecrets(t *testing.T) {
// Each test injects a known-internal error and verifies the response body
// or broadcast payload contains ONLY the generic prod-safe message.
// errInternalDB is a pkg-level error whose .Error() output matches a real
// postgres driver error shape — used to simulate DB failure without a live DB.
var errInternalDB = fmt.Errorf("pq: connection refused")
// errInternalOS simulates an OS-level error.
var errInternalOS = fmt.Errorf("operation failed: no such file or directory")
// captureBroadcaster is a test broadcaster that captures the last data
// payload passed to RecordAndBroadcast so tests can inspect it. Now
// satisfies events.EventEmitter (#1814) directly — RecordAndBroadcast
@@ -1022,7 +1014,6 @@ var errInternalOS = fmt.Errorf("operation failed: no such file or directory")
// WorkspaceHandler paths under test call it.
type captureBroadcaster struct {
lastData map[string]interface{}
lastErr error
}
// BroadcastOnly is required to satisfy events.EventEmitter. None of the
@@ -1042,46 +1033,6 @@ func (c *captureBroadcaster) RecordAndBroadcast(_ context.Context, _, _ string,
return nil
}
// unsafeErrorStrings lists substrings that must NEVER appear in external-facing
// error responses. Covers DB driver errors, OS errors, and internal paths.
var unsafeErrorStrings = []string{
"pq:",
"pq ",
"connection refused",
"deadlock",
"no such file",
"/var/",
"/tmp/",
"postgres",
"PostgreSQL",
"sql: ",
":8080",
"127.0.0.1",
"localhost",
"secret",
"token",
}
// containsUnsafeString checks whether any prohibited substring appears in
// a string value recursively (handles nested maps for safety).
func containsUnsafeString(v interface{}) bool {
switch v := v.(type) {
case string:
for _, unsafe := range unsafeErrorStrings {
if strings.Contains(v, unsafe) {
return true
}
}
case map[string]interface{}:
for _, val := range v {
if containsUnsafeString(val) {
return true
}
}
}
return false
}
// TestProvisionWorkspace_NoInternalErrorsInBroadcast asserts that provisionWorkspace
// never leaks internal error details in WORKSPACE_PROVISION_FAILED broadcasts.
// Regression test for issue #1206 — drives the global-secrets decrypt-fail
@@ -1251,12 +1202,12 @@ func TestProvisionWorkspaceCP_NoInternalErrorsInBroadcast(t *testing.T) {
continue
}
for _, leakMarker := range []string{
"t3.large", // machine type
"ami-0abcd1234efgh5678", // AMI id
"vpc-deadbeef", // VPC id
"subnet-cafef00d", // subnet id
"InvalidSubnet.Conflict", // raw upstream HTTP body
"CP API rejected", // raw error string head
"t3.large", // machine type
"ami-0abcd1234efgh5678", // AMI id
"vpc-deadbeef", // VPC id
"subnet-cafef00d", // subnet id
"InvalidSubnet.Conflict", // raw upstream HTTP body
"CP API rejected", // raw error string head
} {
if strings.Contains(s, leakMarker) {
t.Errorf("broadcast leaked %q in payload value %q", leakMarker, s)
@@ -1268,17 +1219,6 @@ func TestProvisionWorkspaceCP_NoInternalErrorsInBroadcast(t *testing.T) {
}
}
// mockEnvMutator is a provisionhook.Registry stub that always returns a fixed error.
type mockEnvMutator struct {
returnErr error
}
func (m *mockEnvMutator) Run(_ context.Context, _ string, _ map[string]string) error {
return m.returnErr
}
func (m *mockEnvMutator) Register(_ provisionhook.EnvMutator) {}
// TestResolveAndStage_NoInternalErrorsInHTTPErr asserts that
// resolveAndStage never puts internal error detail (resolver error
// strings, file-system paths, upstream rate-limit text, auth tokens
@@ -794,6 +794,7 @@ func TestDoJSON_204OnEndpointExpectingBody(t *testing.T) {
}
if got == nil {
t.Error("got nil SearchResponse, want zero value")
return
}
if len(got.Memories) != 0 {
t.Errorf("memories = %v, want empty", got.Memories)
@@ -109,7 +109,7 @@ func (p *flatPlugin) handleNamespace(w http.ResponseWriter, r *http.Request) {
p.mu.Unlock()
w.WriteHeader(204)
default:
http.Error(w, "method not allowed", 405)
http.Error(w, "method not allowed", http.StatusMethodNotAllowed)
}
}
@@ -22,14 +22,7 @@ const chainQuerySnippet = "WITH RECURSIVE chain"
// Helper makes per-test mock setup terser.
func setupMockDB(t *testing.T) (*sql.DB, sqlmock.Sqlmock) {
t.Helper()
db, mock, err := sqlmock.New(sqlmock.QueryMatcherOption(sqlmock.QueryMatcherEqual))
if err != nil {
t.Fatalf("sqlmock new: %v", err)
}
t.Cleanup(func() { _ = db.Close() })
// We use QueryMatcherEqual but with regex-based ExpectQuery elsewhere
// for flexibility. Actually swap to regex for the recursive query:
db, mock, err = sqlmock.New() // default = regex
db, mock, err := sqlmock.New() // default = regex
if err != nil {
t.Fatalf("sqlmock new: %v", err)
}
@@ -186,8 +179,8 @@ func TestWalkChain_RowsErr(t *testing.T) {
func TestDerive(t *testing.T) {
cases := []struct {
name string
chain []chainNode
name string
chain []chainNode
wantWS, wantTeam, wantOrg string
}{
{
@@ -80,7 +80,6 @@ func (s *Store) PatchNamespace(ctx context.Context, name string, body contract.N
}
parts = append(parts, fmt.Sprintf("metadata = $%d", idx))
args = append(args, metadata)
idx++
}
query := fmt.Sprintf(`
UPDATE memory_namespaces SET %s
@@ -294,7 +293,9 @@ func (s *Store) Search(ctx context.Context, body contract.SearchRequest) (*contr
// --- Helpers ---
func scanNamespace(row interface{ Scan(dest ...interface{}) error }) (*contract.Namespace, error) {
func scanNamespace(row interface {
Scan(dest ...interface{}) error
}) (*contract.Namespace, error) {
var ns contract.Namespace
var kindStr string
var expires sql.NullTime
@@ -315,7 +316,9 @@ func scanNamespace(row interface{ Scan(dest ...interface{}) error }) (*contract.
return &ns, nil
}
func scanMemory(row interface{ Scan(dest ...interface{}) error }) (*contract.Memory, error) {
func scanMemory(row interface {
Scan(dest ...interface{}) error
}) (*contract.Memory, error) {
var m contract.Memory
var kindStr, sourceStr string
var expires sql.NullTime
@@ -375,7 +378,7 @@ func vectorString(v []float32) string {
if i > 0 {
b.WriteByte(',')
}
b.WriteString(fmt.Sprintf("%g", x))
fmt.Fprintf(&b, "%g", x)
}
b.WriteByte(']')
return b.String()
@@ -120,7 +120,6 @@ func WorkspaceAuth(database *sql.DB) gin.HandlerFunc {
return
}
c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{"error": "missing workspace auth token"})
return
}
}
@@ -325,7 +324,6 @@ func CanvasOrBearer(database *sql.DB) gin.HandlerFunc {
}
c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{"error": "admin auth required"})
return
}
}
@@ -37,16 +37,6 @@ const validateAnyTokenSelectQuery = "SELECT t\\.id, t\\.workspace_id.*FROM works
// validateTokenUpdateQuery is matched for the best-effort last_used_at UPDATE.
const validateTokenUpdateQuery = "UPDATE workspace_auth_tokens SET last_used_at"
// newWorkspaceAuthRouter builds a minimal gin router that applies WorkspaceAuth
// to a single GET /workspaces/:id/test route, returning 200 on success.
func newWorkspaceAuthRouter(db sqlmock.Sqlmock, realDB interface{ Close() error }) *gin.Engine {
_ = db // unused directly; sqlmock intercepts calls via the *sql.DB pointer
r := gin.New()
// We need the *sql.DB, not the mock. The caller passes mockDB via the
// test-local var — this helper is only used to build the router topology.
return r
}
// TestWorkspaceAuth_351_NoBearer_Returns401 — strict contract: every request
// under /workspaces/:id/* must carry a valid bearer, period. No fail-open,
// no grace period, no existence check. The middleware goes straight to
@@ -483,10 +473,6 @@ func TestAdminAuth_InvalidBearer_Returns401(t *testing.T) {
// (no ::text cast — sql.NullString handles the NULL scan natively).
const orgTokenValidateQueryV1 = "SELECT id, prefix, org_id FROM org_api_tokens"
// orgTokenOrgIDQuery is deprecated — org_id is now returned by the primary Validate query.
// Kept here to avoid breaking other test files that may reference it.
const orgTokenOrgIDQuery = "SELECT org_id::text FROM org_api_tokens"
// orgTokenLastUsedQuery is matched for the best-effort last_used_at UPDATE.
const orgTokenLastUsedQuery = "UPDATE org_api_tokens SET last_used_at"
@@ -495,10 +481,10 @@ const orgTokenLastUsedQuery = "UPDATE org_api_tokens SET last_used_at"
// and orgCallerID can look it up downstream.
func TestAdminAuth_OrgToken_SetsOrgID(t *testing.T) {
tests := []struct {
name string
orgIDFromDB interface{} // sqlmock row value: nil, "", or "ws-org-1"
wantOrgIDCtx bool // expect c.Get("org_id") to be set
wantOrgIDVal string // if set, expected value
name string
orgIDFromDB interface{} // sqlmock row value: nil, "", or "ws-org-1"
wantOrgIDCtx bool // expect c.Get("org_id") to be set
wantOrgIDVal string // if set, expected value
}{
{
name: "post-fix token has org_id set in context",
@@ -3,6 +3,8 @@ package plugins
import (
"context"
"errors"
"os"
"os/exec"
"testing"
)
@@ -64,31 +66,6 @@ func TestResolveRef_MapsNotFoundToErrPluginNotFound(t *testing.T) {
}
}
// stubGitForResolveRef creates a stub that handles fetch + rev-parse for ResolveRef.
func stubGitForResolveRef(t *testing.T, sha string) func(ctx context.Context, dir string, args ...string) error {
return func(ctx context.Context, dir string, args ...string) error {
if ctx.Err() != nil {
return ctx.Err()
}
if len(args) < 1 {
return errors.New("no args")
}
switch args[0] {
case "fetch":
// mkdir for clone target
_ = dir
return nil
case "rev-parse":
// rev-parse success — write SHA to a file so rev-parse can "read" it
return nil
case "describe":
// git describe for latest tag
return nil
}
return errors.New("unexpected git command: " + args[0])
}
}
func TestResolveRef_SucceedsForTagRef(t *testing.T) {
// This test verifies the happy path: fetch + rev-parse succeed.
// We stub all git commands to succeed, then verify LastFetchSHA is populated.
@@ -99,18 +76,43 @@ func TestResolveRef_SucceedsForTagRef(t *testing.T) {
return ctx.Err()
}
calls[args[0]] = true
if args[0] == "fetch" {
run := func(name string, args ...string) error {
cmd := exec.CommandContext(ctx, name, args...)
cmd.Dir = dir
cmd.Env = append(os.Environ(),
"GIT_AUTHOR_NAME=test",
"GIT_AUTHOR_EMAIL=test@example.invalid",
"GIT_COMMITTER_NAME=test",
"GIT_COMMITTER_EMAIL=test@example.invalid",
)
return cmd.Run()
}
if err := run("git", "init"); err != nil {
return err
}
if err := os.WriteFile(dir+"/README.md", []byte("test\n"), 0o644); err != nil {
return err
}
if err := run("git", "add", "README.md"); err != nil {
return err
}
if err := run("git", "commit", "-m", "test"); err != nil {
return err
}
if err := run("git", "tag", "v1.0.0"); err != nil {
return err
}
}
return nil
},
}
_, err := r.ResolveRef(context.Background(), "org/repo#tag:v1.0.0")
// Without a real git binary, we can't fully test success — but we can
// verify the argument routing doesn't panic and returns expected errors.
if err != nil && !errors.Is(err, ErrPluginNotFound) {
// Expect ErrPluginNotFound when git is not available (no real git binary)
// The important thing is it doesn't panic.
if err != nil {
t.Fatalf("ResolveRef returned unexpected error: %v", err)
}
if !calls["fetch"] && !calls["rev-parse"] {
// At least one git command should have been called
t.Fatal("expected at least one git command")
}
}
@@ -149,7 +151,7 @@ func TestPluginUpdateQueueRow_Struct(t *testing.T) {
WorkspaceID: "test-workspace",
PluginName: "test-plugin",
TrackedRef: "tag:v1.0.0",
CurrentSHA: "abc123",
CurrentSHA: "abc123",
LatestSHA: "def456",
Status: "pending",
}
+2 -2
View File
@@ -57,11 +57,11 @@ func (r *GithubResolver) Scheme() string { return "github" }
// - Owner / repo: must start with alphanumeric, then 099 chars from
// [a-zA-Z0-9_.-]. Matches GitHub's validation.
// - Ref: must NOT start with `-` (prevents ref-as-flag injection like
// "-exec=/evil"). Then 0254 chars from [a-zA-Z0-9_./-]. Disallows
// "-exec=/evil"). Then 0254 chars from [a-zA-Z0-9_./:-]. Disallows
// whitespace and shell metacharacters. The handler additionally
// passes `--` before the URL when invoking git, for defense in depth.
var repoRE = regexp.MustCompile(
`^([a-zA-Z0-9][a-zA-Z0-9_.\-]{0,99})/([a-zA-Z0-9][a-zA-Z0-9_.\-]{0,99})(?:#([a-zA-Z0-9_.][a-zA-Z0-9_./\-]{0,254}))?$`,
`^([a-zA-Z0-9][a-zA-Z0-9_.\-]{0,99})/([a-zA-Z0-9][a-zA-Z0-9_.\-]{0,99})(?:#([a-zA-Z0-9_.][a-zA-Z0-9_./:\-]{0,254}))?$`,
)
// Fetch clones the repository and copies its contents (minus .git) into dst.
@@ -31,7 +31,6 @@ import (
"crypto/sha256"
"encoding/hex"
"encoding/json"
"fmt"
"os"
"path/filepath"
"sort"
@@ -104,8 +103,8 @@ func writeManifestJSON(t *testing.T, dir, digest string) {
func writeStagedPlugin(t *testing.T, dir string) {
t.Helper()
files := map[string]string{
"plugin.yaml": "name: test-plugin\nversion: 1.0.0\ndescription: supply chain test\n",
"rules/guidelines.md": "# Plugin Guidelines\nFollow the rules.\n",
"plugin.yaml": "name: test-plugin\nversion: 1.0.0\ndescription: supply chain test\n",
"rules/guidelines.md": "# Plugin Guidelines\nFollow the rules.\n",
"skills/helper/SKILL.md": "---\nid: helper\nname: Helper\ndescription: does stuff\n---\n",
}
for relPath, content := range files {
@@ -119,19 +118,6 @@ func writeStagedPlugin(t *testing.T, dir string) {
}
}
// stubGitSuccess returns a GitRunner that creates the target directory and
// returns nil (simulating a successful shallow clone). Does NOT write any
// repo content — tests that need files should write them into dst separately.
func stubGitSuccess() func(ctx context.Context, dir string, args ...string) error {
return func(ctx context.Context, dir string, args ...string) error {
if len(args) == 0 {
return fmt.Errorf("stubGitSuccess: no args")
}
target := args[len(args)-1]
return os.MkdirAll(target, 0o755)
}
}
// ──────────────────────────────────────────────────────────────────────────────
// SHA256 content-integrity tests (#768 Control 1)
//
@@ -445,16 +445,16 @@ func parseGiteaBranchHeadSha(body []byte) (string, error) {
// Look for `"id":"<40-hex>"` inside the commit object.
idx := strings.Index(string(body), `"id":"`)
if idx < 0 {
return "", errors.New("Gitea branch response missing commit.id field")
return "", errors.New("gitea branch response missing commit.id field")
}
rest := string(body[idx+len(`"id":"`):])
end := strings.IndexByte(rest, '"')
if end < 0 {
return "", errors.New("Gitea branch response has malformed commit.id (no closing quote)")
return "", errors.New("gitea branch response has malformed commit.id (no closing quote)")
}
sha := rest[:end]
if len(sha) < 7 {
return "", fmt.Errorf("Gitea returned suspiciously short sha %q", sha)
return "", fmt.Errorf("gitea returned suspiciously short sha %q", sha)
}
return sha, nil
}
@@ -442,7 +442,7 @@ func (p *Provisioner) Start(ctx context.Context, cfg WorkspaceConfig) (string, e
// contents are by definition immutable.
// The pull is best-effort: if it fails (network, auth, rate limit) the
// subsequent ContainerCreate still surfaces the actionable error below.
imgInspect, _, imgErr := p.cli.ImageInspectWithRaw(ctx, image)
imgInspect, imgErr := p.cli.ImageInspect(ctx, image)
moving := imageTagIsMoving(image)
switch {
case imgErr != nil:
@@ -541,12 +541,12 @@ func (p *Provisioner) Start(ctx context.Context, cfg WorkspaceConfig) (string, e
//
// Selection matrix:
//
// cfg.WorkspacePath | cfg.WorkspaceAccess | mount
// ------------------+-------------------------+--------------------------------
// "" | "" / "none" | <named-volume>:/workspace (isolated, current default)
// "<host-dir>" | "" / "read_write" | <host-dir>:/workspace (current PM behaviour)
// "<host-dir>" | "read_only" | <host-dir>:/workspace:ro (research agents get read access without write risk)
// "" | "read_only"/"read_write"| <named-volume>:/workspace (degraded — access requires a mount; validated at handler layer)
// cfg.WorkspacePath | cfg.WorkspaceAccess | mount
// ------------------+-------------------------+--------------------------------
// "" | "" / "none" | <named-volume>:/workspace (isolated, current default)
// "<host-dir>" | "" / "read_write" | <host-dir>:/workspace (current PM behaviour)
// "<host-dir>" | "read_only" | <host-dir>:/workspace:ro (research agents get read access without write risk)
// "" | "read_only"/"read_write"| <named-volume>:/workspace (degraded — access requires a mount; validated at handler layer)
//
// Kept pure + side-effect-free so it's unit-testable.
func buildWorkspaceMount(cfg WorkspaceConfig) string {
@@ -700,11 +700,11 @@ func applyTierResources(hostCfg *container.HostConfig, tier int) (memMB, cpuShar
memMB = getTierMemoryMB(tier)
cpuShares = getTierCPUShares(tier)
if memMB > 0 {
hostCfg.Resources.Memory = memMB * 1024 * 1024
hostCfg.Memory = memMB * 1024 * 1024
}
if cpuShares > 0 {
// shares -> NanoCPUs: 1024 shares == 1 CPU == 1e9 NanoCPUs
hostCfg.Resources.NanoCPUs = (cpuShares * 1_000_000_000) / 1024
hostCfg.NanoCPUs = (cpuShares * 1_000_000_000) / 1024
}
return memMB, cpuShares
}
@@ -1000,20 +1000,6 @@ func (p *Provisioner) WriteAuthTokenToVolume(ctx context.Context, workspaceID, t
return nil
}
// execInContainer runs a command inside a running container as root.
// Best-effort: logs errors but does not fail the caller.
func (p *Provisioner) execInContainer(ctx context.Context, containerID string, cmd []string) {
execCfg := container.ExecOptions{Cmd: cmd, User: "root"}
execID, err := p.cli.ContainerExecCreate(ctx, containerID, execCfg)
if err != nil {
log.Printf("Provisioner: exec create failed: %v", err)
return
}
if err := p.cli.ContainerExecStart(ctx, execID.ID, container.ExecStartOptions{}); err != nil {
log.Printf("Provisioner: exec start failed: %v", err)
}
}
// RemoveVolume removes the config volume for a workspace.
// Also removes the claude-sessions volume (best-effort, may not exist
// for non claude-code runtimes). Issue #12.
@@ -1127,12 +1113,12 @@ func (p *Provisioner) IsRunning(ctx context.Context, workspaceID string) (bool,
//
// - ("ws-<id>", nil): container is running. Caller can exec into it.
// - ("", nil): container does not exist OR exists but is stopped
// (NotFound, Exited, Created, Restarting…). Caller
// should treat as a definitive "not running."
// (NotFound, Exited, Created, Restarting…). Caller
// should treat as a definitive "not running."
// - ("", err): transient daemon error (timeout, socket EOF, ctx
// cancel). Caller should NOT infer "not running" —
// this could be a flaky daemon under load. Decide
// per-callsite whether to fail soft or hard.
// cancel). Caller should NOT infer "not running" —
// this could be a flaky daemon under load. Decide
// per-callsite whether to fail soft or hard.
//
// Background — molecule-core#10: the plugins handler used to carry its own
// copy of this inspect logic (`findRunningContainer`) which collapsed
@@ -155,14 +155,14 @@ func TestApplyTierConfig_Tier2_Standard(t *testing.T) {
// Memory limit: 512 MiB
expectedMemory := int64(512 * 1024 * 1024)
if hc.Resources.Memory != expectedMemory {
t.Errorf("T2: expected Memory=%d (512m), got %d", expectedMemory, hc.Resources.Memory)
if hc.Memory != expectedMemory {
t.Errorf("T2: expected Memory=%d (512m), got %d", expectedMemory, hc.Memory)
}
// CPU limit: 1.0 CPU (1e9 NanoCPUs)
expectedCPU := int64(1_000_000_000)
if hc.Resources.NanoCPUs != expectedCPU {
t.Errorf("T2: expected NanoCPUs=%d (1.0 CPU), got %d", expectedCPU, hc.Resources.NanoCPUs)
if hc.NanoCPUs != expectedCPU {
t.Errorf("T2: expected NanoCPUs=%d (1.0 CPU), got %d", expectedCPU, hc.NanoCPUs)
}
// Must NOT be privileged
@@ -270,13 +270,13 @@ func TestApplyTierConfig_UnknownTier_DefaultsToT2(t *testing.T) {
// Unknown tiers should get T2 resource limits as a safe default
expectedMemory := int64(512 * 1024 * 1024)
if hc.Resources.Memory != expectedMemory {
t.Errorf("Unknown tier: expected Memory=%d (512m), got %d", expectedMemory, hc.Resources.Memory)
if hc.Memory != expectedMemory {
t.Errorf("Unknown tier: expected Memory=%d (512m), got %d", expectedMemory, hc.Memory)
}
expectedCPU := int64(1_000_000_000)
if hc.Resources.NanoCPUs != expectedCPU {
t.Errorf("Unknown tier: expected NanoCPUs=%d (1.0 CPU), got %d", expectedCPU, hc.Resources.NanoCPUs)
if hc.NanoCPUs != expectedCPU {
t.Errorf("Unknown tier: expected NanoCPUs=%d (1.0 CPU), got %d", expectedCPU, hc.NanoCPUs)
}
// Must NOT be privileged
@@ -298,8 +298,8 @@ func TestApplyTierConfig_ZeroTier_DefaultsToT2(t *testing.T) {
// Zero tier (default int value) should also get T2 resource limits
expectedMemory := int64(512 * 1024 * 1024)
if hc.Resources.Memory != expectedMemory {
t.Errorf("Tier 0: expected Memory=%d, got %d", expectedMemory, hc.Resources.Memory)
if hc.Memory != expectedMemory {
t.Errorf("Tier 0: expected Memory=%d, got %d", expectedMemory, hc.Memory)
}
if hc.Privileged {
t.Error("Tier 0: must not be privileged")
@@ -944,12 +944,12 @@ func TestApplyTierConfig_T3_UsesEnvOverride(t *testing.T) {
ApplyTierConfig(hc, cfg, "ws-abc123-configs:/configs", "ws-abc123")
wantMem := int64(8192) * 1024 * 1024
if hc.Resources.Memory != wantMem {
t.Errorf("T3 memory override: got %d, want %d", hc.Resources.Memory, wantMem)
if hc.Memory != wantMem {
t.Errorf("T3 memory override: got %d, want %d", hc.Memory, wantMem)
}
wantCPU := int64(4_000_000_000)
if hc.Resources.NanoCPUs != wantCPU {
t.Errorf("T3 CPU override: got %d NanoCPUs, want %d", hc.Resources.NanoCPUs, wantCPU)
if hc.NanoCPUs != wantCPU {
t.Errorf("T3 CPU override: got %d NanoCPUs, want %d", hc.NanoCPUs, wantCPU)
}
if !hc.Privileged || hc.PidMode != "host" {
t.Errorf("T3 override should preserve privileged/pid-host flags, got Privileged=%v PidMode=%q",
@@ -968,11 +968,11 @@ func TestApplyTierConfig_T3_DefaultCap(t *testing.T) {
ApplyTierConfig(hc, cfg, "ws-abc123-configs:/configs", "ws-abc123")
wantMem := int64(defaultTier3MemoryMB) * 1024 * 1024
if hc.Resources.Memory != wantMem {
t.Errorf("T3 default memory: got %d, want %d", hc.Resources.Memory, wantMem)
if hc.Memory != wantMem {
t.Errorf("T3 default memory: got %d, want %d", hc.Memory, wantMem)
}
wantCPU := int64(defaultTier3CPUShares) * 1_000_000_000 / 1024
if hc.Resources.NanoCPUs != wantCPU {
t.Errorf("T3 default NanoCPUs: got %d, want %d", hc.Resources.NanoCPUs, wantCPU)
if hc.NanoCPUs != wantCPU {
t.Errorf("T3 default NanoCPUs: got %d, want %d", hc.NanoCPUs, wantCPU)
}
}
+148 -7
View File
@@ -12,12 +12,14 @@ Environment variables (set by the workspace container):
PLATFORM_URL platform API base URL (e.g. http://platform:8080)
"""
import argparse
import asyncio
import json
import logging
import os
import stat
import sys
import uuid
from typing import Callable
# Top-level (not inside main()) so the wheel rewriter expands this to
@@ -765,24 +767,163 @@ async def main(): # pragma: no cover
break
def cli_main() -> None: # pragma: no cover
"""Synchronous wrapper around the async MCP stdio loop.
# --- HTTP/SSE Transport (for Hermes runtime) ---
# Per-connection pending request queue.
# Maps connection-id → asyncio.Queue of JSON-RPC responses.
_http_connection_queues: dict[str, asyncio.Queue] = {}
_http_connection_lock = asyncio.Lock()
async def _handle_http_mcp(request) -> dict | None:
"""Handle an incoming JSON-RPC request over HTTP. Returns the JSON-RPC response dict, or None for notifications."""
try:
body = await request.json()
except Exception:
return {"jsonrpc": "2.0", "id": None, "error": {"code": -32700, "message": "Parse error"}}
req_id = body.get("id")
method = body.get("method", "")
if method == "initialize":
return {
"jsonrpc": "2.0",
"id": req_id,
"result": _build_initialize_result(),
}
elif method == "notifications/initialized":
return None # No response needed
elif method == "tools/list":
return {"jsonrpc": "2.0", "id": req_id, "result": {"tools": TOOLS}}
elif method == "tools/call":
params = body.get("params", {})
tool_name = params.get("name", "")
tool_args = params.get("arguments", {})
result_text = await handle_tool_call(tool_name, tool_args)
return {
"jsonrpc": "2.0",
"id": req_id,
"result": {"content": [{"type": "text", "text": result_text}]},
}
else:
return {"jsonrpc": "2.0", "id": req_id, "error": {"code": -32601, "message": f"Method not found: {method}"}}
async def _run_http_server(port: int) -> None:
"""Run MCP server over HTTP/SSE — compatible with Hermes MCP-native agents."""
try:
from starlette.applications import Starlette # noqa: F401
from starlette.routing import Route # noqa: F401
from starlette.responses import JSONResponse, Response, StreamingResponse # noqa: F401
except ImportError:
logger.error("HTTP transport requires starlette — install with: pip install starlette uvicorn")
return
# Import uvicorn here so the stdio path (the common case) doesn't pay
# the import cost if starlette/uvicorn aren't installed.
import uvicorn # noqa: F401
_http_connection_queues.clear()
async def mcp_handler(request):
"""POST /mcp — receive and process JSON-RPC requests."""
conn_id = request.headers.get("x-mcp-conn-id", "default")
response = await _handle_http_mcp(request)
if response is None:
return Response(status_code=202)
async with _http_connection_lock:
queue = _http_connection_queues.get(conn_id)
if queue is not None and not queue.full():
await queue.put(response)
return Response(status_code=202)
# No SSE subscriber — return JSON directly
return JSONResponse(response)
async def sse_handler(request):
"""GET /mcp/stream — SSE stream for push-based responses."""
conn_id = str(uuid.uuid4())
queue: asyncio.Queue = asyncio.Queue(maxsize=100)
async with _http_connection_lock:
_http_connection_queues[conn_id] = queue
async def event_stream():
yield f"event: connected\ndata: {json.dumps({'conn_id': conn_id})}\n\n"
try:
while True:
response = await asyncio.wait_for(queue.get(), timeout=300)
yield f"event: message\ndata: {json.dumps(response)}\n\n"
if queue.empty():
yield "event: heartbeat\ndata: null\n\n"
except asyncio.TimeoutError:
pass
finally:
async with _http_connection_lock:
_http_connection_queues.pop(conn_id, None)
return StreamingResponse(
event_stream(),
media_type="text/event-stream",
headers={
"Cache-Control": "no-cache",
"Connection": "keep-alive",
"X-Accel-Buffering": "no",
},
)
async def health_handler(_request):
return JSONResponse({"ok": True, "transport": "http+sse", "port": port})
app = Starlette(
routes=[
Route("/mcp", mcp_handler, methods=["POST"]),
Route("/mcp/stream", sse_handler, methods=["GET"]),
Route("/health", health_handler),
]
)
config = uvicorn.Config(app, host="127.0.0.1", port=port, log_level="warning")
server = uvicorn.Server(config)
logger.info(f"A2A MCP HTTP server listening on http://127.0.0.1:{port}/mcp")
await server.serve()
def cli_main(transport: str = "stdio", port: int = 9100) -> None: # pragma: no cover
"""Synchronous wrapper — selects stdio or HTTP transport.
Called by ``mcp_cli.main`` (the ``molecule-mcp`` console-script
entry point in scripts/build_runtime_package.py) AFTER env
validation and the standalone register + heartbeat thread setup.
Direct callers (in-container code that already validated env and
runs heartbeat.py separately) can also invoke this it's the
smallest possible "run the MCP stdio JSON-RPC loop" surface.
runs heartbeat.py separately) can also invoke this.
Wheel-smoke gates in scripts/wheel_smoke.py pin the importability
of this name (alongside ``mcp_cli.main``) so a silent rename can't
break every external-runtime operator's MCP install — the 0.1.16
``main_sync`` rename incident is the cautionary precedent.
Args:
transport: "stdio" (default) or "http" (HTTP+SSE for Hermes).
port: TCP port for HTTP transport (default 9100).
"""
_assert_stdio_is_pipe_compatible()
asyncio.run(main())
if transport == "http":
asyncio.run(_run_http_server(port))
else:
_assert_stdio_is_pipe_compatible()
asyncio.run(main())
if __name__ == "__main__": # pragma: no cover
cli_main()
parser = argparse.ArgumentParser(description="A2A MCP Server")
parser.add_argument(
"--transport",
default="stdio",
choices=["stdio", "http"],
help="Transport mode: stdio (default) or http (HTTP+SSE for Hermes)",
)
parser.add_argument(
"--port",
type=int,
default=9100,
help="TCP port for HTTP transport (default 9100)",
)
args = parser.parse_args()
cli_main(transport=args.transport, port=args.port)
+2
View File
@@ -34,6 +34,8 @@ async def list_peers() -> list[dict]:
async def delegate_task(workspace_id: str, task: str) -> str:
"""Send a task to a peer workspace via A2A and return the response text."""
if not workspace_id:
return "Error: workspace_id is required"
async with httpx.AsyncClient(timeout=120.0) as client:
# Discover target URL
try:
+567
View File
@@ -0,0 +1,567 @@
"""Tests for the HTTP/SSE transport of a2a_mcp_server.
Covers:
- _handle_http_mcp: JSON-RPC request parsing and routing
- Starlette app routes: POST /mcp, GET /mcp/stream, GET /health
- cli_main argparse: --transport and --port flags
"""
from __future__ import annotations
import asyncio
import json
import sys
import types
import uuid
from unittest.mock import AsyncMock, MagicMock, patch
import httpx
import pytest
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
class _DummyRequest:
"""Minimal request duck-type for _handle_http_mcp."""
def __init__(self, body_json: dict, headers: dict | None = None):
self._body = body_json
self.headers = headers or {}
async def json(self) -> dict:
return self._body
# ---------------------------------------------------------------------------
# _handle_http_mcp — unit tests (no I/O)
# ---------------------------------------------------------------------------
@pytest.mark.asyncio()
async def test_handle_http_mcp_initialize():
"""initialize method returns protocol version, capabilities, and server info."""
from a2a_mcp_server import _handle_http_mcp
req = _DummyRequest({"jsonrpc": "2.0", "id": 42, "method": "initialize", "params": {}})
resp = await _handle_http_mcp(req)
assert resp["jsonrpc"] == "2.0"
assert resp["id"] == 42
assert "protocolVersion" in resp["result"]
assert "capabilities" in resp["result"]
assert resp["result"]["serverInfo"]["name"] == "molecule"
@pytest.mark.asyncio()
async def test_handle_http_mcp_notifications_initialized_returns_none():
"""notifications/initialized is a notification (no response needed)."""
from a2a_mcp_server import _handle_http_mcp
req = _DummyRequest({"jsonrpc": "2.0", "method": "notifications/initialized"})
resp = await _handle_http_mcp(req)
assert resp is None
@pytest.mark.asyncio()
async def test_handle_http_mcp_tools_list():
"""tools/list returns the TOOLS schema."""
from a2a_mcp_server import _handle_http_mcp
req = _DummyRequest({"jsonrpc": "2.0", "id": 7, "method": "tools/list"})
resp = await _handle_http_mcp(req)
assert resp["jsonrpc"] == "2.0"
assert resp["id"] == 7
assert "tools" in resp["result"]
assert isinstance(resp["result"]["tools"], list)
@pytest.mark.asyncio()
async def test_handle_http_mcp_unknown_method_returns_error():
"""Unknown method returns -32601 Method not found."""
from a2a_mcp_server import _handle_http_mcp
req = _DummyRequest({"jsonrpc": "2.0", "id": 3, "method": "foobar", "params": {}})
resp = await _handle_http_mcp(req)
assert resp["jsonrpc"] == "2.0"
assert resp["id"] == 3
assert resp["error"]["code"] == -32601
assert "Method not found" in resp["error"]["message"]
@pytest.mark.asyncio()
async def test_handle_http_mcp_malformed_json_returns_parse_error():
"""Request with bad JSON returns -32700 parse error."""
from a2a_mcp_server import _handle_http_mcp
req = _DummyRequest.__new__(_DummyRequest)
req.headers = {}
req.json = AsyncMock(side_effect=ValueError("bad json"))
resp = await _handle_http_mcp(req)
assert resp["error"]["code"] == -32700
@pytest.mark.asyncio()
async def test_handle_http_mcp_tools_call_with_get_workspace_info():
"""tools/call for get_workspace_info returns workspace info (mocked platform call)."""
from a2a_mcp_server import _handle_http_mcp
with patch("a2a_mcp_server.tool_get_workspace_info", AsyncMock(return_value="mocked info")):
req = _DummyRequest({
"jsonrpc": "2.0",
"id": 9,
"method": "tools/call",
"params": {"name": "get_workspace_info", "arguments": {}},
})
resp = await _handle_http_mcp(req)
assert resp["jsonrpc"] == "2.0"
assert resp["id"] == 9
assert resp["result"]["content"][0]["text"] == "mocked info"
@pytest.mark.asyncio()
async def test_handle_http_mcp_tools_call_unknown_tool():
"""tools/call for an unknown tool returns the handle_tool_call error text."""
from a2a_mcp_server import _handle_http_mcp
req = _DummyRequest({
"jsonrpc": "2.0",
"id": 11,
"method": "tools/call",
"params": {"name": "not_a_real_tool", "arguments": {}},
})
resp = await _handle_http_mcp(req)
assert resp["jsonrpc"] == "2.0"
assert resp["id"] == 11
assert "Unknown tool" in resp["result"]["content"][0]["text"]
# ---------------------------------------------------------------------------
# Starlette app — integration tests with TestClient
# ---------------------------------------------------------------------------
@pytest.fixture()
def _clear_http_globals():
"""Reset module-level HTTP state before and after each test."""
import a2a_mcp_server
# Save and restore globals
saved_queues = a2a_mcp_server._http_connection_queues.copy()
saved_lock = a2a_mcp_server._http_connection_lock
a2a_mcp_server._http_connection_queues.clear()
yield
# Restore
a2a_mcp_server._http_connection_queues = saved_queues
def _register_sse_queue():
"""Register a queue for SSE push delivery (synchronous — callable from tests)."""
conn_id = str(uuid.uuid4())
queue = asyncio.Queue(maxsize=100)
import a2a_mcp_server
a2a_mcp_server._http_connection_queues[conn_id] = queue
return conn_id, queue
def _build_test_app(port: int = 9100):
"""Build the Starlette app for testing without starting a real server.
Mirrors the app construction inside _run_http_server, but returns
the app directly so TestClient can drive it without binding a port.
"""
from starlette.applications import Starlette
from starlette.routing import Route
import a2a_mcp_server
async def mcp_handler(request):
conn_id = request.headers.get("x-mcp-conn-id", "default")
response = await a2a_mcp_server._handle_http_mcp(request)
if response is None:
from starlette.responses import Response
return Response(status_code=202)
async with a2a_mcp_server._http_connection_lock:
queue = a2a_mcp_server._http_connection_queues.get(conn_id)
if queue is not None and not queue.full():
await queue.put(response)
from starlette.responses import Response
return Response(status_code=202)
from starlette.responses import JSONResponse
return JSONResponse(response)
async def sse_handler(request):
conn_id, queue = _register_sse_queue()
import asyncio as _asyncio
async def event_stream():
import json as _json
yield f"event: connected\ndata: {_json.dumps({'conn_id': conn_id})}\n\n"
try:
while True:
response = await _asyncio.wait_for(queue.get(), timeout=300)
import json as _json
yield f"event: message\ndata: {_json.dumps(response)}\n\n"
if queue.empty():
yield "event: heartbeat\ndata: null\n\n"
except _asyncio.TimeoutError:
pass
finally:
async with a2a_mcp_server._http_connection_lock:
a2a_mcp_server._http_connection_queues.pop(conn_id, None)
from starlette.responses import StreamingResponse
return StreamingResponse(
event_stream(),
media_type="text/event-stream",
headers={
"Cache-Control": "no-cache",
"Connection": "keep-alive",
"X-Accel-Buffering": "no",
},
)
async def health_handler(_request):
from starlette.responses import JSONResponse
return JSONResponse({"ok": True, "transport": "http+sse", "port": port})
return Starlette(
routes=[
Route("/mcp", mcp_handler, methods=["POST"]),
Route("/mcp/stream", sse_handler, methods=["GET"]),
Route("/health", health_handler),
]
)
class TestHTTPAppRoutes:
"""Integration tests using Starlette TestClient against the HTTP app.
Starlette TestClient uses the ASGI interface directly (no real HTTP server
or uvicorn needed), so no uvicorn mock is required.
"""
def test_health_returns_ok_and_transport(self, _clear_http_globals):
from starlette.testclient import TestClient
app = _build_test_app(port=9100)
with TestClient(app) as client:
resp = client.get("/health")
assert resp.status_code == 200
data = resp.json()
assert data["ok"] is True
assert data["transport"] == "http+sse"
assert data["port"] == 9100
def test_health_accepts_different_port(self, _clear_http_globals):
from starlette.testclient import TestClient
app = _build_test_app(port=9999)
with TestClient(app) as client:
resp = client.get("/health")
assert resp.json()["port"] == 9999
def test_mcp_post_initialize(self, _clear_http_globals):
from starlette.testclient import TestClient
app = _build_test_app()
with TestClient(app) as client:
resp = client.post("/mcp", json={
"jsonrpc": "2.0",
"id": 1,
"method": "initialize",
"params": {},
})
assert resp.status_code == 200
data = resp.json()
assert data["id"] == 1
assert "protocolVersion" in data["result"]
def test_mcp_post_tools_list(self, _clear_http_globals):
from starlette.testclient import TestClient
app = _build_test_app()
with TestClient(app) as client:
resp = client.post("/mcp", json={
"jsonrpc": "2.0",
"id": 2,
"method": "tools/list",
"params": {},
})
assert resp.status_code == 200
data = resp.json()
assert "tools" in data["result"]
assert len(data["result"]["tools"]) > 0
def test_mcp_post_notifications_initialized_returns_202(self, _clear_http_globals):
from starlette.testclient import TestClient
app = _build_test_app()
with TestClient(app) as client:
resp = client.post("/mcp", json={
"jsonrpc": "2.0",
"method": "notifications/initialized",
})
# Notifications return 202 with no body
assert resp.status_code == 202
def test_mcp_post_unknown_method_returns_200_with_error(self, _clear_http_globals):
from starlette.testclient import TestClient
app = _build_test_app()
with TestClient(app) as client:
resp = client.post("/mcp", json={
"jsonrpc": "2.0",
"id": 5,
"method": "no_such_method",
"params": {},
})
assert resp.status_code == 200
data = resp.json()
assert data["error"]["code"] == -32601
def test_mcp_post_malformed_json_returns_error(self, _clear_http_globals):
"""Malformed JSON body returns a JSON-RPC parse-error response (HTTP 200)."""
from starlette.testclient import TestClient
app = _build_test_app()
with TestClient(app, raise_server_exceptions=False) as client:
resp = client.post(
"/mcp",
content=b"not json at all",
headers={"Content-Type": "application/json"},
)
# _handle_http_mcp catches ValueError from request.json() and returns
# a JSON-RPC parse-error response with HTTP 200.
assert resp.status_code == 200
assert resp.json()["error"]["code"] == -32700
assert "Parse error" in resp.json()["error"]["message"]
@pytest.mark.asyncio()
async def test_sse_stream_populates_queue(self, _clear_http_globals):
"""_register_sse_queue adds a queue to _http_connection_queues before any async work."""
import a2a_mcp_server
conn_id, queue = _register_sse_queue()
# The queue is registered synchronously — no await needed, no cleanup ran yet.
assert conn_id in a2a_mcp_server._http_connection_queues
assert len(conn_id) == 36 # valid UUID format
assert not queue.full()
@pytest.mark.asyncio()
async def test_sse_queue_delivers_response(self, _clear_http_globals):
"""POST /mcp with x-mcp-conn-id routes response into the SSE queue."""
import uuid
import a2a_mcp_server
from starlette.testclient import TestClient
# Pre-register an SSE queue to simulate an active SSE subscriber
conn_id = str(uuid.uuid4())
queue: asyncio.Queue = asyncio.Queue(maxsize=100)
async with a2a_mcp_server._http_connection_lock:
a2a_mcp_server._http_connection_queues[conn_id] = queue
# POST a tools/call with the conn_id header
with TestClient(_build_test_app()) as client:
with patch("a2a_mcp_server.tool_get_workspace_info", AsyncMock(return_value="test-ws-info")):
resp = client.post(
"/mcp",
headers={"x-mcp-conn-id": conn_id},
json={
"jsonrpc": "2.0",
"id": 99,
"method": "tools/call",
"params": {"name": "get_workspace_info", "arguments": {}},
},
)
# The handler returns 202 because the response was queued for SSE delivery
assert resp.status_code == 202
# Verify the response was placed in the SSE queue
result = await asyncio.wait_for(queue.get(), timeout=2.0)
assert result["id"] == 99
assert result["result"]["content"][0]["text"] == "test-ws-info"
# ---------------------------------------------------------------------------
# cli_main argparse — unit tests
# ---------------------------------------------------------------------------
def test_mcp_post_falls_back_to_json_when_sse_queue_is_full(_clear_http_globals):
"""When the SSE queue is full (>100 pending), the handler returns JSON directly."""
import a2a_mcp_server
from starlette.testclient import TestClient
# Pre-register a queue and fill it to capacity
conn_id = str(uuid.uuid4())
queue: asyncio.Queue = asyncio.Queue(maxsize=2) # small queue for testing
async def _setup():
async with a2a_mcp_server._http_connection_lock:
a2a_mcp_server._http_connection_queues[conn_id] = queue
queue.put_nowait({"id": 1})
queue.put_nowait({"id": 2})
_sync_run(_setup())
assert queue.full()
app = _build_test_app()
with TestClient(app) as client:
resp = client.post(
"/mcp",
headers={"x-mcp-conn-id": conn_id},
json={"jsonrpc": "2.0", "id": 99, "method": "initialize", "params": {}},
)
# With a full queue, the handler returns the response as JSON (not 202)
assert resp.status_code == 200
assert resp.json()["id"] == 99
assert "result" in resp.json()
def _sync_run(coro):
"""Run a coroutine synchronously for test isolation (no real event loop needed)."""
try:
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
try:
return loop.run_until_complete(coro)
finally:
loop.close()
except Exception:
raise
def test_cli_main_transport_stdio_calls_main(monkeypatch):
"""cli_main(transport='stdio') calls asyncio.run(main) without HTTP."""
import a2a_mcp_server
run_calls: list = []
async def fake_main():
run_calls.append("called")
monkeypatch.setattr(a2a_mcp_server, "main", fake_main)
monkeypatch.setattr(a2a_mcp_server.asyncio, "run", _sync_run)
monkeypatch.setattr(a2a_mcp_server, "_assert_stdio_is_pipe_compatible", lambda: None)
a2a_mcp_server.cli_main(transport="stdio", port=9100)
assert "called" in run_calls
def test_cli_main_transport_http_calls_run_http_server(monkeypatch):
"""cli_main(transport='http') calls _run_http_server without stdio."""
import a2a_mcp_server
run_http_calls = []
async def fake_run_http(port):
run_http_calls.append(port)
# asyncio.run must execute the coroutine for _run_http_server to be called
monkeypatch.setattr(a2a_mcp_server.asyncio, "run", _sync_run)
monkeypatch.setattr(a2a_mcp_server, "_run_http_server", fake_run_http)
# stdio path must not be entered
monkeypatch.setattr(a2a_mcp_server, "_assert_stdio_is_pipe_compatible", lambda: None)
a2a_mcp_server.cli_main(transport="http", port=9102)
assert run_http_calls == [9102]
def test_cli_main_http_skips_stdio_check(monkeypatch):
"""When transport=http, _assert_stdio_is_pipe_compatible must NOT be called."""
import a2a_mcp_server
called = []
def fake_assert():
called.append("assert_called")
# Patch on the module object directly
monkeypatch.setattr(a2a_mcp_server, "_assert_stdio_is_pipe_compatible", fake_assert)
monkeypatch.setattr(a2a_mcp_server.asyncio, "run", lambda fn: None)
a2a_mcp_server.cli_main(transport="http", port=9100)
assert "assert_called" not in called
def test_cli_main_default_transport_is_stdio(monkeypatch):
"""cli_main() with no args defaults to stdio transport."""
import a2a_mcp_server
called_as: list = []
async def fake_main():
called_as.append("called")
monkeypatch.setattr(a2a_mcp_server, "main", fake_main)
monkeypatch.setattr(a2a_mcp_server.asyncio, "run", _sync_run)
monkeypatch.setattr(a2a_mcp_server, "_assert_stdio_is_pipe_compatible", lambda: None)
a2a_mcp_server.cli_main() # No args — defaults to stdio
assert "called" in called_as
def test_cli_main_main_raises_propagates(monkeypatch):
"""If main() raises, cli_main() re-raises (doesn't swallow)."""
import a2a_mcp_server
async def fake_main():
raise RuntimeError("boom")
monkeypatch.setattr(a2a_mcp_server, "main", fake_main)
monkeypatch.setattr(a2a_mcp_server.asyncio, "run", _sync_run)
monkeypatch.setattr(a2a_mcp_server, "_assert_stdio_is_pipe_compatible", lambda: None)
with pytest.raises(RuntimeError, match="boom"):
a2a_mcp_server.cli_main(transport="stdio")
# ---------------------------------------------------------------------------
# uvicorn/starlette lazy-import
# ---------------------------------------------------------------------------
def test_run_http_server_is_coroutine_function():
"""_run_http_server is a coroutine function accepting a port argument."""
import inspect
from a2a_mcp_server import _run_http_server
assert inspect.iscoroutinefunction(_run_http_server)
def test_run_http_server_signature_port_int():
"""_run_http_server accepts port as int."""
import inspect
from a2a_mcp_server import _run_http_server
sig = inspect.signature(_run_http_server)
assert "port" in sig.parameters
assert sig.parameters["port"].annotation == int
+432
View File
@@ -0,0 +1,432 @@
"""Test coverage for ``builtin_tools.a2a_tools`` and ``send_message_wrapper``.
Issue #367: 21 new test cases targeting previously-uncovered branches.
HTTP mocking: each test patches ``builtin_tools.a2a_tools.httpx.AsyncClient``
with an ``AsyncMock`` so no real network I/O occurs. The patch target is
the attribute as seen inside the ``a2a_tools`` module (where httpx is imported
as ``import httpx``), so ``@pytest.fixture(autouse=True)`` from conftest.py is
harmless it replaces the module-level name *after* our patch exits.
"""
from __future__ import annotations
import asyncio
import os
import sys
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
# ---------------------------------------------------------------------------
# conftest.py fixture — swap the MagicMock for the real module for THIS file
# ---------------------------------------------------------------------------
@pytest.fixture(autouse=True)
def _real_a2a_tools_module():
"""Replace conftest's MagicMock of builtin_tools.a2a_tools with the real module.
conftest.py sets sys.modules["builtin_tools.a2a_tools"] = <MagicMock> so that
adapter tests don't accidentally hit the platform. For THIS test file we
want the real module, so we restore it from disk and swap it back after.
"""
import builtin_tools.a2a_tools as real_module
# conftest.py may have clobbered builtin_tools.__path__; restore it so the
# import above finds builtin_tools/a2a_tools.py on disk.
if "builtin_tools" in sys.modules:
real_builtin = sys.modules["builtin_tools"]
if getattr(real_builtin, "__path__", None) == []:
builtin_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
real_builtin.__path__ = [os.path.join(builtin_dir, "builtin_tools")]
saved = sys.modules.get("builtin_tools.a2a_tools")
# Ensure we have the real module (reload if sys.modules already has it)
if saved is None or saved is real_module:
import importlib
importlib.reload(real_module)
sys.modules["builtin_tools.a2a_tools"] = real_module
yield
sys.modules["builtin_tools.a2a_tools"] = saved
@pytest.fixture(autouse=True)
def _require_env(monkeypatch):
"""Per-test: set required env vars."""
monkeypatch.setenv("WORKSPACE_ID", "00000000-0000-0000-0000-000000000001")
monkeypatch.setenv("PLATFORM_URL", "http://test.invalid")
yield
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
def _make_mock_response(
json_data, status_code: int = 200
) -> MagicMock:
"""Return a fully-configured AsyncMock that mirrors httpx.Response."""
resp = MagicMock()
resp.json = MagicMock(return_value=json_data)
resp.status_code = status_code
return resp
# =============================================================================
# builtin_tools/a2a_tools — list_peers
# =============================================================================
class TestListPeers:
"""Coverage for builtin_tools/a2a_tools.list_peers()."""
async def test_returns_peers_on_200(self):
"""Successful GET returns the peer list."""
from builtin_tools.a2a_tools import list_peers
peers = [
{"id": "ws-1", "name": "Alpha", "role": "sre", "status": "online"},
{"id": "ws-2", "name": "Beta", "role": "dev", "status": "busy"},
]
mock_resp = _make_mock_response(peers, 200)
mock_client = AsyncMock()
mock_client.__aenter__.return_value.get = AsyncMock(return_value=mock_resp)
with patch("builtin_tools.a2a_tools.httpx.AsyncClient", return_value=mock_client):
result = await list_peers()
assert result == peers
async def test_returns_empty_list_on_non_200(self):
"""list_peers swallows all non-200 responses gracefully."""
from builtin_tools.a2a_tools import list_peers
mock_resp = _make_mock_response({}, 500)
mock_client = AsyncMock()
mock_client.__aenter__.return_value.get = AsyncMock(return_value=mock_resp)
with patch("builtin_tools.a2a_tools.httpx.AsyncClient", return_value=mock_client):
result = await list_peers()
assert result == []
async def test_returns_empty_list_on_exception(self):
"""Network errors must not propagate — list_peers returns []. """
from builtin_tools.a2a_tools import list_peers
mock_client = AsyncMock()
mock_client.__aenter__.return_value.get = AsyncMock(
side_effect=RuntimeError("dns failure")
)
with patch("builtin_tools.a2a_tools.httpx.AsyncClient", return_value=mock_client):
result = await list_peers()
assert result == []
# =============================================================================
# builtin_tools/a2a_tools — delegate_task
# =============================================================================
_DISCOVER_ROUTE = "http://test.invalid/registry/discover/ws-target"
class TestDelegateTask:
"""Coverage for builtin_tools/a2a_tools.delegate_task(workspace_id, task)."""
async def test_empty_workspace_id_returns_error(self):
"""Empty workspace_id is validated before any network call."""
from builtin_tools.a2a_tools import delegate_task
out = await delegate_task("", "do it")
assert "Error" in out
assert "workspace_id" in out.lower()
async def test_discover_returns_non_200(self):
"""Discovery 4xx/5xx → error message with status code."""
from builtin_tools.a2a_tools import delegate_task
discover_resp = _make_mock_response({}, 404)
mock_client = AsyncMock()
mock_client.__aenter__.return_value.get = AsyncMock(return_value=discover_resp)
with patch("builtin_tools.a2a_tools.httpx.AsyncClient", return_value=mock_client):
out = await delegate_task("ws-target", "do it")
assert "Error" in out
assert "404" in out
async def test_discover_returns_200_with_empty_url(self):
"""Discovery 200 but no url field → actionable error."""
from builtin_tools.a2a_tools import delegate_task
discover_resp = _make_mock_response({"name": "orphan"}, 200)
mock_client = AsyncMock()
mock_client.__aenter__.return_value.get = AsyncMock(return_value=discover_resp)
with patch("builtin_tools.a2a_tools.httpx.AsyncClient", return_value=mock_client):
out = await delegate_task("ws-target", "do it")
assert "Error" in out
assert "no URL" in out
async def test_a2a_post_returns_500(self):
"""A2A send 5xx with empty body → str(data) returned (code doesn't check status_code)."""
from builtin_tools.a2a_tools import delegate_task
discover_resp = _make_mock_response({"url": "http://peer.invalid/a2a"}, 200)
a2a_resp = _make_mock_response({}, 500)
mock_client = AsyncMock()
mock_client.__aenter__.return_value.get = AsyncMock(return_value=discover_resp)
mock_client.__aenter__.return_value.post = AsyncMock(return_value=a2a_resp)
with patch("builtin_tools.a2a_tools.httpx.AsyncClient", return_value=mock_client):
out = await delegate_task("ws-target", "do it")
# Code checks json body, not status_code; empty body {} → str({})
assert out == "{}"
async def test_result_parts_empty_dict(self):
"""Regression #279: {"parts": []} → str(result), not "(no text)"."""
from builtin_tools.a2a_tools import delegate_task
discover_resp = _make_mock_response({"url": "http://peer.invalid/a2a"}, 200)
a2a_resp = _make_mock_response({"result": {"parts": []}}, 200)
mock_client = AsyncMock()
mock_client.__aenter__.return_value.get = AsyncMock(return_value=discover_resp)
mock_client.__aenter__.return_value.post = AsyncMock(return_value=a2a_resp)
with patch("builtin_tools.a2a_tools.httpx.AsyncClient", return_value=mock_client):
out = await delegate_task("ws-target", "do it")
# Must return str(result), not "(no text)"
assert "parts" in out
assert "(no text)" not in out
async def test_result_is_plain_string(self):
"""A bare string result returns as-is."""
from builtin_tools.a2a_tools import delegate_task
discover_resp = _make_mock_response({"url": "http://peer.invalid/a2a"}, 200)
a2a_resp = _make_mock_response({"result": "just a plain string"}, 200)
mock_client = AsyncMock()
mock_client.__aenter__.return_value.get = AsyncMock(return_value=discover_resp)
mock_client.__aenter__.return_value.post = AsyncMock(return_value=a2a_resp)
with patch("builtin_tools.a2a_tools.httpx.AsyncClient", return_value=mock_client):
out = await delegate_task("ws-target", "do it")
assert out == "just a plain string"
async def test_result_is_number(self):
"""Non-dict, non-string result → falls through to "(no text)"."""
from builtin_tools.a2a_tools import delegate_task
discover_resp = _make_mock_response({"url": "http://peer.invalid/a2a"}, 200)
a2a_resp = _make_mock_response({"result": 12345}, 200)
mock_client = AsyncMock()
mock_client.__aenter__.return_value.get = AsyncMock(return_value=discover_resp)
mock_client.__aenter__.return_value.post = AsyncMock(return_value=a2a_resp)
with patch("builtin_tools.a2a_tools.httpx.AsyncClient", return_value=mock_client):
out = await delegate_task("ws-target", "do it")
assert out == "(no text)"
async def test_result_parts_non_dict_element(self):
"""parts[0] is not a dict → falls through to "(no text)".
The code checks if parts[0] is a dict; since 123 is an int, it hits
the else-branch and returns "(no text)".
"""
from builtin_tools.a2a_tools import delegate_task
discover_resp = _make_mock_response({"url": "http://peer.invalid/a2a"}, 200)
a2a_resp = _make_mock_response({"result": {"parts": [123, "also a string"]}}, 200)
mock_client = AsyncMock()
mock_client.__aenter__.return_value.get = AsyncMock(return_value=discover_resp)
mock_client.__aenter__.return_value.post = AsyncMock(return_value=a2a_resp)
with patch("builtin_tools.a2a_tools.httpx.AsyncClient", return_value=mock_client):
out = await delegate_task("ws-target", "do it")
assert out == "(no text)"
async def test_error_dict_form(self):
"""{"error": {"message": "..."}} → "Error: ..."."""
from builtin_tools.a2a_tools import delegate_task
discover_resp = _make_mock_response({"url": "http://peer.invalid/a2a"}, 200)
a2a_resp = _make_mock_response(
{"error": {"message": "peer overloaded", "code": 429}}, 200
)
mock_client = AsyncMock()
mock_client.__aenter__.return_value.get = AsyncMock(return_value=discover_resp)
mock_client.__aenter__.return_value.post = AsyncMock(return_value=a2a_resp)
with patch("builtin_tools.a2a_tools.httpx.AsyncClient", return_value=mock_client):
out = await delegate_task("ws-target", "do it")
assert out == "Error: peer overloaded"
async def test_error_string_form(self):
"""{"error": "string error"} → "Error: string error"."""
from builtin_tools.a2a_tools import delegate_task
discover_resp = _make_mock_response({"url": "http://peer.invalid/a2a"}, 200)
a2a_resp = _make_mock_response({"error": "workspace offline"}, 200)
mock_client = AsyncMock()
mock_client.__aenter__.return_value.get = AsyncMock(return_value=discover_resp)
mock_client.__aenter__.return_value.post = AsyncMock(return_value=a2a_resp)
with patch("builtin_tools.a2a_tools.httpx.AsyncClient", return_value=mock_client):
out = await delegate_task("ws-target", "do it")
assert out == "Error: workspace offline"
async def test_error_null(self):
"""{"error": null} → "Error: None" (edge case — str(null) in message)."""
from builtin_tools.a2a_tools import delegate_task
discover_resp = _make_mock_response({"url": "http://peer.invalid/a2a"}, 200)
a2a_resp = _make_mock_response({"error": None}, 200)
mock_client = AsyncMock()
mock_client.__aenter__.return_value.get = AsyncMock(return_value=discover_resp)
mock_client.__aenter__.return_value.post = AsyncMock(return_value=a2a_resp)
with patch("builtin_tools.a2a_tools.httpx.AsyncClient", return_value=mock_client):
out = await delegate_task("ws-target", "do it")
assert "Error" in out
async def test_a2a_post_raises_exception(self):
"""Network error during A2A POST → Error: sending A2A message: ..."""
from builtin_tools.a2a_tools import delegate_task
discover_resp = _make_mock_response({"url": "http://peer.invalid/a2a"}, 200)
mock_client = AsyncMock()
mock_client.__aenter__.return_value.get = AsyncMock(return_value=discover_resp)
mock_client.__aenter__.return_value.post = AsyncMock(
side_effect=ConnectionError("connection refused")
)
with patch("builtin_tools.a2a_tools.httpx.AsyncClient", return_value=mock_client):
out = await delegate_task("ws-target", "do it")
assert "Error" in out
assert "connection refused" in out
# =============================================================================
# builtin_tools/a2a_tools — get_peers_summary
# =============================================================================
class TestGetPeersSummary:
"""Coverage for builtin_tools/a2a_tools.get_peers_summary()."""
async def test_empty_peers_returns_no_peers_available(self):
from builtin_tools.a2a_tools import get_peers_summary
mock_resp = _make_mock_response([], 200)
mock_client = AsyncMock()
mock_client.__aenter__.return_value.get = AsyncMock(return_value=mock_resp)
with patch("builtin_tools.a2a_tools.httpx.AsyncClient", return_value=mock_client):
out = await get_peers_summary()
assert "No peers" in out
async def test_peer_missing_fields(self):
"""Peers with missing name/id/role/status must not KeyError/TypeError."""
from builtin_tools.a2a_tools import get_peers_summary
mock_resp = _make_mock_response([{"id": "ws-x"}], 200)
mock_client = AsyncMock()
mock_client.__aenter__.return_value.get = AsyncMock(return_value=mock_resp)
with patch("builtin_tools.a2a_tools.httpx.AsyncClient", return_value=mock_client):
out = await get_peers_summary()
assert "ws-x" in out
assert isinstance(out, str)
async def test_healthy_peer_roundtrip(self):
"""Sanity: normal peer dicts produce a formatted list."""
from builtin_tools.a2a_tools import get_peers_summary
peers = [
{"id": "ws-alpha", "name": "Alpha", "role": "sre", "status": "online"},
]
mock_resp = _make_mock_response(peers, 200)
mock_client = AsyncMock()
mock_client.__aenter__.return_value.get = AsyncMock(return_value=mock_resp)
with patch("builtin_tools.a2a_tools.httpx.AsyncClient", return_value=mock_client):
out = await get_peers_summary()
assert "Alpha" in out
assert "ws-alpha" in out
assert "sre" in out
assert "online" in out
# =============================================================================
# send_message_wrapper — safe_send_message
# =============================================================================
from adapters.smolagents.send_message_wrapper import safe_send_message
class TestSafeSendMessage:
"""Coverage for adapters.smolagents.send_message_wrapper.safe_send_message()."""
def test_non_string_input_converted(self):
"""Non-str text is str()-converted before escaping."""
delivered = []
safe_send_message(42, send_fn=lambda s: delivered.append(s))
assert delivered == ["[smolagents] 42"]
assert isinstance(delivered[0], str)
def test_html_entities_escaped(self):
"""< > ' are escaped so rendered UIs cannot be injected.
The payload <script>alert('xss')</script> has no literal '&', so &amp;
does not appear. The escape output is: &lt;script&gt;alert(&#x27;xss&#x27;)&lt;/script&gt;
"""
delivered = []
safe_send_message(
"<script>alert('xss')</script>",
send_fn=lambda s: delivered.append(s),
)
assert "&lt;" in delivered[0]
assert "&gt;" in delivered[0]
assert "&#x27;" in delivered[0]
assert "&lt;script&gt;" in delivered[0]
# The angle brackets and quotes must NOT appear unescaped
assert "<script>" not in delivered[0]
assert "alert('" not in delivered[0]
def test_truncation_at_max_len(self):
"""Text > 2000 chars is truncated; caller is warned."""
delivered = []
with patch(
"adapters.smolagents.send_message_wrapper.logger"
) as mock_logger:
long_text = "A" * 2500
safe_send_message(long_text, send_fn=lambda s: delivered.append(s))
assert len(delivered[0]) < len(long_text)
mock_logger.warning.assert_called_once()
assert "truncating" in mock_logger.warning.call_args[0][0]
def test_no_truncation_under_max_len(self):
"""Text ≤ 2000 chars is passed through intact with no warning."""
delivered = []
with patch(
"adapters.smolagents.send_message_wrapper.logger"
) as mock_logger:
text = "A" * 1500
safe_send_message(text, send_fn=lambda s: delivered.append(s))
expected = f"[smolagents] {text}"
assert delivered[0] == expected
mock_logger.warning.assert_not_called()
def test_debug_log_emitted(self):
"""Every delivery logs at DEBUG with final payload length."""
delivered = []
with patch(
"adapters.smolagents.send_message_wrapper.logger"
) as mock_logger:
safe_send_message("hello", send_fn=lambda s: delivered.append(s))
mock_logger.debug.assert_called_once()
assert "delivering" in mock_logger.debug.call_args[0][0]
def test_label_prefix_always_present(self):
"""Every delivered payload starts with '[smolagents]'."""
delivered = []
safe_send_message("x", send_fn=lambda s: delivered.append(s))
assert delivered[0].startswith("[smolagents]")
@@ -0,0 +1,300 @@
"""Test coverage for shared_runtime helpers (issue #366).
Six helper functions previously had zero test coverage:
_extract_part_text, extract_message_text, format_conversation_history,
build_task_text, append_peer_guidance, brief_task
"""
from __future__ import annotations
from shared_runtime import (
_extract_part_text,
append_peer_guidance,
brief_task,
build_task_text,
extract_message_text,
format_conversation_history,
)
# =============================================================================
# _extract_part_text
# =============================================================================
class TestExtractPartText:
"""Coverage for shared_runtime._extract_part_text()."""
def test_dict_with_text_field(self):
assert _extract_part_text({"text": "hello"}) == "hello"
def test_dict_without_text_field(self):
assert _extract_part_text({"type": "image"}) == ""
def test_dict_with_empty_text_field(self):
assert _extract_part_text({"text": ""}) == ""
def test_dict_with_root_nesting(self):
"""Text buried in part['root']['text'] is extracted."""
assert _extract_part_text({"root": {"text": "nested"}}) == "nested"
def test_dict_with_root_non_dict(self):
"""part['root'] that is not a dict is safely skipped."""
assert _extract_part_text({"root": "string", "text": "top"}) == "top"
def test_object_with_text_attribute(self):
class FakePart:
text = "attr-text"
assert _extract_part_text(FakePart()) == "attr-text"
def test_object_with_root_object_with_text(self):
"""Object with root.attr.text is extracted (A2A v1 object style)."""
class FakeRoot:
text = "root-attr-text"
class FakePart:
root = FakeRoot()
assert _extract_part_text(FakePart()) == "root-attr-text"
def test_object_with_empty_text_attribute(self):
class FakePart:
text = ""
assert _extract_part_text(FakePart()) == ""
def test_none_input(self):
assert _extract_part_text(None) == ""
def test_unexpected_type(self):
"""Plain int/float/bool falls through to empty string."""
assert _extract_part_text(42) == ""
# =============================================================================
# extract_message_text
# =============================================================================
class TestExtractMessageText:
"""Coverage for shared_runtime.extract_message_text()."""
def test_list_of_dict_parts(self):
parts = [{"text": "hello"}, {"text": "world"}]
assert extract_message_text(parts) == "hello world"
def test_single_part(self):
assert extract_message_text([{"text": "single"}]) == "single"
def test_context_object_with_message_parts(self):
"""RequestContext-like: .message.parts is the parts list."""
class FakeContext:
class _Msg:
parts = [{"text": "from context"}]
message = _Msg()
assert extract_message_text(FakeContext()) == "from context"
def test_context_object_without_message(self):
"""No .message attr → falls back to treating input as a parts list."""
class FakeContext:
pass # no .message
# Pass a list directly as the context-like object
assert extract_message_text([{"text": "fallback"}]) == "fallback"
def test_whitespace_normalized(self):
"""Leading/trailing whitespace is stripped; internal newlines are preserved."""
parts = [{"text": " hello "}, {"text": "\nworld\n"}]
result = extract_message_text(parts)
# Leading/trailing stripped, but internal \n stays (join uses single space)
assert result == "hello \nworld"
assert not result.startswith(" ")
assert not result.endswith(" ")
def test_empty_parts_list(self):
assert extract_message_text([]) == ""
# =============================================================================
# format_conversation_history
# =============================================================================
class TestFormatConversationHistory:
"""Coverage for shared_runtime.format_conversation_history()."""
def test_single_user_message(self):
hist = [("human", "hello")]
out = format_conversation_history(hist)
assert out == "User: hello"
def test_single_agent_message(self):
hist = [("ai", "response")]
out = format_conversation_history(hist)
assert out == "Agent: response"
def test_interleaved_history(self):
hist = [
("human", "hello"),
("ai", "hi there"),
("human", "what is 2+2?"),
("ai", "four"),
]
out = format_conversation_history(hist)
lines = out.split("\n")
assert lines[0] == "User: hello"
assert lines[1] == "Agent: hi there"
assert lines[2] == "User: what is 2+2?"
assert lines[3] == "Agent: four"
def test_empty_history(self):
assert format_conversation_history([]) == ""
# =============================================================================
# build_task_text
# =============================================================================
class TestBuildTaskText:
"""Coverage for shared_runtime.build_task_text()."""
def test_no_history_returns_user_message_unchanged(self):
assert build_task_text("do the thing", []) == "do the thing"
def test_history_prepends_transcript(self):
hist = [("human", "hello"), ("ai", "hi")]
result = build_task_text("follow-up", hist)
assert "Conversation so far:" in result
assert "User: hello" in result
assert "Agent: hi" in result
assert "follow-up" in result
def test_user_message_after_conversation_header(self):
hist = [("human", "hello")]
result = build_task_text("do it", hist)
assert result.startswith("Conversation so far:")
assert result.endswith("Current request: do it")
def test_empty_user_message_with_history(self):
"""Empty user_message is still rendered with history."""
hist = [("human", "hello")]
result = build_task_text("", hist)
assert "Conversation so far:" in result
assert "Current request:" in result
# =============================================================================
# append_peer_guidance
# =============================================================================
class TestAppendPeerGuidance:
"""Coverage for shared_runtime.append_peer_guidance()."""
def test_base_text_appended(self):
result = append_peer_guidance(
"base text",
peers_info="alpha: ws-1",
default_text="default",
tool_name="delegate_task",
)
assert result.startswith("base text")
assert "## Peers" in result
assert "alpha: ws-1" in result
assert "Use delegate_task" in result
def test_null_base_text_uses_default(self):
result = append_peer_guidance(
None,
peers_info="peer info",
default_text="DEFAULT_TEXT",
tool_name="tool",
)
assert result.startswith("DEFAULT_TEXT")
def test_whitespace_base_text_strips_to_empty_peers_still_added(self):
"""Whitespace-only base_text is stripped but default_text is NOT used
(only None triggers the fallback). The peers section is still appended."""
result = append_peer_guidance(
" ",
peers_info="peer",
default_text="DEF",
tool_name="t",
)
# " ".strip() == ""; default_text is NOT substituted for whitespace
assert "## Peers" in result
assert "peer" in result
assert "DEF" not in result # default_text only on None, not whitespace
def test_none_base_text_uses_default(self):
"""None base_text triggers fallback to default_text."""
result = append_peer_guidance(
None,
peers_info="peer",
default_text="DEFAULT",
tool_name="tool",
)
assert result.startswith("DEFAULT")
assert "## Peers" in result
def test_empty_peers_info_skips_section(self):
result = append_peer_guidance(
"base",
peers_info="",
default_text="def",
tool_name="tool",
)
# No "## Peers" section when peers_info is empty
assert result == "base"
def test_whitespace_in_base_and_peers_normalized(self):
result = append_peer_guidance(
" base \n",
peers_info=" peer-1 \n",
default_text="def",
tool_name="tool",
)
# Base should be stripped of leading/trailing whitespace
assert result.startswith("base")
# Peer info should be appended
assert "peer-1" in result
# =============================================================================
# brief_task
# =============================================================================
class TestBriefTask:
"""Coverage for shared_runtime.brief_task()."""
def test_short_text_returned_unchanged(self):
assert brief_task("hello", limit=60) == "hello"
def test_exact_limit_no_ellipsis(self):
text = "A" * 60
assert brief_task(text, limit=60) == text
assert "..." not in text
def test_truncated_with_ellipsis(self):
text = "A" * 80
result = brief_task(text, limit=60)
assert len(result) == 63 # 60 chars + "..."
assert result.endswith("...")
def test_limit_10_shortens(self):
result = brief_task("hello world", limit=10)
assert len(result) == 13 # 10 chars + "..."
assert result.endswith("...")
def test_limit_0_returns_ellipsis(self):
"""limit=0 → 0-char slice + "..." since len("hello") > 0."""
result = brief_task("hello", limit=0)
assert result == "..."
def test_limit_1_single_char_plus_ellipsis(self):
result = brief_task("hello", limit=1)
assert len(result) == 4 # 1 char + "..."
assert result.startswith("h")
assert result.endswith("...")