Compare commits
6 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| aba67633f4 | |||
| cc6057db7a | |||
| 0147cc1857 | |||
| b87a79aff4 | |||
| ef93dabaf9 | |||
| a4a6f52064 |
@@ -170,9 +170,12 @@ jobs:
|
||||
# CLI (molecli) moved to standalone repo: git.moleculesai.app/molecule-ai/molecule-cli
|
||||
- if: needs.changes.outputs.platform == 'true'
|
||||
run: go vet ./...
|
||||
- if: needs.changes.outputs.platform == 'true'
|
||||
name: Install golangci-lint
|
||||
run: go install github.com/golangci/golangci-lint/v2/cmd/golangci-lint@v2.12.2
|
||||
- if: needs.changes.outputs.platform == 'true'
|
||||
name: Run golangci-lint
|
||||
run: golangci-lint run --timeout 3m ./...
|
||||
run: $(go env GOPATH)/bin/golangci-lint run --timeout 3m ./...
|
||||
- if: needs.changes.outputs.platform == 'true'
|
||||
name: Diagnostic — per-package verbose 60s
|
||||
run: |
|
||||
|
||||
@@ -7,14 +7,16 @@
|
||||
// in place rather than duplicating.
|
||||
//
|
||||
// Usage:
|
||||
// memory-backfill -dry-run # count + diff
|
||||
// memory-backfill -apply # actually copy
|
||||
// memory-backfill -apply -limit=10000 # cap rows per run
|
||||
// memory-backfill -apply -workspace=<uuid> # one workspace only
|
||||
//
|
||||
// memory-backfill -dry-run # count + diff
|
||||
// memory-backfill -apply # actually copy
|
||||
// memory-backfill -apply -limit=10000 # cap rows per run
|
||||
// memory-backfill -apply -workspace=<uuid> # one workspace only
|
||||
//
|
||||
// Required env:
|
||||
// DATABASE_URL — workspace-server DB (read agent_memories)
|
||||
// MEMORY_PLUGIN_URL — target plugin (write memory_records)
|
||||
//
|
||||
// DATABASE_URL — workspace-server DB (read agent_memories)
|
||||
// MEMORY_PLUGIN_URL — target plugin (write memory_records)
|
||||
package main
|
||||
|
||||
import (
|
||||
@@ -251,7 +253,7 @@ func mapScopeToNamespace(ctx context.Context, r backfillResolver, workspaceID, s
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("resolve writable: %w", err)
|
||||
}
|
||||
wantKind := contract.NamespaceKindWorkspace
|
||||
var wantKind contract.NamespaceKind
|
||||
switch scope {
|
||||
case "LOCAL":
|
||||
wantKind = contract.NamespaceKindWorkspace
|
||||
|
||||
@@ -522,7 +522,7 @@ func (m *Manager) FetchWorkspaceChannelContext(ctx context.Context, workspaceID
|
||||
if len(text) > 200 {
|
||||
text = text[:197] + "..."
|
||||
}
|
||||
sb.WriteString(fmt.Sprintf("- %s: %s\n", name, text))
|
||||
fmt.Fprintf(&sb, "- %s: %s\n", name, text)
|
||||
}
|
||||
return sb.String()
|
||||
}
|
||||
|
||||
@@ -134,9 +134,9 @@ var botCommands = []tgbotapi.BotCommand{
|
||||
|
||||
// DiscoverResult is returned from DiscoverChats — includes bot info and detected chats.
|
||||
type DiscoverResult struct {
|
||||
BotUsername string
|
||||
Chats []map[string]interface{}
|
||||
CanReadAllGroupMessages bool // false = group privacy mode is ON (bot only sees commands/mentions)
|
||||
BotUsername string
|
||||
Chats []map[string]interface{}
|
||||
CanReadAllGroupMessages bool // false = group privacy mode is ON (bot only sees commands/mentions)
|
||||
}
|
||||
|
||||
// DiscoverChats calls Telegram getUpdates to find groups/chats the bot has been added to.
|
||||
@@ -231,7 +231,6 @@ func (t *TelegramAdapter) DiscoverChats(ctx context.Context, botToken string) (*
|
||||
addChat(msg.Chat)
|
||||
}
|
||||
|
||||
|
||||
return &DiscoverResult{
|
||||
BotUsername: bot.Self.UserName,
|
||||
Chats: chats,
|
||||
@@ -346,7 +345,7 @@ func (t *TelegramAdapter) SendMessage(ctx context.Context, config map[string]int
|
||||
case 403:
|
||||
return fmt.Errorf("forbidden: bot was blocked or kicked from chat %s", chatID)
|
||||
case 429:
|
||||
retryAfter := time.Duration(apiErr.ResponseParameters.RetryAfter) * time.Second
|
||||
retryAfter := time.Duration(apiErr.RetryAfter) * time.Second
|
||||
log.Printf("Channels: Telegram rate-limited, retry after %s", retryAfter)
|
||||
time.Sleep(retryAfter)
|
||||
if _, retryErr := bot.Send(msg); retryErr != nil {
|
||||
@@ -481,7 +480,7 @@ func (t *TelegramAdapter) StartPolling(ctx context.Context, config map[string]in
|
||||
var apiErr *tgbotapi.Error
|
||||
if errors.As(err, &apiErr) {
|
||||
if apiErr.Code == 429 {
|
||||
retryAfter := time.Duration(apiErr.ResponseParameters.RetryAfter) * time.Second
|
||||
retryAfter := time.Duration(apiErr.RetryAfter) * time.Second
|
||||
log.Printf("Channels: Telegram poll rate-limited, sleeping %s", retryAfter)
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
|
||||
@@ -108,7 +108,7 @@ func TestEventType_AllUppercaseSnakeCase(t *testing.T) {
|
||||
t.Errorf("EventType %q has consecutive underscores — disallowed", s)
|
||||
}
|
||||
for _, r := range s {
|
||||
if !((r >= 'A' && r <= 'Z') || (r >= '0' && r <= '9') || r == '_') {
|
||||
if (r < 'A' || r > 'Z') && (r < '0' || r > '9') && r != '_' {
|
||||
t.Errorf("EventType %q contains disallowed char %q", s, r)
|
||||
break
|
||||
}
|
||||
|
||||
@@ -42,7 +42,7 @@ func setupTestDBForQueueTests(t *testing.T) sqlmock.Sqlmock {
|
||||
// ──────────────────────────────────────────────────────────────────────────────
|
||||
|
||||
func TestPriorityConstants(t *testing.T) {
|
||||
if !(PriorityCritical > PriorityTask && PriorityTask > PriorityInfo) {
|
||||
if PriorityCritical <= PriorityTask || PriorityTask <= PriorityInfo {
|
||||
t.Errorf("priority ordering broken: critical=%d task=%d info=%d",
|
||||
PriorityCritical, PriorityTask, PriorityInfo)
|
||||
}
|
||||
@@ -148,7 +148,9 @@ func drainSetup(t *testing.T, workspaceID string) (sqlmock.Sqlmock, *WorkspaceHa
|
||||
}
|
||||
|
||||
// expectQueueBudgetCheck registers the mock for checkWorkspaceBudget's query:
|
||||
// SELECT budget_limit, COALESCE(monthly_spend, 0) FROM workspaces WHERE id = $1
|
||||
//
|
||||
// SELECT budget_limit, COALESCE(monthly_spend, 0) FROM workspaces WHERE id = $1
|
||||
//
|
||||
// Must be called AFTER expectDequeueNextOk — DequeueNext (BEGIN→SELECT→UPDATE→COMMIT)
|
||||
// runs before proxyA2ARequest which calls checkWorkspaceBudget.
|
||||
// Named distinctly from handlers_test.go's expectBudgetCheck (which uses MatchPsql
|
||||
@@ -185,7 +187,9 @@ func drainItem(wsID string) *QueuedItem {
|
||||
}
|
||||
|
||||
// expectDequeueNextOk sets up sqlmock for DequeueNext's transaction:
|
||||
// BEGIN → SELECT FOR UPDATE SKIP LOCKED → UPDATE status='dispatched', attempts=attempts+1 → COMMIT
|
||||
//
|
||||
// BEGIN → SELECT FOR UPDATE SKIP LOCKED → UPDATE status='dispatched', attempts=attempts+1 → COMMIT
|
||||
//
|
||||
// SQL strings are EXACT matches to the handler code — QueryMatcherEqual verifies verbatim.
|
||||
func expectDequeueNextOk(mock sqlmock.Sqlmock, item *QueuedItem) {
|
||||
mock.ExpectBegin()
|
||||
|
||||
@@ -474,12 +474,7 @@ func (h *ActivityHandler) Notify(c *gin.Context) {
|
||||
// Lark) hook in here too.
|
||||
attachments := make([]AgentMessageAttachment, 0, len(body.Attachments))
|
||||
for _, a := range body.Attachments {
|
||||
attachments = append(attachments, AgentMessageAttachment{
|
||||
URI: a.URI,
|
||||
Name: a.Name,
|
||||
MimeType: a.MimeType,
|
||||
Size: a.Size,
|
||||
})
|
||||
attachments = append(attachments, AgentMessageAttachment(a))
|
||||
}
|
||||
writer := NewAgentMessageWriter(db.DB, h.broadcaster)
|
||||
if err := writer.Send(c.Request.Context(), workspaceID, body.Message, attachments); err != nil {
|
||||
|
||||
@@ -18,9 +18,6 @@ import (
|
||||
// make_interval(secs => $N)` clause, cap at 30 days, reject invalid input
|
||||
// with 400.
|
||||
|
||||
const activityCols = `id, workspace_id, activity_type, source_id, target_id, method, ` +
|
||||
`summary, request_body, response_body, tool_trace, duration_ms, status, error_detail, created_at`
|
||||
|
||||
func newActivityRows() *sqlmock.Rows {
|
||||
cols := []string{
|
||||
"id", "workspace_id", "activity_type", "source_id", "target_id", "method",
|
||||
|
||||
@@ -262,16 +262,16 @@ func (h *AdminMemoriesHandler) Import(c *gin.Context) {
|
||||
// because workspaces sharing a team/org root see identical namespaces.
|
||||
//
|
||||
// New strategy:
|
||||
// 1. Single SQL pass walks parent_id chains, returning each
|
||||
// workspace's root_id alongside its name.
|
||||
// 2. Group workspaces by root → unique tree count is typically <<
|
||||
// workspace count.
|
||||
// 3. Resolve namespaces ONCE per root (any workspace under that
|
||||
// root produces the same readable list).
|
||||
// 4. Build a UNION of namespaces across all roots; single plugin
|
||||
// search call.
|
||||
// 5. Map each memory back to a workspace_name via a namespace→ws
|
||||
// lookup table built up from step 3.
|
||||
// 1. Single SQL pass walks parent_id chains, returning each
|
||||
// workspace's root_id alongside its name.
|
||||
// 2. Group workspaces by root → unique tree count is typically <<
|
||||
// workspace count.
|
||||
// 3. Resolve namespaces ONCE per root (any workspace under that
|
||||
// root produces the same readable list).
|
||||
// 4. Build a UNION of namespaces across all roots; single plugin
|
||||
// search call.
|
||||
// 5. Map each memory back to a workspace_name via a namespace→ws
|
||||
// lookup table built up from step 3.
|
||||
//
|
||||
// Net cost: 1 SQL + N_roots resolver calls + 1 plugin call (vs
|
||||
// N_workspaces resolver + N_workspaces plugin in the old code).
|
||||
@@ -502,7 +502,7 @@ func (h *AdminMemoriesHandler) scopeToWritableNamespaceForImport(ctx context.Con
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
wantKind := contract.NamespaceKindWorkspace
|
||||
var wantKind contract.NamespaceKind
|
||||
switch strings.ToUpper(scope) {
|
||||
case "", "LOCAL":
|
||||
wantKind = contract.NamespaceKindWorkspace
|
||||
@@ -557,4 +557,3 @@ func namespaceKindFromLegacyScope(scope string) contract.NamespaceKind {
|
||||
return contract.NamespaceKindWorkspace
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -131,10 +131,9 @@ func TestCutoverActive(t *testing.T) {
|
||||
|
||||
func TestWithMemoryV2_AttachesDeps(t *testing.T) {
|
||||
h := NewAdminMemoriesHandler().WithMemoryV2(nil, nil)
|
||||
// Both nil pointers — wiring still attaches them; cutoverActive
|
||||
// reports false because the interface values are nil.
|
||||
if h.plugin == nil && h.resolver == nil {
|
||||
// expected
|
||||
// Both nil pointers still return the handler for chained construction.
|
||||
if h == nil {
|
||||
t.Fatal("WithMemoryV2(nil, nil) returned nil handler")
|
||||
}
|
||||
}
|
||||
|
||||
@@ -596,7 +595,7 @@ func (r perWorkspaceResolver) ReadableNamespaces(_ context.Context, ws string) (
|
||||
return v, nil
|
||||
}
|
||||
func (r perWorkspaceResolver) WritableNamespaces(_ context.Context, ws string) ([]namespace.Namespace, error) {
|
||||
return r.ReadableNamespaces(nil, ws)
|
||||
return r.ReadableNamespaces(context.TODO(), ws)
|
||||
}
|
||||
|
||||
// TestExport_IncludesEveryMembersPrivateNamespace pins the I3 follow-up
|
||||
|
||||
@@ -71,13 +71,6 @@ func (h *BudgetHandler) GetBudget(c *gin.Context) {
|
||||
c.JSON(http.StatusOK, resp)
|
||||
}
|
||||
|
||||
// patchBudgetRequest is the expected JSON body for PATCH /workspaces/:id/budget.
|
||||
// budget_limit=null removes the ceiling; a positive integer sets it (USD cents).
|
||||
type patchBudgetRequest struct {
|
||||
// BudgetLimit pointer so JSON null → nil, absent → parse error (required field).
|
||||
BudgetLimit *int64 `json:"budget_limit"`
|
||||
}
|
||||
|
||||
// PatchBudget handles PATCH /workspaces/:id/budget.
|
||||
// Accepts {"budget_limit": <int64>} to set a new ceiling, or
|
||||
// {"budget_limit": null} to remove an existing ceiling.
|
||||
|
||||
@@ -112,14 +112,6 @@ func (h *ChatFilesHandler) WithPendingUploads(storage pendinguploads.Storage, br
|
||||
// network boundary before forwarding.
|
||||
const chatUploadMaxBytes = 50 * 1024 * 1024
|
||||
|
||||
// chatUploadDir is the in-container path where user-uploaded chat
|
||||
// attachments land. Kept here for documentation parity with the
|
||||
// workspace-side handler — the platform no longer writes files
|
||||
// directly, but the URI scheme returned in responses still uses this
|
||||
// path, so any consumer parsing those URIs has the constant to
|
||||
// reference.
|
||||
const chatUploadDir = "/workspace/.molecule/chat-uploads"
|
||||
|
||||
// resolveWorkspaceForwardCreds resolves the workspace's URL +
|
||||
// platform_inbound_secret for an /internal/* forward, applying
|
||||
// lazy-heal on a missing inbound secret (RFC #2312 backfill — the
|
||||
@@ -460,7 +452,6 @@ func (h *ChatFilesHandler) streamWorkspaceResponse(
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
// lookupUploadDeliveryMode returns the workspace's delivery_mode
|
||||
// for the chat upload branch. Returns ("", false) and writes the
|
||||
// HTTP error response on lookup failure (caller stops). NULL or
|
||||
|
||||
@@ -153,7 +153,7 @@ func TestMergeSystemMessages_EmptySlice(t *testing.T) {
|
||||
func TestMergeSystemMessages_NilSlice(t *testing.T) {
|
||||
var input []map[string]interface{}
|
||||
got := mergeSystemMessages(input)
|
||||
if got != nil && len(got) != 0 {
|
||||
if len(got) != 0 {
|
||||
t.Errorf("nil: got %v, want nil/empty", got)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -47,13 +47,13 @@ const defaultProvisionConcurrency = 3
|
||||
//
|
||||
// - unset / empty / non-numeric → defaultProvisionConcurrency (3)
|
||||
// - "0" → unlimited (a very large cap;
|
||||
// practically no semaphore — used on
|
||||
// SaaS where AWS RunInstances is the
|
||||
// rate-limiter, not us)
|
||||
// practically no semaphore — used on
|
||||
// SaaS where AWS RunInstances is the
|
||||
// rate-limiter, not us)
|
||||
// - any positive integer N → N
|
||||
// - negative integer → defaultProvisionConcurrency (3),
|
||||
// log warning so operator notices
|
||||
// the misconfiguration
|
||||
// log warning so operator notices
|
||||
// the misconfiguration
|
||||
//
|
||||
// The "0 = unlimited" mapping was a deliberate choice: an env var of "0"
|
||||
// is the natural shorthand for "no cap" without forcing operators to
|
||||
@@ -102,18 +102,6 @@ const (
|
||||
childGridColumnCount = 2
|
||||
)
|
||||
|
||||
// childSlot computes the child-relative position for the N-th sibling in
|
||||
// a parent's 2-column grid. Matches defaultChildSlot in
|
||||
// canvas-topology.ts exactly — change them together. Leaf-sized slots
|
||||
// only; for variable-size siblings use childSlotInGrid below.
|
||||
func childSlot(index int) (x, y float64) {
|
||||
col := index % childGridColumnCount
|
||||
row := index / childGridColumnCount
|
||||
x = parentSidePadding + float64(col)*(childDefaultWidth+childGutter)
|
||||
y = parentHeaderPadding + float64(row)*(childDefaultHeight+childGutter)
|
||||
return
|
||||
}
|
||||
|
||||
type nodeSize struct {
|
||||
width, height float64
|
||||
}
|
||||
@@ -342,10 +330,10 @@ func (e *EnvRequirement) UnmarshalJSON(data []byte) error {
|
||||
|
||||
// OrgTemplate is the YAML structure for an org hierarchy.
|
||||
type OrgTemplate struct {
|
||||
Name string `yaml:"name" json:"name"`
|
||||
Description string `yaml:"description" json:"description"`
|
||||
Defaults OrgDefaults `yaml:"defaults" json:"defaults"`
|
||||
Workspaces []OrgWorkspace `yaml:"workspaces" json:"workspaces"`
|
||||
Name string `yaml:"name" json:"name"`
|
||||
Description string `yaml:"description" json:"description"`
|
||||
Defaults OrgDefaults `yaml:"defaults" json:"defaults"`
|
||||
Workspaces []OrgWorkspace `yaml:"workspaces" json:"workspaces"`
|
||||
// GlobalMemories is a list of org-wide memories seeded as GLOBAL scope
|
||||
// on the first root workspace (PM) during org import. Issue #1050.
|
||||
GlobalMemories []models.MemorySeed `yaml:"global_memories" json:"global_memories"`
|
||||
@@ -381,9 +369,9 @@ type OrgDefaults struct {
|
||||
// declare them — causing live configs to boot without idle_prompts
|
||||
// even when org.yaml had them. Phase 1 scalability work adds both
|
||||
// inline + file-ref forms.
|
||||
IdlePrompt string `yaml:"idle_prompt" json:"idle_prompt"`
|
||||
IdlePromptFile string `yaml:"idle_prompt_file" json:"idle_prompt_file"`
|
||||
IdleIntervalSeconds int `yaml:"idle_interval_seconds" json:"idle_interval_seconds"`
|
||||
IdlePrompt string `yaml:"idle_prompt" json:"idle_prompt"`
|
||||
IdlePromptFile string `yaml:"idle_prompt_file" json:"idle_prompt_file"`
|
||||
IdleIntervalSeconds int `yaml:"idle_interval_seconds" json:"idle_interval_seconds"`
|
||||
// CategoryRouting maps issue/audit category → list of target roles.
|
||||
// Per-workspace blocks UNION + override per-key with these defaults.
|
||||
// Rendered into each workspace's config.yaml so agent prompts can read it
|
||||
@@ -470,12 +458,12 @@ type OrgWorkspace struct {
|
||||
// time. If empty, defaults.initial_memories are used. Issue #1050.
|
||||
InitialMemories []models.MemorySeed `yaml:"initial_memories" json:"initial_memories"`
|
||||
// MaxConcurrentTasks: see models.CreateWorkspacePayload.
|
||||
MaxConcurrentTasks int `yaml:"max_concurrent_tasks" json:"max_concurrent_tasks"`
|
||||
Schedules []OrgSchedule `yaml:"schedules" json:"schedules"`
|
||||
Channels []OrgChannel `yaml:"channels" json:"channels"`
|
||||
External bool `yaml:"external" json:"external"`
|
||||
URL string `yaml:"url" json:"url"`
|
||||
Canvas struct {
|
||||
MaxConcurrentTasks int `yaml:"max_concurrent_tasks" json:"max_concurrent_tasks"`
|
||||
Schedules []OrgSchedule `yaml:"schedules" json:"schedules"`
|
||||
Channels []OrgChannel `yaml:"channels" json:"channels"`
|
||||
External bool `yaml:"external" json:"external"`
|
||||
URL string `yaml:"url" json:"url"`
|
||||
Canvas struct {
|
||||
X float64 `yaml:"x" json:"x"`
|
||||
Y float64 `yaml:"y" json:"y"`
|
||||
} `yaml:"canvas" json:"canvas"`
|
||||
@@ -714,10 +702,10 @@ func (h *OrgHandler) Import(c *gin.Context) {
|
||||
wsMissing := collectPerWorkspaceUnsatisfied(tmpl.Workspaces, orgBaseDir, configured)
|
||||
if len(wsMissing) > 0 {
|
||||
c.JSON(http.StatusPreconditionFailed, gin.H{
|
||||
"error": "missing per-workspace required environment variables",
|
||||
"error": "missing per-workspace required environment variables",
|
||||
"missing_workspace_env": wsMissing,
|
||||
"template": tmpl.Name,
|
||||
"suggestion": "add these keys to the workspace's .env file or set them as global secrets before importing",
|
||||
"template": tmpl.Name,
|
||||
"suggestion": "add these keys to the workspace's .env file or set them as global secrets before importing",
|
||||
})
|
||||
return
|
||||
}
|
||||
@@ -952,4 +940,3 @@ func errString(err error) string {
|
||||
}
|
||||
return err.Error()
|
||||
}
|
||||
|
||||
|
||||
@@ -196,7 +196,7 @@ func TestSanitizeEnvMembers_MaxLength(t *testing.T) {
|
||||
}
|
||||
// 129 chars: invalid (exceeds {0,127} suffix in regex)
|
||||
tooLong := "A" + strings.Repeat("B", 128)
|
||||
got, ok = sanitizeEnvMembers([]string{tooLong}, "test")
|
||||
_, ok = sanitizeEnvMembers([]string{tooLong}, "test")
|
||||
if ok {
|
||||
t.Error("129 char invalid: ok should be false")
|
||||
}
|
||||
@@ -230,7 +230,7 @@ func TestFlattenAndSortRequirements_Empty(t *testing.T) {
|
||||
func TestFlattenAndSortRequirements_SingleFirst(t *testing.T) {
|
||||
// Singles come before groups; within singles, alphabetical
|
||||
reqs := map[string]EnvRequirement{
|
||||
envRequirementKey([]string{"ZETA"}): {Name: "ZETA"},
|
||||
envRequirementKey([]string{"ZETA"}): {Name: "ZETA"},
|
||||
envRequirementKey([]string{"ALPHA"}): {Name: "ALPHA"},
|
||||
}
|
||||
got := flattenAndSortRequirements(reqs)
|
||||
@@ -247,7 +247,7 @@ func TestFlattenAndSortRequirements_SingleFirst(t *testing.T) {
|
||||
|
||||
func TestFlattenAndSortRequirements_GroupsAfterSingles(t *testing.T) {
|
||||
reqs := map[string]EnvRequirement{
|
||||
envRequirementKey([]string{"X"}): {Name: "X"}, // single
|
||||
envRequirementKey([]string{"X"}): {Name: "X"}, // single
|
||||
envRequirementKey([]string{"A", "B"}): {AnyOf: []string{"A", "B"}}, // group
|
||||
}
|
||||
got := flattenAndSortRequirements(reqs)
|
||||
@@ -429,8 +429,8 @@ func TestCollectOrgEnv_WorkspaceLevel(t *testing.T) {
|
||||
tmpl := &OrgTemplate{
|
||||
Workspaces: []OrgWorkspace{
|
||||
{
|
||||
Name: "Dev",
|
||||
RequiredEnv: []EnvRequirement{{Name: "DEV_KEY"}},
|
||||
Name: "Dev",
|
||||
RequiredEnv: []EnvRequirement{{Name: "DEV_KEY"}},
|
||||
RecommendedEnv: []EnvRequirement{{Name: "DEV_TOOL"}},
|
||||
},
|
||||
},
|
||||
@@ -456,12 +456,12 @@ func TestCollectOrgEnv_DeepNesting(t *testing.T) {
|
||||
RequiredEnv: []EnvRequirement{{Name: "ORG_LEVEL"}},
|
||||
Workspaces: []OrgWorkspace{
|
||||
{
|
||||
Name: "Root",
|
||||
RequiredEnv: []EnvRequirement{{Name: "ROOT_LEVEL"}},
|
||||
Name: "Root",
|
||||
RequiredEnv: []EnvRequirement{{Name: "ROOT_LEVEL"}},
|
||||
Children: []OrgWorkspace{
|
||||
{
|
||||
Name: "Child",
|
||||
RequiredEnv: []EnvRequirement{{Name: "CHILD_LEVEL"}},
|
||||
Name: "Child",
|
||||
RequiredEnv: []EnvRequirement{{Name: "CHILD_LEVEL"}},
|
||||
Children: []OrgWorkspace{
|
||||
{Name: "GrandChild", RecommendedEnv: []EnvRequirement{{Name: "GRANDCHILD_TOOL"}}},
|
||||
},
|
||||
@@ -536,4 +536,3 @@ func TestCollectOrgEnv_MixedCasePreservesSort(t *testing.T) {
|
||||
t.Errorf("A,B group should come first: got %+v", req[2])
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -33,11 +33,11 @@ GITEA_SSH_KEY_PATH=/etc/molecule-bootstrap/personas/dev-lead/ssh_priv
|
||||
loadPersonaEnvFile("dev-lead", out)
|
||||
|
||||
want := map[string]string{
|
||||
"GITEA_USER": "dev-lead",
|
||||
"GITEA_USER_EMAIL": "dev-lead@agents.moleculesai.app",
|
||||
"GITEA_TOKEN": "abc123",
|
||||
"GITEA_TOKEN_SCOPES": "write:repository,write:issue,read:user",
|
||||
"GITEA_SSH_KEY_PATH": "/etc/molecule-bootstrap/personas/dev-lead/ssh_priv",
|
||||
"GITEA_USER": "dev-lead",
|
||||
"GITEA_USER_EMAIL": "dev-lead@agents.moleculesai.app",
|
||||
"GITEA_TOKEN": "abc123",
|
||||
"GITEA_TOKEN_SCOPES": "write:repository,write:issue,read:user",
|
||||
"GITEA_SSH_KEY_PATH": "/etc/molecule-bootstrap/personas/dev-lead/ssh_priv",
|
||||
}
|
||||
if len(out) != len(want) {
|
||||
t.Fatalf("got %d keys, want %d: %#v", len(out), len(want), out)
|
||||
@@ -153,12 +153,6 @@ func TestIsSafeRoleName_Acceptance(t *testing.T) {
|
||||
}
|
||||
}
|
||||
bad := []string{
|
||||
"", ".", "..", "with/slash", "/abs", "dot.in.middle",
|
||||
"with space", "back\\slash", "trailing-", // trailing-hyphen is fine actually
|
||||
"with$dollar", "with?question", "newline\nsplit",
|
||||
}
|
||||
// trailing-hyphen IS allowed; remove from "bad" list:
|
||||
bad = []string{
|
||||
"", ".", "..", "with/slash", "/abs", "dot.in.middle",
|
||||
"with space", "back\\slash", "with$dollar", "with?question",
|
||||
"newline\nsplit",
|
||||
|
||||
@@ -2,7 +2,6 @@ package handlers
|
||||
|
||||
import (
|
||||
"archive/tar"
|
||||
"bytes"
|
||||
"context"
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
@@ -19,7 +18,6 @@ import (
|
||||
|
||||
"github.com/Molecule-AI/molecule-monorepo/platform/internal/envx"
|
||||
"github.com/Molecule-AI/molecule-monorepo/platform/internal/plugins"
|
||||
"github.com/docker/docker/api/types/container"
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
@@ -436,53 +434,6 @@ func regexpEscapeForAwk(s string) string {
|
||||
return b.String()
|
||||
}
|
||||
|
||||
// copyPluginToContainer creates a tar from a host directory and copies it into /configs/plugins/<name>/.
|
||||
// The tar entries are prefixed with plugins/<name>/ so Docker creates the directory structure.
|
||||
func (h *PluginsHandler) copyPluginToContainer(ctx context.Context, containerName, hostDir, pluginName string) error {
|
||||
var buf bytes.Buffer
|
||||
tw := tar.NewWriter(&buf)
|
||||
|
||||
err := filepath.Walk(hostDir, func(path string, info os.FileInfo, err error) error {
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
rel, err := filepath.Rel(hostDir, path)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
header, err := tar.FileInfoHeader(info, "")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
// Prefix: plugins/<pluginName>/<rel> → extracts under /configs/
|
||||
header.Name = filepath.Join("plugins", pluginName, rel)
|
||||
|
||||
if err := tw.WriteHeader(header); err != nil {
|
||||
return err
|
||||
}
|
||||
if !info.IsDir() {
|
||||
data, err := os.ReadFile(path)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if _, err := tw.Write(data); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create tar from %s: %w", hostDir, err)
|
||||
}
|
||||
if err := tw.Close(); err != nil {
|
||||
return fmt.Errorf("failed to close tar: %w", err)
|
||||
}
|
||||
|
||||
// Copy to /configs — the tar's plugins/<name>/ prefix creates the directory
|
||||
return h.docker.CopyToContainer(ctx, containerName, "/configs", &buf, container.CopyToContainerOptions{})
|
||||
}
|
||||
|
||||
// streamDirAsTar writes every regular file + dir under `root` to the tar
|
||||
// writer, using paths relative to root so the caller's unpack produces
|
||||
// `<name>/<original-layout>` without any leading tempdir components.
|
||||
|
||||
@@ -119,7 +119,7 @@ func TestResolveAgentURLForRestartSignal_CacheHit(t *testing.T) {
|
||||
// returned and propagated when neither Redis cache nor DB lookup succeeds.
|
||||
func TestResolveAgentURLForRestartSignal_DBError(t *testing.T) {
|
||||
mock := setupTestDB(t) // must come before setupTestRedis so db.DB is correct
|
||||
_ = setupTestRedis(t) // empty → cache miss
|
||||
_ = setupTestRedis(t) // empty → cache miss
|
||||
|
||||
h := newHandlerWithTestDeps(t)
|
||||
|
||||
@@ -209,10 +209,10 @@ func TestGracefulPreRestart_Success(t *testing.T) {
|
||||
// Pre-populate Redis cache with the test server URL
|
||||
_ = setupTestRedisWithURL(t, srv.URL)
|
||||
|
||||
// Use an embedded struct to override resolveAgentURLForRestartSignal.
|
||||
// Use a wrapper so gracefulPreRestart runs through the embedded handler.
|
||||
hWrapper := &resolveURLTestWrapper{
|
||||
WorkspaceHandler: newHandlerWithTestDeps(t),
|
||||
testURL: srv.URL + "/agent",
|
||||
testURL: srv.URL + "/agent",
|
||||
}
|
||||
|
||||
// gracefulPreRestart runs in a goroutine with its own timeout.
|
||||
@@ -235,7 +235,7 @@ func TestGracefulPreRestart_NotImplemented(t *testing.T) {
|
||||
|
||||
hWrapper := &resolveURLTestWrapper{
|
||||
WorkspaceHandler: newHandlerWithTestDeps(t),
|
||||
testURL: srv.URL + "/agent",
|
||||
testURL: srv.URL + "/agent",
|
||||
}
|
||||
|
||||
hWrapper.gracefulPreRestart(context.Background(), "ws-noimpl-999")
|
||||
@@ -253,7 +253,7 @@ func TestGracefulPreRestart_ConnectionRefused(t *testing.T) {
|
||||
|
||||
hWrapper := &resolveURLTestWrapper{
|
||||
WorkspaceHandler: newHandlerWithTestDeps(t),
|
||||
testURL: "http://localhost:19999/agent",
|
||||
testURL: "http://localhost:19999/agent",
|
||||
}
|
||||
|
||||
hWrapper.gracefulPreRestart(context.Background(), "ws-unreachable-000")
|
||||
@@ -269,7 +269,7 @@ func TestGracefulPreRestart_URLResolutionError(t *testing.T) {
|
||||
|
||||
hWrapper := &resolveURLTestWrapper{
|
||||
WorkspaceHandler: newHandlerWithTestDeps(t),
|
||||
errToReturn: context.DeadlineExceeded,
|
||||
errToReturn: context.DeadlineExceeded,
|
||||
}
|
||||
|
||||
hWrapper.gracefulPreRestart(context.Background(), "ws-url-err-111")
|
||||
@@ -279,21 +279,14 @@ func TestGracefulPreRestart_URLResolutionError(t *testing.T) {
|
||||
|
||||
// ─── helpers ─────────────────────────────────────────────────────────────────
|
||||
|
||||
// resolveURLTestWrapper embeds *WorkspaceHandler and overrides
|
||||
// resolveAgentURLForRestartSignal so tests can inject a fixed URL or error.
|
||||
// resolveURLTestWrapper embeds *WorkspaceHandler for tests that exercise
|
||||
// gracefulPreRestart through a wrapper value.
|
||||
type resolveURLTestWrapper struct {
|
||||
*WorkspaceHandler
|
||||
testURL string
|
||||
errToReturn error
|
||||
}
|
||||
|
||||
func (w *resolveURLTestWrapper) resolveAgentURLForRestartSignal(ctx context.Context, workspaceID string) (string, error) {
|
||||
if w.errToReturn != nil {
|
||||
return "", w.errToReturn
|
||||
}
|
||||
return w.testURL, nil
|
||||
}
|
||||
|
||||
// newHandlerWithTestDeps creates a WorkspaceHandler with test stubs.
|
||||
func newHandlerWithTestDeps(t *testing.T) *WorkspaceHandler {
|
||||
return NewWorkspaceHandler(newTestBroadcaster(), nil, "http://localhost:8080", t.TempDir())
|
||||
@@ -313,4 +306,4 @@ func setupTestRedisWithURL(t *testing.T, url string) *miniredis.Miniredis {
|
||||
}
|
||||
t.Cleanup(func() { mr.Close() })
|
||||
return mr
|
||||
}
|
||||
}
|
||||
|
||||
@@ -61,7 +61,6 @@ func resolveRestartTemplate(configsDir, wsName, dbRuntime string, body restartTe
|
||||
candidatePath, resolveErr := resolveInsideRoot(configsDir, template)
|
||||
if resolveErr != nil {
|
||||
log.Printf("Restart: invalid template %q: %v — proceeding without it", template, resolveErr)
|
||||
template = ""
|
||||
} else if _, err := os.Stat(candidatePath); err == nil {
|
||||
return candidatePath, template
|
||||
} else {
|
||||
|
||||
@@ -3,8 +3,6 @@ package handlers
|
||||
import (
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/Molecule-AI/molecule-monorepo/platform/internal/provisioner"
|
||||
)
|
||||
|
||||
// Tests for the SaaS-aware default-tier resolution introduced in #2901
|
||||
@@ -21,19 +19,6 @@ import (
|
||||
// was hardcoded to 3 and silently disagreed with the create-
|
||||
// handler default on SaaS.
|
||||
|
||||
// stubCPProv is a minimal stand-in for the CP provisioner — only
|
||||
// exercises the IsSaaS / HasProvisioner contract, never invoked in
|
||||
// these tests.
|
||||
type stubCPProv struct{}
|
||||
|
||||
func (stubCPProv) Start(_ interface{}, _ provisioner.WorkspaceConfig) (string, error) {
|
||||
return "", nil
|
||||
}
|
||||
func (stubCPProv) Stop(_ interface{}, _ string) error { return nil }
|
||||
func (stubCPProv) Restart(_ interface{}, _ provisioner.WorkspaceConfig) (string, error) {
|
||||
return "", nil
|
||||
}
|
||||
|
||||
func TestIsSaaS_TrueWhenCPProvWired(t *testing.T) {
|
||||
h := &WorkspaceHandler{cpProv: &trackingCPProv{}}
|
||||
if !h.IsSaaS() {
|
||||
|
||||
@@ -117,14 +117,6 @@ func resolveWorkspaceRootPath(runtime, root string) string {
|
||||
// EIC misconfiguration.
|
||||
const eicFileOpTimeout = 30 * time.Second
|
||||
|
||||
// eicFileOpTimeout was historically named eicFileWriteTimeout when the
|
||||
// only EIC op was writeFile. Keep an alias so any external test that
|
||||
// pinned the old name still compiles; rename can land as a follow-up
|
||||
// once we've gone a release without the alias being touched.
|
||||
//
|
||||
//nolint:revive // intentional alias for back-compat with prior tests.
|
||||
const eicFileWriteTimeout = eicFileOpTimeout
|
||||
|
||||
// eicSSHSession describes an open EIC tunnel ready for an ssh subprocess.
|
||||
// Only valid inside the closure passed to withEICTunnel — the underlying
|
||||
// keypair + tunnel are torn down when the closure returns.
|
||||
|
||||
@@ -88,7 +88,7 @@ func generateDefaultConfig(name string, files map[string]string, tier int) strin
|
||||
tier = 3
|
||||
}
|
||||
cfg.WriteString("version: 1.0.0\n")
|
||||
cfg.WriteString(fmt.Sprintf("tier: %d\n", tier))
|
||||
fmt.Fprintf(&cfg, "tier: %d\n", tier)
|
||||
cfg.WriteString("model: anthropic:claude-haiku-4-5-20251001\n")
|
||||
cfg.WriteString("\nprompt_files:\n")
|
||||
if len(promptFiles) > 0 {
|
||||
|
||||
@@ -278,7 +278,7 @@ func (h *TemplatesHandler) ListFiles(c *gin.Context) {
|
||||
// 1:1, but Go can't implicit-convert named struct types).
|
||||
out := make([]fileEntry, 0, len(entries))
|
||||
for _, e := range entries {
|
||||
out = append(out, fileEntry{Path: e.Path, Size: e.Size, Dir: e.Dir})
|
||||
out = append(out, fileEntry(e))
|
||||
}
|
||||
c.JSON(http.StatusOK, out)
|
||||
return
|
||||
@@ -373,9 +373,7 @@ func (h *TemplatesHandler) ListFiles(c *gin.Context) {
|
||||
func (h *TemplatesHandler) ReadFile(c *gin.Context) {
|
||||
workspaceID := c.Param("id")
|
||||
filePath := c.Param("path")
|
||||
if strings.HasPrefix(filePath, "/") {
|
||||
filePath = filePath[1:]
|
||||
}
|
||||
filePath = strings.TrimPrefix(filePath, "/")
|
||||
|
||||
if err := validateRelPath(filePath); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid path"})
|
||||
@@ -480,9 +478,7 @@ func (h *TemplatesHandler) ReadFile(c *gin.Context) {
|
||||
func (h *TemplatesHandler) WriteFile(c *gin.Context) {
|
||||
workspaceID := c.Param("id")
|
||||
filePath := c.Param("path")
|
||||
if strings.HasPrefix(filePath, "/") {
|
||||
filePath = filePath[1:]
|
||||
}
|
||||
filePath = strings.TrimPrefix(filePath, "/")
|
||||
|
||||
if err := validateRelPath(filePath); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid path"})
|
||||
@@ -636,4 +632,3 @@ func (h *TemplatesHandler) DeleteFile(c *gin.Context) {
|
||||
go h.wh.RestartByID(workspaceID)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -63,13 +63,6 @@ const workspacesUniqueIndexName = "workspaces_parent_name_uniq"
|
||||
// Conflict — the user must rename and re-try.
|
||||
var errWorkspaceNameExhausted = errors.New("workspace name exhausted: too many duplicates of base name under same parent")
|
||||
|
||||
// dbExec is the minimum surface our retry helper needs from
|
||||
// *sql.Tx (or *sql.DB). Declared as an interface so tests can
|
||||
// substitute a fake without standing up a real DB connection.
|
||||
type dbExec interface {
|
||||
ExecContext(ctx context.Context, query string, args ...any) (sql.Result, error)
|
||||
}
|
||||
|
||||
// insertWorkspaceWithNameRetry runs the workspace INSERT and, if it
|
||||
// hits the parent-name unique-violation, retries with a suffixed
|
||||
// name. Returns the name actually persisted (which the caller MUST
|
||||
|
||||
@@ -109,21 +109,6 @@ func (h *WorkspaceHandler) State(c *gin.Context) {
|
||||
})
|
||||
}
|
||||
|
||||
// sensitiveUpdateFields documents fields that carry elevated risk — kept as
|
||||
// an explicit list for code readability and future audits. Auth is now fully
|
||||
// enforced at the router layer (WorkspaceAuth middleware, #680 IDOR fix);
|
||||
// this map is no longer used for in-handler gate logic but is preserved to
|
||||
// surface the risk classification clearly.
|
||||
//
|
||||
// budget_limit is intentionally NOT here — the dedicated PATCH
|
||||
// /workspaces/:id/budget (AdminAuth) is the only write path (#611).
|
||||
var sensitiveUpdateFields = map[string]struct{}{
|
||||
"tier": {},
|
||||
"parent_id": {},
|
||||
"runtime": {},
|
||||
"workspace_dir": {},
|
||||
}
|
||||
|
||||
// Update handles PATCH /workspaces/:id
|
||||
func (h *WorkspaceHandler) Update(c *gin.Context) {
|
||||
id := c.Param("id")
|
||||
|
||||
@@ -156,10 +156,7 @@ func TestProvisionWorkspaceAuto_RoutesToCPWhenSet(t *testing.T) {
|
||||
|
||||
// Wait for the goroutine to land in cpProv.Start (or give up).
|
||||
deadline := time.Now().Add(2 * time.Second)
|
||||
for {
|
||||
if len(rec.startedSnapshot()) > 0 {
|
||||
break
|
||||
}
|
||||
for len(rec.startedSnapshot()) == 0 {
|
||||
if time.Now().After(deadline) {
|
||||
t.Fatalf("timed out waiting for cpProv.Start; recorded=%v", rec.startedSnapshot())
|
||||
}
|
||||
@@ -626,10 +623,7 @@ func TestRestartWorkspaceAuto_RoutesToCPWhenSet(t *testing.T) {
|
||||
// the tracking stub, so we expect at least one Stop and (eventually)
|
||||
// at least one Start.
|
||||
deadline := time.Now().Add(2 * time.Second)
|
||||
for {
|
||||
if len(rec.stoppedSnapshot()) > 0 && len(rec.startedSnapshot()) > 0 {
|
||||
break
|
||||
}
|
||||
for len(rec.stoppedSnapshot()) == 0 || len(rec.startedSnapshot()) == 0 {
|
||||
if time.Now().After(deadline) {
|
||||
t.Fatalf("timed out waiting for cpProv.Stop + cpProv.Start; stopped=%v started=%v",
|
||||
rec.stoppedSnapshot(), rec.startedSnapshot())
|
||||
@@ -907,7 +901,7 @@ func stripGoComments(src []byte) []byte {
|
||||
// Block comment
|
||||
if i+1 < len(src) && src[i] == '/' && src[i+1] == '*' {
|
||||
i += 2
|
||||
for i+1 < len(src) && !(src[i] == '*' && src[i+1] == '/') {
|
||||
for i+1 < len(src) && (src[i] != '*' || src[i+1] != '/') {
|
||||
i++
|
||||
}
|
||||
i++ // skip closing /
|
||||
|
||||
@@ -13,7 +13,6 @@ import (
|
||||
"github.com/Molecule-AI/molecule-monorepo/platform/internal/models"
|
||||
"github.com/Molecule-AI/molecule-monorepo/platform/internal/plugins"
|
||||
"github.com/Molecule-AI/molecule-monorepo/platform/internal/provisioner"
|
||||
"github.com/Molecule-AI/molecule-monorepo/platform/pkg/provisionhook"
|
||||
"gopkg.in/yaml.v3"
|
||||
)
|
||||
|
||||
@@ -49,7 +48,7 @@ func TestConfigDirName(t *testing.T) {
|
||||
{"abc-def-ghi", "ws-abc-def-ghi"},
|
||||
{"abcdefghijklmnop", "ws-abcdefghijkl"}, // truncated at 12
|
||||
{"short", "ws-short"},
|
||||
{"123456789012", "ws-123456789012"}, // exactly 12
|
||||
{"123456789012", "ws-123456789012"}, // exactly 12
|
||||
{"1234567890123", "ws-123456789012"}, // 13 chars, truncated
|
||||
}
|
||||
|
||||
@@ -483,11 +482,11 @@ func TestSanitizeRuntime_Allowlist(t *testing.T) {
|
||||
{"openclaw", "openclaw"},
|
||||
{"hermes", "hermes"},
|
||||
{"codex", "codex"},
|
||||
{"langgraph", "claude-code"}, // deprecated → default
|
||||
{"deepagents", "claude-code"}, // deprecated → default
|
||||
{"crewai", "claude-code"}, // deprecated → default
|
||||
{"autogen", "claude-code"}, // deprecated → default
|
||||
{"not-a-runtime", "claude-code"}, // unknown → default
|
||||
{"langgraph", "claude-code"}, // deprecated → default
|
||||
{"deepagents", "claude-code"}, // deprecated → default
|
||||
{"crewai", "claude-code"}, // deprecated → default
|
||||
{"autogen", "claude-code"}, // deprecated → default
|
||||
{"not-a-runtime", "claude-code"}, // unknown → default
|
||||
{"../../sensitive", "claude-code"}, // path traversal probe → default
|
||||
{"langgraph\nevil", "claude-code"}, // newline injection → default (not in allowlist)
|
||||
}
|
||||
@@ -533,7 +532,7 @@ func TestSeedInitialMemories_TruncatesOversizedContent(t *testing.T) {
|
||||
},
|
||||
{
|
||||
name: "well under limit — passes through unchanged",
|
||||
contentLen: 50_000,
|
||||
contentLen: 50_000,
|
||||
expectInsert: true,
|
||||
},
|
||||
}
|
||||
@@ -1008,13 +1007,6 @@ func TestSeedInitialMemories_OversizedWithSecrets(t *testing.T) {
|
||||
// Each test injects a known-internal error and verifies the response body
|
||||
// or broadcast payload contains ONLY the generic prod-safe message.
|
||||
|
||||
// errInternalDB is a pkg-level error whose .Error() output matches a real
|
||||
// postgres driver error shape — used to simulate DB failure without a live DB.
|
||||
var errInternalDB = fmt.Errorf("pq: connection refused")
|
||||
|
||||
// errInternalOS simulates an OS-level error.
|
||||
var errInternalOS = fmt.Errorf("operation failed: no such file or directory")
|
||||
|
||||
// captureBroadcaster is a test broadcaster that captures the last data
|
||||
// payload passed to RecordAndBroadcast so tests can inspect it. Now
|
||||
// satisfies events.EventEmitter (#1814) directly — RecordAndBroadcast
|
||||
@@ -1022,7 +1014,6 @@ var errInternalOS = fmt.Errorf("operation failed: no such file or directory")
|
||||
// WorkspaceHandler paths under test call it.
|
||||
type captureBroadcaster struct {
|
||||
lastData map[string]interface{}
|
||||
lastErr error
|
||||
}
|
||||
|
||||
// BroadcastOnly is required to satisfy events.EventEmitter. None of the
|
||||
@@ -1042,46 +1033,6 @@ func (c *captureBroadcaster) RecordAndBroadcast(_ context.Context, _, _ string,
|
||||
return nil
|
||||
}
|
||||
|
||||
// unsafeErrorStrings lists substrings that must NEVER appear in external-facing
|
||||
// error responses. Covers DB driver errors, OS errors, and internal paths.
|
||||
var unsafeErrorStrings = []string{
|
||||
"pq:",
|
||||
"pq ",
|
||||
"connection refused",
|
||||
"deadlock",
|
||||
"no such file",
|
||||
"/var/",
|
||||
"/tmp/",
|
||||
"postgres",
|
||||
"PostgreSQL",
|
||||
"sql: ",
|
||||
":8080",
|
||||
"127.0.0.1",
|
||||
"localhost",
|
||||
"secret",
|
||||
"token",
|
||||
}
|
||||
|
||||
// containsUnsafeString checks whether any prohibited substring appears in
|
||||
// a string value recursively (handles nested maps for safety).
|
||||
func containsUnsafeString(v interface{}) bool {
|
||||
switch v := v.(type) {
|
||||
case string:
|
||||
for _, unsafe := range unsafeErrorStrings {
|
||||
if strings.Contains(v, unsafe) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
case map[string]interface{}:
|
||||
for _, val := range v {
|
||||
if containsUnsafeString(val) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// TestProvisionWorkspace_NoInternalErrorsInBroadcast asserts that provisionWorkspace
|
||||
// never leaks internal error details in WORKSPACE_PROVISION_FAILED broadcasts.
|
||||
// Regression test for issue #1206 — drives the global-secrets decrypt-fail
|
||||
@@ -1251,12 +1202,12 @@ func TestProvisionWorkspaceCP_NoInternalErrorsInBroadcast(t *testing.T) {
|
||||
continue
|
||||
}
|
||||
for _, leakMarker := range []string{
|
||||
"t3.large", // machine type
|
||||
"ami-0abcd1234efgh5678", // AMI id
|
||||
"vpc-deadbeef", // VPC id
|
||||
"subnet-cafef00d", // subnet id
|
||||
"InvalidSubnet.Conflict", // raw upstream HTTP body
|
||||
"CP API rejected", // raw error string head
|
||||
"t3.large", // machine type
|
||||
"ami-0abcd1234efgh5678", // AMI id
|
||||
"vpc-deadbeef", // VPC id
|
||||
"subnet-cafef00d", // subnet id
|
||||
"InvalidSubnet.Conflict", // raw upstream HTTP body
|
||||
"CP API rejected", // raw error string head
|
||||
} {
|
||||
if strings.Contains(s, leakMarker) {
|
||||
t.Errorf("broadcast leaked %q in payload value %q", leakMarker, s)
|
||||
@@ -1268,17 +1219,6 @@ func TestProvisionWorkspaceCP_NoInternalErrorsInBroadcast(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
// mockEnvMutator is a provisionhook.Registry stub that always returns a fixed error.
|
||||
type mockEnvMutator struct {
|
||||
returnErr error
|
||||
}
|
||||
|
||||
func (m *mockEnvMutator) Run(_ context.Context, _ string, _ map[string]string) error {
|
||||
return m.returnErr
|
||||
}
|
||||
|
||||
func (m *mockEnvMutator) Register(_ provisionhook.EnvMutator) {}
|
||||
|
||||
// TestResolveAndStage_NoInternalErrorsInHTTPErr asserts that
|
||||
// resolveAndStage never puts internal error detail (resolver error
|
||||
// strings, file-system paths, upstream rate-limit text, auth tokens
|
||||
|
||||
@@ -794,6 +794,7 @@ func TestDoJSON_204OnEndpointExpectingBody(t *testing.T) {
|
||||
}
|
||||
if got == nil {
|
||||
t.Error("got nil SearchResponse, want zero value")
|
||||
return
|
||||
}
|
||||
if len(got.Memories) != 0 {
|
||||
t.Errorf("memories = %v, want empty", got.Memories)
|
||||
|
||||
@@ -109,7 +109,7 @@ func (p *flatPlugin) handleNamespace(w http.ResponseWriter, r *http.Request) {
|
||||
p.mu.Unlock()
|
||||
w.WriteHeader(204)
|
||||
default:
|
||||
http.Error(w, "method not allowed", 405)
|
||||
http.Error(w, "method not allowed", http.StatusMethodNotAllowed)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -22,14 +22,7 @@ const chainQuerySnippet = "WITH RECURSIVE chain"
|
||||
// Helper makes per-test mock setup terser.
|
||||
func setupMockDB(t *testing.T) (*sql.DB, sqlmock.Sqlmock) {
|
||||
t.Helper()
|
||||
db, mock, err := sqlmock.New(sqlmock.QueryMatcherOption(sqlmock.QueryMatcherEqual))
|
||||
if err != nil {
|
||||
t.Fatalf("sqlmock new: %v", err)
|
||||
}
|
||||
t.Cleanup(func() { _ = db.Close() })
|
||||
// We use QueryMatcherEqual but with regex-based ExpectQuery elsewhere
|
||||
// for flexibility. Actually swap to regex for the recursive query:
|
||||
db, mock, err = sqlmock.New() // default = regex
|
||||
db, mock, err := sqlmock.New() // default = regex
|
||||
if err != nil {
|
||||
t.Fatalf("sqlmock new: %v", err)
|
||||
}
|
||||
@@ -186,8 +179,8 @@ func TestWalkChain_RowsErr(t *testing.T) {
|
||||
|
||||
func TestDerive(t *testing.T) {
|
||||
cases := []struct {
|
||||
name string
|
||||
chain []chainNode
|
||||
name string
|
||||
chain []chainNode
|
||||
wantWS, wantTeam, wantOrg string
|
||||
}{
|
||||
{
|
||||
|
||||
@@ -80,7 +80,6 @@ func (s *Store) PatchNamespace(ctx context.Context, name string, body contract.N
|
||||
}
|
||||
parts = append(parts, fmt.Sprintf("metadata = $%d", idx))
|
||||
args = append(args, metadata)
|
||||
idx++
|
||||
}
|
||||
query := fmt.Sprintf(`
|
||||
UPDATE memory_namespaces SET %s
|
||||
@@ -294,7 +293,9 @@ func (s *Store) Search(ctx context.Context, body contract.SearchRequest) (*contr
|
||||
|
||||
// --- Helpers ---
|
||||
|
||||
func scanNamespace(row interface{ Scan(dest ...interface{}) error }) (*contract.Namespace, error) {
|
||||
func scanNamespace(row interface {
|
||||
Scan(dest ...interface{}) error
|
||||
}) (*contract.Namespace, error) {
|
||||
var ns contract.Namespace
|
||||
var kindStr string
|
||||
var expires sql.NullTime
|
||||
@@ -315,7 +316,9 @@ func scanNamespace(row interface{ Scan(dest ...interface{}) error }) (*contract.
|
||||
return &ns, nil
|
||||
}
|
||||
|
||||
func scanMemory(row interface{ Scan(dest ...interface{}) error }) (*contract.Memory, error) {
|
||||
func scanMemory(row interface {
|
||||
Scan(dest ...interface{}) error
|
||||
}) (*contract.Memory, error) {
|
||||
var m contract.Memory
|
||||
var kindStr, sourceStr string
|
||||
var expires sql.NullTime
|
||||
@@ -375,7 +378,7 @@ func vectorString(v []float32) string {
|
||||
if i > 0 {
|
||||
b.WriteByte(',')
|
||||
}
|
||||
b.WriteString(fmt.Sprintf("%g", x))
|
||||
fmt.Fprintf(&b, "%g", x)
|
||||
}
|
||||
b.WriteByte(']')
|
||||
return b.String()
|
||||
|
||||
@@ -120,7 +120,6 @@ func WorkspaceAuth(database *sql.DB) gin.HandlerFunc {
|
||||
return
|
||||
}
|
||||
c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{"error": "missing workspace auth token"})
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
@@ -325,7 +324,6 @@ func CanvasOrBearer(database *sql.DB) gin.HandlerFunc {
|
||||
}
|
||||
|
||||
c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{"error": "admin auth required"})
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -37,16 +37,6 @@ const validateAnyTokenSelectQuery = "SELECT t\\.id, t\\.workspace_id.*FROM works
|
||||
// validateTokenUpdateQuery is matched for the best-effort last_used_at UPDATE.
|
||||
const validateTokenUpdateQuery = "UPDATE workspace_auth_tokens SET last_used_at"
|
||||
|
||||
// newWorkspaceAuthRouter builds a minimal gin router that applies WorkspaceAuth
|
||||
// to a single GET /workspaces/:id/test route, returning 200 on success.
|
||||
func newWorkspaceAuthRouter(db sqlmock.Sqlmock, realDB interface{ Close() error }) *gin.Engine {
|
||||
_ = db // unused directly; sqlmock intercepts calls via the *sql.DB pointer
|
||||
r := gin.New()
|
||||
// We need the *sql.DB, not the mock. The caller passes mockDB via the
|
||||
// test-local var — this helper is only used to build the router topology.
|
||||
return r
|
||||
}
|
||||
|
||||
// TestWorkspaceAuth_351_NoBearer_Returns401 — strict contract: every request
|
||||
// under /workspaces/:id/* must carry a valid bearer, period. No fail-open,
|
||||
// no grace period, no existence check. The middleware goes straight to
|
||||
@@ -483,10 +473,6 @@ func TestAdminAuth_InvalidBearer_Returns401(t *testing.T) {
|
||||
// (no ::text cast — sql.NullString handles the NULL scan natively).
|
||||
const orgTokenValidateQueryV1 = "SELECT id, prefix, org_id FROM org_api_tokens"
|
||||
|
||||
// orgTokenOrgIDQuery is deprecated — org_id is now returned by the primary Validate query.
|
||||
// Kept here to avoid breaking other test files that may reference it.
|
||||
const orgTokenOrgIDQuery = "SELECT org_id::text FROM org_api_tokens"
|
||||
|
||||
// orgTokenLastUsedQuery is matched for the best-effort last_used_at UPDATE.
|
||||
const orgTokenLastUsedQuery = "UPDATE org_api_tokens SET last_used_at"
|
||||
|
||||
@@ -495,10 +481,10 @@ const orgTokenLastUsedQuery = "UPDATE org_api_tokens SET last_used_at"
|
||||
// and orgCallerID can look it up downstream.
|
||||
func TestAdminAuth_OrgToken_SetsOrgID(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
orgIDFromDB interface{} // sqlmock row value: nil, "", or "ws-org-1"
|
||||
wantOrgIDCtx bool // expect c.Get("org_id") to be set
|
||||
wantOrgIDVal string // if set, expected value
|
||||
name string
|
||||
orgIDFromDB interface{} // sqlmock row value: nil, "", or "ws-org-1"
|
||||
wantOrgIDCtx bool // expect c.Get("org_id") to be set
|
||||
wantOrgIDVal string // if set, expected value
|
||||
}{
|
||||
{
|
||||
name: "post-fix token has org_id set in context",
|
||||
|
||||
@@ -3,6 +3,8 @@ package plugins
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"os"
|
||||
"os/exec"
|
||||
"testing"
|
||||
)
|
||||
|
||||
@@ -64,31 +66,6 @@ func TestResolveRef_MapsNotFoundToErrPluginNotFound(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
// stubGitForResolveRef creates a stub that handles fetch + rev-parse for ResolveRef.
|
||||
func stubGitForResolveRef(t *testing.T, sha string) func(ctx context.Context, dir string, args ...string) error {
|
||||
return func(ctx context.Context, dir string, args ...string) error {
|
||||
if ctx.Err() != nil {
|
||||
return ctx.Err()
|
||||
}
|
||||
if len(args) < 1 {
|
||||
return errors.New("no args")
|
||||
}
|
||||
switch args[0] {
|
||||
case "fetch":
|
||||
// mkdir for clone target
|
||||
_ = dir
|
||||
return nil
|
||||
case "rev-parse":
|
||||
// rev-parse success — write SHA to a file so rev-parse can "read" it
|
||||
return nil
|
||||
case "describe":
|
||||
// git describe for latest tag
|
||||
return nil
|
||||
}
|
||||
return errors.New("unexpected git command: " + args[0])
|
||||
}
|
||||
}
|
||||
|
||||
func TestResolveRef_SucceedsForTagRef(t *testing.T) {
|
||||
// This test verifies the happy path: fetch + rev-parse succeed.
|
||||
// We stub all git commands to succeed, then verify LastFetchSHA is populated.
|
||||
@@ -99,18 +76,43 @@ func TestResolveRef_SucceedsForTagRef(t *testing.T) {
|
||||
return ctx.Err()
|
||||
}
|
||||
calls[args[0]] = true
|
||||
if args[0] == "fetch" {
|
||||
run := func(name string, args ...string) error {
|
||||
cmd := exec.CommandContext(ctx, name, args...)
|
||||
cmd.Dir = dir
|
||||
cmd.Env = append(os.Environ(),
|
||||
"GIT_AUTHOR_NAME=test",
|
||||
"GIT_AUTHOR_EMAIL=test@example.invalid",
|
||||
"GIT_COMMITTER_NAME=test",
|
||||
"GIT_COMMITTER_EMAIL=test@example.invalid",
|
||||
)
|
||||
return cmd.Run()
|
||||
}
|
||||
if err := run("git", "init"); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := os.WriteFile(dir+"/README.md", []byte("test\n"), 0o644); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := run("git", "add", "README.md"); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := run("git", "commit", "-m", "test"); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := run("git", "tag", "v1.0.0"); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
},
|
||||
}
|
||||
_, err := r.ResolveRef(context.Background(), "org/repo#tag:v1.0.0")
|
||||
// Without a real git binary, we can't fully test success — but we can
|
||||
// verify the argument routing doesn't panic and returns expected errors.
|
||||
if err != nil && !errors.Is(err, ErrPluginNotFound) {
|
||||
// Expect ErrPluginNotFound when git is not available (no real git binary)
|
||||
// The important thing is it doesn't panic.
|
||||
if err != nil {
|
||||
t.Fatalf("ResolveRef returned unexpected error: %v", err)
|
||||
}
|
||||
if !calls["fetch"] && !calls["rev-parse"] {
|
||||
// At least one git command should have been called
|
||||
t.Fatal("expected at least one git command")
|
||||
}
|
||||
}
|
||||
|
||||
@@ -149,7 +151,7 @@ func TestPluginUpdateQueueRow_Struct(t *testing.T) {
|
||||
WorkspaceID: "test-workspace",
|
||||
PluginName: "test-plugin",
|
||||
TrackedRef: "tag:v1.0.0",
|
||||
CurrentSHA: "abc123",
|
||||
CurrentSHA: "abc123",
|
||||
LatestSHA: "def456",
|
||||
Status: "pending",
|
||||
}
|
||||
|
||||
@@ -57,11 +57,11 @@ func (r *GithubResolver) Scheme() string { return "github" }
|
||||
// - Owner / repo: must start with alphanumeric, then 0–99 chars from
|
||||
// [a-zA-Z0-9_.-]. Matches GitHub's validation.
|
||||
// - Ref: must NOT start with `-` (prevents ref-as-flag injection like
|
||||
// "-exec=/evil"). Then 0–254 chars from [a-zA-Z0-9_./-]. Disallows
|
||||
// "-exec=/evil"). Then 0–254 chars from [a-zA-Z0-9_./:-]. Disallows
|
||||
// whitespace and shell metacharacters. The handler additionally
|
||||
// passes `--` before the URL when invoking git, for defense in depth.
|
||||
var repoRE = regexp.MustCompile(
|
||||
`^([a-zA-Z0-9][a-zA-Z0-9_.\-]{0,99})/([a-zA-Z0-9][a-zA-Z0-9_.\-]{0,99})(?:#([a-zA-Z0-9_.][a-zA-Z0-9_./\-]{0,254}))?$`,
|
||||
`^([a-zA-Z0-9][a-zA-Z0-9_.\-]{0,99})/([a-zA-Z0-9][a-zA-Z0-9_.\-]{0,99})(?:#([a-zA-Z0-9_.][a-zA-Z0-9_./:\-]{0,254}))?$`,
|
||||
)
|
||||
|
||||
// Fetch clones the repository and copies its contents (minus .git) into dst.
|
||||
|
||||
@@ -31,7 +31,6 @@ import (
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"sort"
|
||||
@@ -104,8 +103,8 @@ func writeManifestJSON(t *testing.T, dir, digest string) {
|
||||
func writeStagedPlugin(t *testing.T, dir string) {
|
||||
t.Helper()
|
||||
files := map[string]string{
|
||||
"plugin.yaml": "name: test-plugin\nversion: 1.0.0\ndescription: supply chain test\n",
|
||||
"rules/guidelines.md": "# Plugin Guidelines\nFollow the rules.\n",
|
||||
"plugin.yaml": "name: test-plugin\nversion: 1.0.0\ndescription: supply chain test\n",
|
||||
"rules/guidelines.md": "# Plugin Guidelines\nFollow the rules.\n",
|
||||
"skills/helper/SKILL.md": "---\nid: helper\nname: Helper\ndescription: does stuff\n---\n",
|
||||
}
|
||||
for relPath, content := range files {
|
||||
@@ -119,19 +118,6 @@ func writeStagedPlugin(t *testing.T, dir string) {
|
||||
}
|
||||
}
|
||||
|
||||
// stubGitSuccess returns a GitRunner that creates the target directory and
|
||||
// returns nil (simulating a successful shallow clone). Does NOT write any
|
||||
// repo content — tests that need files should write them into dst separately.
|
||||
func stubGitSuccess() func(ctx context.Context, dir string, args ...string) error {
|
||||
return func(ctx context.Context, dir string, args ...string) error {
|
||||
if len(args) == 0 {
|
||||
return fmt.Errorf("stubGitSuccess: no args")
|
||||
}
|
||||
target := args[len(args)-1]
|
||||
return os.MkdirAll(target, 0o755)
|
||||
}
|
||||
}
|
||||
|
||||
// ──────────────────────────────────────────────────────────────────────────────
|
||||
// SHA256 content-integrity tests (#768 Control 1)
|
||||
//
|
||||
|
||||
@@ -445,16 +445,16 @@ func parseGiteaBranchHeadSha(body []byte) (string, error) {
|
||||
// Look for `"id":"<40-hex>"` inside the commit object.
|
||||
idx := strings.Index(string(body), `"id":"`)
|
||||
if idx < 0 {
|
||||
return "", errors.New("Gitea branch response missing commit.id field")
|
||||
return "", errors.New("gitea branch response missing commit.id field")
|
||||
}
|
||||
rest := string(body[idx+len(`"id":"`):])
|
||||
end := strings.IndexByte(rest, '"')
|
||||
if end < 0 {
|
||||
return "", errors.New("Gitea branch response has malformed commit.id (no closing quote)")
|
||||
return "", errors.New("gitea branch response has malformed commit.id (no closing quote)")
|
||||
}
|
||||
sha := rest[:end]
|
||||
if len(sha) < 7 {
|
||||
return "", fmt.Errorf("Gitea returned suspiciously short sha %q", sha)
|
||||
return "", fmt.Errorf("gitea returned suspiciously short sha %q", sha)
|
||||
}
|
||||
return sha, nil
|
||||
}
|
||||
|
||||
@@ -442,7 +442,7 @@ func (p *Provisioner) Start(ctx context.Context, cfg WorkspaceConfig) (string, e
|
||||
// contents are by definition immutable.
|
||||
// The pull is best-effort: if it fails (network, auth, rate limit) the
|
||||
// subsequent ContainerCreate still surfaces the actionable error below.
|
||||
imgInspect, _, imgErr := p.cli.ImageInspectWithRaw(ctx, image)
|
||||
imgInspect, imgErr := p.cli.ImageInspect(ctx, image)
|
||||
moving := imageTagIsMoving(image)
|
||||
switch {
|
||||
case imgErr != nil:
|
||||
@@ -541,12 +541,12 @@ func (p *Provisioner) Start(ctx context.Context, cfg WorkspaceConfig) (string, e
|
||||
//
|
||||
// Selection matrix:
|
||||
//
|
||||
// cfg.WorkspacePath | cfg.WorkspaceAccess | mount
|
||||
// ------------------+-------------------------+--------------------------------
|
||||
// "" | "" / "none" | <named-volume>:/workspace (isolated, current default)
|
||||
// "<host-dir>" | "" / "read_write" | <host-dir>:/workspace (current PM behaviour)
|
||||
// "<host-dir>" | "read_only" | <host-dir>:/workspace:ro (research agents get read access without write risk)
|
||||
// "" | "read_only"/"read_write"| <named-volume>:/workspace (degraded — access requires a mount; validated at handler layer)
|
||||
// cfg.WorkspacePath | cfg.WorkspaceAccess | mount
|
||||
// ------------------+-------------------------+--------------------------------
|
||||
// "" | "" / "none" | <named-volume>:/workspace (isolated, current default)
|
||||
// "<host-dir>" | "" / "read_write" | <host-dir>:/workspace (current PM behaviour)
|
||||
// "<host-dir>" | "read_only" | <host-dir>:/workspace:ro (research agents get read access without write risk)
|
||||
// "" | "read_only"/"read_write"| <named-volume>:/workspace (degraded — access requires a mount; validated at handler layer)
|
||||
//
|
||||
// Kept pure + side-effect-free so it's unit-testable.
|
||||
func buildWorkspaceMount(cfg WorkspaceConfig) string {
|
||||
@@ -700,11 +700,11 @@ func applyTierResources(hostCfg *container.HostConfig, tier int) (memMB, cpuShar
|
||||
memMB = getTierMemoryMB(tier)
|
||||
cpuShares = getTierCPUShares(tier)
|
||||
if memMB > 0 {
|
||||
hostCfg.Resources.Memory = memMB * 1024 * 1024
|
||||
hostCfg.Memory = memMB * 1024 * 1024
|
||||
}
|
||||
if cpuShares > 0 {
|
||||
// shares -> NanoCPUs: 1024 shares == 1 CPU == 1e9 NanoCPUs
|
||||
hostCfg.Resources.NanoCPUs = (cpuShares * 1_000_000_000) / 1024
|
||||
hostCfg.NanoCPUs = (cpuShares * 1_000_000_000) / 1024
|
||||
}
|
||||
return memMB, cpuShares
|
||||
}
|
||||
@@ -1000,20 +1000,6 @@ func (p *Provisioner) WriteAuthTokenToVolume(ctx context.Context, workspaceID, t
|
||||
return nil
|
||||
}
|
||||
|
||||
// execInContainer runs a command inside a running container as root.
|
||||
// Best-effort: logs errors but does not fail the caller.
|
||||
func (p *Provisioner) execInContainer(ctx context.Context, containerID string, cmd []string) {
|
||||
execCfg := container.ExecOptions{Cmd: cmd, User: "root"}
|
||||
execID, err := p.cli.ContainerExecCreate(ctx, containerID, execCfg)
|
||||
if err != nil {
|
||||
log.Printf("Provisioner: exec create failed: %v", err)
|
||||
return
|
||||
}
|
||||
if err := p.cli.ContainerExecStart(ctx, execID.ID, container.ExecStartOptions{}); err != nil {
|
||||
log.Printf("Provisioner: exec start failed: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// RemoveVolume removes the config volume for a workspace.
|
||||
// Also removes the claude-sessions volume (best-effort, may not exist
|
||||
// for non claude-code runtimes). Issue #12.
|
||||
@@ -1127,12 +1113,12 @@ func (p *Provisioner) IsRunning(ctx context.Context, workspaceID string) (bool,
|
||||
//
|
||||
// - ("ws-<id>", nil): container is running. Caller can exec into it.
|
||||
// - ("", nil): container does not exist OR exists but is stopped
|
||||
// (NotFound, Exited, Created, Restarting…). Caller
|
||||
// should treat as a definitive "not running."
|
||||
// (NotFound, Exited, Created, Restarting…). Caller
|
||||
// should treat as a definitive "not running."
|
||||
// - ("", err): transient daemon error (timeout, socket EOF, ctx
|
||||
// cancel). Caller should NOT infer "not running" —
|
||||
// this could be a flaky daemon under load. Decide
|
||||
// per-callsite whether to fail soft or hard.
|
||||
// cancel). Caller should NOT infer "not running" —
|
||||
// this could be a flaky daemon under load. Decide
|
||||
// per-callsite whether to fail soft or hard.
|
||||
//
|
||||
// Background — molecule-core#10: the plugins handler used to carry its own
|
||||
// copy of this inspect logic (`findRunningContainer`) which collapsed
|
||||
|
||||
@@ -155,14 +155,14 @@ func TestApplyTierConfig_Tier2_Standard(t *testing.T) {
|
||||
|
||||
// Memory limit: 512 MiB
|
||||
expectedMemory := int64(512 * 1024 * 1024)
|
||||
if hc.Resources.Memory != expectedMemory {
|
||||
t.Errorf("T2: expected Memory=%d (512m), got %d", expectedMemory, hc.Resources.Memory)
|
||||
if hc.Memory != expectedMemory {
|
||||
t.Errorf("T2: expected Memory=%d (512m), got %d", expectedMemory, hc.Memory)
|
||||
}
|
||||
|
||||
// CPU limit: 1.0 CPU (1e9 NanoCPUs)
|
||||
expectedCPU := int64(1_000_000_000)
|
||||
if hc.Resources.NanoCPUs != expectedCPU {
|
||||
t.Errorf("T2: expected NanoCPUs=%d (1.0 CPU), got %d", expectedCPU, hc.Resources.NanoCPUs)
|
||||
if hc.NanoCPUs != expectedCPU {
|
||||
t.Errorf("T2: expected NanoCPUs=%d (1.0 CPU), got %d", expectedCPU, hc.NanoCPUs)
|
||||
}
|
||||
|
||||
// Must NOT be privileged
|
||||
@@ -270,13 +270,13 @@ func TestApplyTierConfig_UnknownTier_DefaultsToT2(t *testing.T) {
|
||||
|
||||
// Unknown tiers should get T2 resource limits as a safe default
|
||||
expectedMemory := int64(512 * 1024 * 1024)
|
||||
if hc.Resources.Memory != expectedMemory {
|
||||
t.Errorf("Unknown tier: expected Memory=%d (512m), got %d", expectedMemory, hc.Resources.Memory)
|
||||
if hc.Memory != expectedMemory {
|
||||
t.Errorf("Unknown tier: expected Memory=%d (512m), got %d", expectedMemory, hc.Memory)
|
||||
}
|
||||
|
||||
expectedCPU := int64(1_000_000_000)
|
||||
if hc.Resources.NanoCPUs != expectedCPU {
|
||||
t.Errorf("Unknown tier: expected NanoCPUs=%d (1.0 CPU), got %d", expectedCPU, hc.Resources.NanoCPUs)
|
||||
if hc.NanoCPUs != expectedCPU {
|
||||
t.Errorf("Unknown tier: expected NanoCPUs=%d (1.0 CPU), got %d", expectedCPU, hc.NanoCPUs)
|
||||
}
|
||||
|
||||
// Must NOT be privileged
|
||||
@@ -298,8 +298,8 @@ func TestApplyTierConfig_ZeroTier_DefaultsToT2(t *testing.T) {
|
||||
|
||||
// Zero tier (default int value) should also get T2 resource limits
|
||||
expectedMemory := int64(512 * 1024 * 1024)
|
||||
if hc.Resources.Memory != expectedMemory {
|
||||
t.Errorf("Tier 0: expected Memory=%d, got %d", expectedMemory, hc.Resources.Memory)
|
||||
if hc.Memory != expectedMemory {
|
||||
t.Errorf("Tier 0: expected Memory=%d, got %d", expectedMemory, hc.Memory)
|
||||
}
|
||||
if hc.Privileged {
|
||||
t.Error("Tier 0: must not be privileged")
|
||||
@@ -944,12 +944,12 @@ func TestApplyTierConfig_T3_UsesEnvOverride(t *testing.T) {
|
||||
ApplyTierConfig(hc, cfg, "ws-abc123-configs:/configs", "ws-abc123")
|
||||
|
||||
wantMem := int64(8192) * 1024 * 1024
|
||||
if hc.Resources.Memory != wantMem {
|
||||
t.Errorf("T3 memory override: got %d, want %d", hc.Resources.Memory, wantMem)
|
||||
if hc.Memory != wantMem {
|
||||
t.Errorf("T3 memory override: got %d, want %d", hc.Memory, wantMem)
|
||||
}
|
||||
wantCPU := int64(4_000_000_000)
|
||||
if hc.Resources.NanoCPUs != wantCPU {
|
||||
t.Errorf("T3 CPU override: got %d NanoCPUs, want %d", hc.Resources.NanoCPUs, wantCPU)
|
||||
if hc.NanoCPUs != wantCPU {
|
||||
t.Errorf("T3 CPU override: got %d NanoCPUs, want %d", hc.NanoCPUs, wantCPU)
|
||||
}
|
||||
if !hc.Privileged || hc.PidMode != "host" {
|
||||
t.Errorf("T3 override should preserve privileged/pid-host flags, got Privileged=%v PidMode=%q",
|
||||
@@ -968,11 +968,11 @@ func TestApplyTierConfig_T3_DefaultCap(t *testing.T) {
|
||||
ApplyTierConfig(hc, cfg, "ws-abc123-configs:/configs", "ws-abc123")
|
||||
|
||||
wantMem := int64(defaultTier3MemoryMB) * 1024 * 1024
|
||||
if hc.Resources.Memory != wantMem {
|
||||
t.Errorf("T3 default memory: got %d, want %d", hc.Resources.Memory, wantMem)
|
||||
if hc.Memory != wantMem {
|
||||
t.Errorf("T3 default memory: got %d, want %d", hc.Memory, wantMem)
|
||||
}
|
||||
wantCPU := int64(defaultTier3CPUShares) * 1_000_000_000 / 1024
|
||||
if hc.Resources.NanoCPUs != wantCPU {
|
||||
t.Errorf("T3 default NanoCPUs: got %d, want %d", hc.Resources.NanoCPUs, wantCPU)
|
||||
if hc.NanoCPUs != wantCPU {
|
||||
t.Errorf("T3 default NanoCPUs: got %d, want %d", hc.NanoCPUs, wantCPU)
|
||||
}
|
||||
}
|
||||
|
||||
+148
-7
@@ -12,12 +12,14 @@ Environment variables (set by the workspace container):
|
||||
PLATFORM_URL — platform API base URL (e.g. http://platform:8080)
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import stat
|
||||
import sys
|
||||
import uuid
|
||||
from typing import Callable
|
||||
|
||||
# Top-level (not inside main()) so the wheel rewriter expands this to
|
||||
@@ -765,24 +767,163 @@ async def main(): # pragma: no cover
|
||||
break
|
||||
|
||||
|
||||
def cli_main() -> None: # pragma: no cover
|
||||
"""Synchronous wrapper around the async MCP stdio loop.
|
||||
# --- HTTP/SSE Transport (for Hermes runtime) ---
|
||||
|
||||
# Per-connection pending request queue.
|
||||
# Maps connection-id → asyncio.Queue of JSON-RPC responses.
|
||||
_http_connection_queues: dict[str, asyncio.Queue] = {}
|
||||
_http_connection_lock = asyncio.Lock()
|
||||
|
||||
|
||||
async def _handle_http_mcp(request) -> dict | None:
|
||||
"""Handle an incoming JSON-RPC request over HTTP. Returns the JSON-RPC response dict, or None for notifications."""
|
||||
try:
|
||||
body = await request.json()
|
||||
except Exception:
|
||||
return {"jsonrpc": "2.0", "id": None, "error": {"code": -32700, "message": "Parse error"}}
|
||||
|
||||
req_id = body.get("id")
|
||||
method = body.get("method", "")
|
||||
|
||||
if method == "initialize":
|
||||
return {
|
||||
"jsonrpc": "2.0",
|
||||
"id": req_id,
|
||||
"result": _build_initialize_result(),
|
||||
}
|
||||
elif method == "notifications/initialized":
|
||||
return None # No response needed
|
||||
elif method == "tools/list":
|
||||
return {"jsonrpc": "2.0", "id": req_id, "result": {"tools": TOOLS}}
|
||||
elif method == "tools/call":
|
||||
params = body.get("params", {})
|
||||
tool_name = params.get("name", "")
|
||||
tool_args = params.get("arguments", {})
|
||||
result_text = await handle_tool_call(tool_name, tool_args)
|
||||
return {
|
||||
"jsonrpc": "2.0",
|
||||
"id": req_id,
|
||||
"result": {"content": [{"type": "text", "text": result_text}]},
|
||||
}
|
||||
else:
|
||||
return {"jsonrpc": "2.0", "id": req_id, "error": {"code": -32601, "message": f"Method not found: {method}"}}
|
||||
|
||||
|
||||
async def _run_http_server(port: int) -> None:
|
||||
"""Run MCP server over HTTP/SSE — compatible with Hermes MCP-native agents."""
|
||||
try:
|
||||
from starlette.applications import Starlette # noqa: F401
|
||||
from starlette.routing import Route # noqa: F401
|
||||
from starlette.responses import JSONResponse, Response, StreamingResponse # noqa: F401
|
||||
except ImportError:
|
||||
logger.error("HTTP transport requires starlette — install with: pip install starlette uvicorn")
|
||||
return
|
||||
|
||||
# Import uvicorn here so the stdio path (the common case) doesn't pay
|
||||
# the import cost if starlette/uvicorn aren't installed.
|
||||
import uvicorn # noqa: F401
|
||||
|
||||
_http_connection_queues.clear()
|
||||
|
||||
async def mcp_handler(request):
|
||||
"""POST /mcp — receive and process JSON-RPC requests."""
|
||||
conn_id = request.headers.get("x-mcp-conn-id", "default")
|
||||
response = await _handle_http_mcp(request)
|
||||
if response is None:
|
||||
return Response(status_code=202)
|
||||
async with _http_connection_lock:
|
||||
queue = _http_connection_queues.get(conn_id)
|
||||
if queue is not None and not queue.full():
|
||||
await queue.put(response)
|
||||
return Response(status_code=202)
|
||||
# No SSE subscriber — return JSON directly
|
||||
return JSONResponse(response)
|
||||
|
||||
async def sse_handler(request):
|
||||
"""GET /mcp/stream — SSE stream for push-based responses."""
|
||||
conn_id = str(uuid.uuid4())
|
||||
queue: asyncio.Queue = asyncio.Queue(maxsize=100)
|
||||
async with _http_connection_lock:
|
||||
_http_connection_queues[conn_id] = queue
|
||||
|
||||
async def event_stream():
|
||||
yield f"event: connected\ndata: {json.dumps({'conn_id': conn_id})}\n\n"
|
||||
try:
|
||||
while True:
|
||||
response = await asyncio.wait_for(queue.get(), timeout=300)
|
||||
yield f"event: message\ndata: {json.dumps(response)}\n\n"
|
||||
if queue.empty():
|
||||
yield "event: heartbeat\ndata: null\n\n"
|
||||
except asyncio.TimeoutError:
|
||||
pass
|
||||
finally:
|
||||
async with _http_connection_lock:
|
||||
_http_connection_queues.pop(conn_id, None)
|
||||
|
||||
return StreamingResponse(
|
||||
event_stream(),
|
||||
media_type="text/event-stream",
|
||||
headers={
|
||||
"Cache-Control": "no-cache",
|
||||
"Connection": "keep-alive",
|
||||
"X-Accel-Buffering": "no",
|
||||
},
|
||||
)
|
||||
|
||||
async def health_handler(_request):
|
||||
return JSONResponse({"ok": True, "transport": "http+sse", "port": port})
|
||||
|
||||
app = Starlette(
|
||||
routes=[
|
||||
Route("/mcp", mcp_handler, methods=["POST"]),
|
||||
Route("/mcp/stream", sse_handler, methods=["GET"]),
|
||||
Route("/health", health_handler),
|
||||
]
|
||||
)
|
||||
config = uvicorn.Config(app, host="127.0.0.1", port=port, log_level="warning")
|
||||
server = uvicorn.Server(config)
|
||||
logger.info(f"A2A MCP HTTP server listening on http://127.0.0.1:{port}/mcp")
|
||||
await server.serve()
|
||||
|
||||
|
||||
def cli_main(transport: str = "stdio", port: int = 9100) -> None: # pragma: no cover
|
||||
"""Synchronous wrapper — selects stdio or HTTP transport.
|
||||
|
||||
Called by ``mcp_cli.main`` (the ``molecule-mcp`` console-script
|
||||
entry point in scripts/build_runtime_package.py) AFTER env
|
||||
validation and the standalone register + heartbeat thread setup.
|
||||
Direct callers (in-container code that already validated env and
|
||||
runs heartbeat.py separately) can also invoke this — it's the
|
||||
smallest possible "run the MCP stdio JSON-RPC loop" surface.
|
||||
runs heartbeat.py separately) can also invoke this.
|
||||
|
||||
Wheel-smoke gates in scripts/wheel_smoke.py pin the importability
|
||||
of this name (alongside ``mcp_cli.main``) so a silent rename can't
|
||||
break every external-runtime operator's MCP install — the 0.1.16
|
||||
``main_sync`` rename incident is the cautionary precedent.
|
||||
|
||||
Args:
|
||||
transport: "stdio" (default) or "http" (HTTP+SSE for Hermes).
|
||||
port: TCP port for HTTP transport (default 9100).
|
||||
"""
|
||||
_assert_stdio_is_pipe_compatible()
|
||||
asyncio.run(main())
|
||||
if transport == "http":
|
||||
asyncio.run(_run_http_server(port))
|
||||
else:
|
||||
_assert_stdio_is_pipe_compatible()
|
||||
asyncio.run(main())
|
||||
|
||||
|
||||
if __name__ == "__main__": # pragma: no cover
|
||||
cli_main()
|
||||
parser = argparse.ArgumentParser(description="A2A MCP Server")
|
||||
parser.add_argument(
|
||||
"--transport",
|
||||
default="stdio",
|
||||
choices=["stdio", "http"],
|
||||
help="Transport mode: stdio (default) or http (HTTP+SSE for Hermes)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--port",
|
||||
type=int,
|
||||
default=9100,
|
||||
help="TCP port for HTTP transport (default 9100)",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
cli_main(transport=args.transport, port=args.port)
|
||||
|
||||
@@ -34,6 +34,8 @@ async def list_peers() -> list[dict]:
|
||||
|
||||
async def delegate_task(workspace_id: str, task: str) -> str:
|
||||
"""Send a task to a peer workspace via A2A and return the response text."""
|
||||
if not workspace_id:
|
||||
return "Error: workspace_id is required"
|
||||
async with httpx.AsyncClient(timeout=120.0) as client:
|
||||
# Discover target URL
|
||||
try:
|
||||
|
||||
@@ -0,0 +1,567 @@
|
||||
"""Tests for the HTTP/SSE transport of a2a_mcp_server.
|
||||
|
||||
Covers:
|
||||
- _handle_http_mcp: JSON-RPC request parsing and routing
|
||||
- Starlette app routes: POST /mcp, GET /mcp/stream, GET /health
|
||||
- cli_main argparse: --transport and --port flags
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import sys
|
||||
import types
|
||||
import uuid
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import httpx
|
||||
import pytest
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class _DummyRequest:
|
||||
"""Minimal request duck-type for _handle_http_mcp."""
|
||||
|
||||
def __init__(self, body_json: dict, headers: dict | None = None):
|
||||
self._body = body_json
|
||||
self.headers = headers or {}
|
||||
|
||||
async def json(self) -> dict:
|
||||
return self._body
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _handle_http_mcp — unit tests (no I/O)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.asyncio()
|
||||
async def test_handle_http_mcp_initialize():
|
||||
"""initialize method returns protocol version, capabilities, and server info."""
|
||||
from a2a_mcp_server import _handle_http_mcp
|
||||
|
||||
req = _DummyRequest({"jsonrpc": "2.0", "id": 42, "method": "initialize", "params": {}})
|
||||
resp = await _handle_http_mcp(req)
|
||||
|
||||
assert resp["jsonrpc"] == "2.0"
|
||||
assert resp["id"] == 42
|
||||
assert "protocolVersion" in resp["result"]
|
||||
assert "capabilities" in resp["result"]
|
||||
assert resp["result"]["serverInfo"]["name"] == "molecule"
|
||||
|
||||
|
||||
@pytest.mark.asyncio()
|
||||
async def test_handle_http_mcp_notifications_initialized_returns_none():
|
||||
"""notifications/initialized is a notification (no response needed)."""
|
||||
from a2a_mcp_server import _handle_http_mcp
|
||||
|
||||
req = _DummyRequest({"jsonrpc": "2.0", "method": "notifications/initialized"})
|
||||
resp = await _handle_http_mcp(req)
|
||||
|
||||
assert resp is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio()
|
||||
async def test_handle_http_mcp_tools_list():
|
||||
"""tools/list returns the TOOLS schema."""
|
||||
from a2a_mcp_server import _handle_http_mcp
|
||||
|
||||
req = _DummyRequest({"jsonrpc": "2.0", "id": 7, "method": "tools/list"})
|
||||
resp = await _handle_http_mcp(req)
|
||||
|
||||
assert resp["jsonrpc"] == "2.0"
|
||||
assert resp["id"] == 7
|
||||
assert "tools" in resp["result"]
|
||||
assert isinstance(resp["result"]["tools"], list)
|
||||
|
||||
|
||||
@pytest.mark.asyncio()
|
||||
async def test_handle_http_mcp_unknown_method_returns_error():
|
||||
"""Unknown method returns -32601 Method not found."""
|
||||
from a2a_mcp_server import _handle_http_mcp
|
||||
|
||||
req = _DummyRequest({"jsonrpc": "2.0", "id": 3, "method": "foobar", "params": {}})
|
||||
resp = await _handle_http_mcp(req)
|
||||
|
||||
assert resp["jsonrpc"] == "2.0"
|
||||
assert resp["id"] == 3
|
||||
assert resp["error"]["code"] == -32601
|
||||
assert "Method not found" in resp["error"]["message"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio()
|
||||
async def test_handle_http_mcp_malformed_json_returns_parse_error():
|
||||
"""Request with bad JSON returns -32700 parse error."""
|
||||
from a2a_mcp_server import _handle_http_mcp
|
||||
|
||||
req = _DummyRequest.__new__(_DummyRequest)
|
||||
req.headers = {}
|
||||
req.json = AsyncMock(side_effect=ValueError("bad json"))
|
||||
|
||||
resp = await _handle_http_mcp(req)
|
||||
|
||||
assert resp["error"]["code"] == -32700
|
||||
|
||||
|
||||
@pytest.mark.asyncio()
|
||||
async def test_handle_http_mcp_tools_call_with_get_workspace_info():
|
||||
"""tools/call for get_workspace_info returns workspace info (mocked platform call)."""
|
||||
from a2a_mcp_server import _handle_http_mcp
|
||||
|
||||
with patch("a2a_mcp_server.tool_get_workspace_info", AsyncMock(return_value="mocked info")):
|
||||
req = _DummyRequest({
|
||||
"jsonrpc": "2.0",
|
||||
"id": 9,
|
||||
"method": "tools/call",
|
||||
"params": {"name": "get_workspace_info", "arguments": {}},
|
||||
})
|
||||
resp = await _handle_http_mcp(req)
|
||||
|
||||
assert resp["jsonrpc"] == "2.0"
|
||||
assert resp["id"] == 9
|
||||
assert resp["result"]["content"][0]["text"] == "mocked info"
|
||||
|
||||
|
||||
@pytest.mark.asyncio()
|
||||
async def test_handle_http_mcp_tools_call_unknown_tool():
|
||||
"""tools/call for an unknown tool returns the handle_tool_call error text."""
|
||||
from a2a_mcp_server import _handle_http_mcp
|
||||
|
||||
req = _DummyRequest({
|
||||
"jsonrpc": "2.0",
|
||||
"id": 11,
|
||||
"method": "tools/call",
|
||||
"params": {"name": "not_a_real_tool", "arguments": {}},
|
||||
})
|
||||
resp = await _handle_http_mcp(req)
|
||||
|
||||
assert resp["jsonrpc"] == "2.0"
|
||||
assert resp["id"] == 11
|
||||
assert "Unknown tool" in resp["result"]["content"][0]["text"]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Starlette app — integration tests with TestClient
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def _clear_http_globals():
|
||||
"""Reset module-level HTTP state before and after each test."""
|
||||
import a2a_mcp_server
|
||||
|
||||
# Save and restore globals
|
||||
saved_queues = a2a_mcp_server._http_connection_queues.copy()
|
||||
saved_lock = a2a_mcp_server._http_connection_lock
|
||||
a2a_mcp_server._http_connection_queues.clear()
|
||||
yield
|
||||
# Restore
|
||||
a2a_mcp_server._http_connection_queues = saved_queues
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
def _register_sse_queue():
|
||||
"""Register a queue for SSE push delivery (synchronous — callable from tests)."""
|
||||
conn_id = str(uuid.uuid4())
|
||||
queue = asyncio.Queue(maxsize=100)
|
||||
import a2a_mcp_server
|
||||
a2a_mcp_server._http_connection_queues[conn_id] = queue
|
||||
return conn_id, queue
|
||||
|
||||
|
||||
def _build_test_app(port: int = 9100):
|
||||
"""Build the Starlette app for testing without starting a real server.
|
||||
|
||||
Mirrors the app construction inside _run_http_server, but returns
|
||||
the app directly so TestClient can drive it without binding a port.
|
||||
"""
|
||||
from starlette.applications import Starlette
|
||||
from starlette.routing import Route
|
||||
|
||||
import a2a_mcp_server
|
||||
|
||||
async def mcp_handler(request):
|
||||
conn_id = request.headers.get("x-mcp-conn-id", "default")
|
||||
response = await a2a_mcp_server._handle_http_mcp(request)
|
||||
if response is None:
|
||||
from starlette.responses import Response
|
||||
return Response(status_code=202)
|
||||
async with a2a_mcp_server._http_connection_lock:
|
||||
queue = a2a_mcp_server._http_connection_queues.get(conn_id)
|
||||
if queue is not None and not queue.full():
|
||||
await queue.put(response)
|
||||
from starlette.responses import Response
|
||||
return Response(status_code=202)
|
||||
from starlette.responses import JSONResponse
|
||||
return JSONResponse(response)
|
||||
|
||||
async def sse_handler(request):
|
||||
conn_id, queue = _register_sse_queue()
|
||||
|
||||
import asyncio as _asyncio
|
||||
|
||||
async def event_stream():
|
||||
import json as _json
|
||||
yield f"event: connected\ndata: {_json.dumps({'conn_id': conn_id})}\n\n"
|
||||
try:
|
||||
while True:
|
||||
response = await _asyncio.wait_for(queue.get(), timeout=300)
|
||||
import json as _json
|
||||
yield f"event: message\ndata: {_json.dumps(response)}\n\n"
|
||||
if queue.empty():
|
||||
yield "event: heartbeat\ndata: null\n\n"
|
||||
except _asyncio.TimeoutError:
|
||||
pass
|
||||
finally:
|
||||
async with a2a_mcp_server._http_connection_lock:
|
||||
a2a_mcp_server._http_connection_queues.pop(conn_id, None)
|
||||
|
||||
from starlette.responses import StreamingResponse
|
||||
return StreamingResponse(
|
||||
event_stream(),
|
||||
media_type="text/event-stream",
|
||||
headers={
|
||||
"Cache-Control": "no-cache",
|
||||
"Connection": "keep-alive",
|
||||
"X-Accel-Buffering": "no",
|
||||
},
|
||||
)
|
||||
|
||||
async def health_handler(_request):
|
||||
from starlette.responses import JSONResponse
|
||||
return JSONResponse({"ok": True, "transport": "http+sse", "port": port})
|
||||
|
||||
return Starlette(
|
||||
routes=[
|
||||
Route("/mcp", mcp_handler, methods=["POST"]),
|
||||
Route("/mcp/stream", sse_handler, methods=["GET"]),
|
||||
Route("/health", health_handler),
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
class TestHTTPAppRoutes:
|
||||
"""Integration tests using Starlette TestClient against the HTTP app.
|
||||
|
||||
Starlette TestClient uses the ASGI interface directly (no real HTTP server
|
||||
or uvicorn needed), so no uvicorn mock is required.
|
||||
"""
|
||||
|
||||
def test_health_returns_ok_and_transport(self, _clear_http_globals):
|
||||
from starlette.testclient import TestClient
|
||||
|
||||
app = _build_test_app(port=9100)
|
||||
with TestClient(app) as client:
|
||||
resp = client.get("/health")
|
||||
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert data["ok"] is True
|
||||
assert data["transport"] == "http+sse"
|
||||
assert data["port"] == 9100
|
||||
|
||||
def test_health_accepts_different_port(self, _clear_http_globals):
|
||||
from starlette.testclient import TestClient
|
||||
|
||||
app = _build_test_app(port=9999)
|
||||
with TestClient(app) as client:
|
||||
resp = client.get("/health")
|
||||
|
||||
assert resp.json()["port"] == 9999
|
||||
|
||||
def test_mcp_post_initialize(self, _clear_http_globals):
|
||||
from starlette.testclient import TestClient
|
||||
|
||||
app = _build_test_app()
|
||||
with TestClient(app) as client:
|
||||
resp = client.post("/mcp", json={
|
||||
"jsonrpc": "2.0",
|
||||
"id": 1,
|
||||
"method": "initialize",
|
||||
"params": {},
|
||||
})
|
||||
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert data["id"] == 1
|
||||
assert "protocolVersion" in data["result"]
|
||||
|
||||
def test_mcp_post_tools_list(self, _clear_http_globals):
|
||||
from starlette.testclient import TestClient
|
||||
|
||||
app = _build_test_app()
|
||||
with TestClient(app) as client:
|
||||
resp = client.post("/mcp", json={
|
||||
"jsonrpc": "2.0",
|
||||
"id": 2,
|
||||
"method": "tools/list",
|
||||
"params": {},
|
||||
})
|
||||
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert "tools" in data["result"]
|
||||
assert len(data["result"]["tools"]) > 0
|
||||
|
||||
def test_mcp_post_notifications_initialized_returns_202(self, _clear_http_globals):
|
||||
from starlette.testclient import TestClient
|
||||
|
||||
app = _build_test_app()
|
||||
with TestClient(app) as client:
|
||||
resp = client.post("/mcp", json={
|
||||
"jsonrpc": "2.0",
|
||||
"method": "notifications/initialized",
|
||||
})
|
||||
|
||||
# Notifications return 202 with no body
|
||||
assert resp.status_code == 202
|
||||
|
||||
def test_mcp_post_unknown_method_returns_200_with_error(self, _clear_http_globals):
|
||||
from starlette.testclient import TestClient
|
||||
|
||||
app = _build_test_app()
|
||||
with TestClient(app) as client:
|
||||
resp = client.post("/mcp", json={
|
||||
"jsonrpc": "2.0",
|
||||
"id": 5,
|
||||
"method": "no_such_method",
|
||||
"params": {},
|
||||
})
|
||||
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert data["error"]["code"] == -32601
|
||||
|
||||
def test_mcp_post_malformed_json_returns_error(self, _clear_http_globals):
|
||||
"""Malformed JSON body returns a JSON-RPC parse-error response (HTTP 200)."""
|
||||
from starlette.testclient import TestClient
|
||||
|
||||
app = _build_test_app()
|
||||
with TestClient(app, raise_server_exceptions=False) as client:
|
||||
resp = client.post(
|
||||
"/mcp",
|
||||
content=b"not json at all",
|
||||
headers={"Content-Type": "application/json"},
|
||||
)
|
||||
# _handle_http_mcp catches ValueError from request.json() and returns
|
||||
# a JSON-RPC parse-error response with HTTP 200.
|
||||
assert resp.status_code == 200
|
||||
assert resp.json()["error"]["code"] == -32700
|
||||
assert "Parse error" in resp.json()["error"]["message"]
|
||||
|
||||
@pytest.mark.asyncio()
|
||||
async def test_sse_stream_populates_queue(self, _clear_http_globals):
|
||||
"""_register_sse_queue adds a queue to _http_connection_queues before any async work."""
|
||||
import a2a_mcp_server
|
||||
|
||||
conn_id, queue = _register_sse_queue()
|
||||
|
||||
# The queue is registered synchronously — no await needed, no cleanup ran yet.
|
||||
assert conn_id in a2a_mcp_server._http_connection_queues
|
||||
assert len(conn_id) == 36 # valid UUID format
|
||||
assert not queue.full()
|
||||
|
||||
@pytest.mark.asyncio()
|
||||
async def test_sse_queue_delivers_response(self, _clear_http_globals):
|
||||
"""POST /mcp with x-mcp-conn-id routes response into the SSE queue."""
|
||||
import uuid
|
||||
|
||||
import a2a_mcp_server
|
||||
from starlette.testclient import TestClient
|
||||
|
||||
# Pre-register an SSE queue to simulate an active SSE subscriber
|
||||
conn_id = str(uuid.uuid4())
|
||||
queue: asyncio.Queue = asyncio.Queue(maxsize=100)
|
||||
async with a2a_mcp_server._http_connection_lock:
|
||||
a2a_mcp_server._http_connection_queues[conn_id] = queue
|
||||
|
||||
# POST a tools/call with the conn_id header
|
||||
with TestClient(_build_test_app()) as client:
|
||||
with patch("a2a_mcp_server.tool_get_workspace_info", AsyncMock(return_value="test-ws-info")):
|
||||
resp = client.post(
|
||||
"/mcp",
|
||||
headers={"x-mcp-conn-id": conn_id},
|
||||
json={
|
||||
"jsonrpc": "2.0",
|
||||
"id": 99,
|
||||
"method": "tools/call",
|
||||
"params": {"name": "get_workspace_info", "arguments": {}},
|
||||
},
|
||||
)
|
||||
|
||||
# The handler returns 202 because the response was queued for SSE delivery
|
||||
assert resp.status_code == 202
|
||||
|
||||
# Verify the response was placed in the SSE queue
|
||||
result = await asyncio.wait_for(queue.get(), timeout=2.0)
|
||||
assert result["id"] == 99
|
||||
assert result["result"]["content"][0]["text"] == "test-ws-info"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# cli_main argparse — unit tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_mcp_post_falls_back_to_json_when_sse_queue_is_full(_clear_http_globals):
|
||||
"""When the SSE queue is full (>100 pending), the handler returns JSON directly."""
|
||||
import a2a_mcp_server
|
||||
from starlette.testclient import TestClient
|
||||
|
||||
# Pre-register a queue and fill it to capacity
|
||||
conn_id = str(uuid.uuid4())
|
||||
queue: asyncio.Queue = asyncio.Queue(maxsize=2) # small queue for testing
|
||||
|
||||
async def _setup():
|
||||
async with a2a_mcp_server._http_connection_lock:
|
||||
a2a_mcp_server._http_connection_queues[conn_id] = queue
|
||||
queue.put_nowait({"id": 1})
|
||||
queue.put_nowait({"id": 2})
|
||||
|
||||
_sync_run(_setup())
|
||||
assert queue.full()
|
||||
|
||||
app = _build_test_app()
|
||||
with TestClient(app) as client:
|
||||
resp = client.post(
|
||||
"/mcp",
|
||||
headers={"x-mcp-conn-id": conn_id},
|
||||
json={"jsonrpc": "2.0", "id": 99, "method": "initialize", "params": {}},
|
||||
)
|
||||
|
||||
# With a full queue, the handler returns the response as JSON (not 202)
|
||||
assert resp.status_code == 200
|
||||
assert resp.json()["id"] == 99
|
||||
assert "result" in resp.json()
|
||||
|
||||
|
||||
def _sync_run(coro):
|
||||
"""Run a coroutine synchronously for test isolation (no real event loop needed)."""
|
||||
try:
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
try:
|
||||
return loop.run_until_complete(coro)
|
||||
finally:
|
||||
loop.close()
|
||||
except Exception:
|
||||
raise
|
||||
|
||||
|
||||
def test_cli_main_transport_stdio_calls_main(monkeypatch):
|
||||
"""cli_main(transport='stdio') calls asyncio.run(main) without HTTP."""
|
||||
import a2a_mcp_server
|
||||
|
||||
run_calls: list = []
|
||||
|
||||
async def fake_main():
|
||||
run_calls.append("called")
|
||||
|
||||
monkeypatch.setattr(a2a_mcp_server, "main", fake_main)
|
||||
monkeypatch.setattr(a2a_mcp_server.asyncio, "run", _sync_run)
|
||||
monkeypatch.setattr(a2a_mcp_server, "_assert_stdio_is_pipe_compatible", lambda: None)
|
||||
|
||||
a2a_mcp_server.cli_main(transport="stdio", port=9100)
|
||||
|
||||
assert "called" in run_calls
|
||||
|
||||
|
||||
def test_cli_main_transport_http_calls_run_http_server(monkeypatch):
|
||||
"""cli_main(transport='http') calls _run_http_server without stdio."""
|
||||
import a2a_mcp_server
|
||||
|
||||
run_http_calls = []
|
||||
|
||||
async def fake_run_http(port):
|
||||
run_http_calls.append(port)
|
||||
|
||||
# asyncio.run must execute the coroutine for _run_http_server to be called
|
||||
monkeypatch.setattr(a2a_mcp_server.asyncio, "run", _sync_run)
|
||||
monkeypatch.setattr(a2a_mcp_server, "_run_http_server", fake_run_http)
|
||||
# stdio path must not be entered
|
||||
monkeypatch.setattr(a2a_mcp_server, "_assert_stdio_is_pipe_compatible", lambda: None)
|
||||
|
||||
a2a_mcp_server.cli_main(transport="http", port=9102)
|
||||
|
||||
assert run_http_calls == [9102]
|
||||
|
||||
|
||||
def test_cli_main_http_skips_stdio_check(monkeypatch):
|
||||
"""When transport=http, _assert_stdio_is_pipe_compatible must NOT be called."""
|
||||
import a2a_mcp_server
|
||||
|
||||
called = []
|
||||
|
||||
def fake_assert():
|
||||
called.append("assert_called")
|
||||
|
||||
# Patch on the module object directly
|
||||
monkeypatch.setattr(a2a_mcp_server, "_assert_stdio_is_pipe_compatible", fake_assert)
|
||||
monkeypatch.setattr(a2a_mcp_server.asyncio, "run", lambda fn: None)
|
||||
|
||||
a2a_mcp_server.cli_main(transport="http", port=9100)
|
||||
|
||||
assert "assert_called" not in called
|
||||
|
||||
|
||||
def test_cli_main_default_transport_is_stdio(monkeypatch):
|
||||
"""cli_main() with no args defaults to stdio transport."""
|
||||
import a2a_mcp_server
|
||||
|
||||
called_as: list = []
|
||||
|
||||
async def fake_main():
|
||||
called_as.append("called")
|
||||
|
||||
monkeypatch.setattr(a2a_mcp_server, "main", fake_main)
|
||||
monkeypatch.setattr(a2a_mcp_server.asyncio, "run", _sync_run)
|
||||
monkeypatch.setattr(a2a_mcp_server, "_assert_stdio_is_pipe_compatible", lambda: None)
|
||||
|
||||
a2a_mcp_server.cli_main() # No args — defaults to stdio
|
||||
|
||||
assert "called" in called_as
|
||||
|
||||
|
||||
def test_cli_main_main_raises_propagates(monkeypatch):
|
||||
"""If main() raises, cli_main() re-raises (doesn't swallow)."""
|
||||
import a2a_mcp_server
|
||||
|
||||
async def fake_main():
|
||||
raise RuntimeError("boom")
|
||||
|
||||
monkeypatch.setattr(a2a_mcp_server, "main", fake_main)
|
||||
monkeypatch.setattr(a2a_mcp_server.asyncio, "run", _sync_run)
|
||||
monkeypatch.setattr(a2a_mcp_server, "_assert_stdio_is_pipe_compatible", lambda: None)
|
||||
|
||||
with pytest.raises(RuntimeError, match="boom"):
|
||||
a2a_mcp_server.cli_main(transport="stdio")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# uvicorn/starlette lazy-import
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_run_http_server_is_coroutine_function():
|
||||
"""_run_http_server is a coroutine function accepting a port argument."""
|
||||
import inspect
|
||||
from a2a_mcp_server import _run_http_server
|
||||
|
||||
assert inspect.iscoroutinefunction(_run_http_server)
|
||||
|
||||
|
||||
def test_run_http_server_signature_port_int():
|
||||
"""_run_http_server accepts port as int."""
|
||||
import inspect
|
||||
from a2a_mcp_server import _run_http_server
|
||||
|
||||
sig = inspect.signature(_run_http_server)
|
||||
assert "port" in sig.parameters
|
||||
assert sig.parameters["port"].annotation == int
|
||||
@@ -0,0 +1,432 @@
|
||||
"""Test coverage for ``builtin_tools.a2a_tools`` and ``send_message_wrapper``.
|
||||
|
||||
Issue #367: 21 new test cases targeting previously-uncovered branches.
|
||||
|
||||
HTTP mocking: each test patches ``builtin_tools.a2a_tools.httpx.AsyncClient``
|
||||
with an ``AsyncMock`` so no real network I/O occurs. The patch target is
|
||||
the attribute as seen inside the ``a2a_tools`` module (where httpx is imported
|
||||
as ``import httpx``), so ``@pytest.fixture(autouse=True)`` from conftest.py is
|
||||
harmless — it replaces the module-level name *after* our patch exits.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
import sys
|
||||
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# conftest.py fixture — swap the MagicMock for the real module for THIS file
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _real_a2a_tools_module():
|
||||
"""Replace conftest's MagicMock of builtin_tools.a2a_tools with the real module.
|
||||
|
||||
conftest.py sets sys.modules["builtin_tools.a2a_tools"] = <MagicMock> so that
|
||||
adapter tests don't accidentally hit the platform. For THIS test file we
|
||||
want the real module, so we restore it from disk and swap it back after.
|
||||
"""
|
||||
import builtin_tools.a2a_tools as real_module
|
||||
|
||||
# conftest.py may have clobbered builtin_tools.__path__; restore it so the
|
||||
# import above finds builtin_tools/a2a_tools.py on disk.
|
||||
if "builtin_tools" in sys.modules:
|
||||
real_builtin = sys.modules["builtin_tools"]
|
||||
if getattr(real_builtin, "__path__", None) == []:
|
||||
builtin_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
||||
real_builtin.__path__ = [os.path.join(builtin_dir, "builtin_tools")]
|
||||
|
||||
saved = sys.modules.get("builtin_tools.a2a_tools")
|
||||
# Ensure we have the real module (reload if sys.modules already has it)
|
||||
if saved is None or saved is real_module:
|
||||
import importlib
|
||||
importlib.reload(real_module)
|
||||
sys.modules["builtin_tools.a2a_tools"] = real_module
|
||||
yield
|
||||
sys.modules["builtin_tools.a2a_tools"] = saved
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _require_env(monkeypatch):
|
||||
"""Per-test: set required env vars."""
|
||||
monkeypatch.setenv("WORKSPACE_ID", "00000000-0000-0000-0000-000000000001")
|
||||
monkeypatch.setenv("PLATFORM_URL", "http://test.invalid")
|
||||
yield
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def _make_mock_response(
|
||||
json_data, status_code: int = 200
|
||||
) -> MagicMock:
|
||||
"""Return a fully-configured AsyncMock that mirrors httpx.Response."""
|
||||
resp = MagicMock()
|
||||
resp.json = MagicMock(return_value=json_data)
|
||||
resp.status_code = status_code
|
||||
return resp
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# builtin_tools/a2a_tools — list_peers
|
||||
# =============================================================================
|
||||
|
||||
class TestListPeers:
|
||||
"""Coverage for builtin_tools/a2a_tools.list_peers()."""
|
||||
|
||||
async def test_returns_peers_on_200(self):
|
||||
"""Successful GET returns the peer list."""
|
||||
from builtin_tools.a2a_tools import list_peers
|
||||
|
||||
peers = [
|
||||
{"id": "ws-1", "name": "Alpha", "role": "sre", "status": "online"},
|
||||
{"id": "ws-2", "name": "Beta", "role": "dev", "status": "busy"},
|
||||
]
|
||||
mock_resp = _make_mock_response(peers, 200)
|
||||
mock_client = AsyncMock()
|
||||
mock_client.__aenter__.return_value.get = AsyncMock(return_value=mock_resp)
|
||||
|
||||
with patch("builtin_tools.a2a_tools.httpx.AsyncClient", return_value=mock_client):
|
||||
result = await list_peers()
|
||||
assert result == peers
|
||||
|
||||
async def test_returns_empty_list_on_non_200(self):
|
||||
"""list_peers swallows all non-200 responses gracefully."""
|
||||
from builtin_tools.a2a_tools import list_peers
|
||||
|
||||
mock_resp = _make_mock_response({}, 500)
|
||||
mock_client = AsyncMock()
|
||||
mock_client.__aenter__.return_value.get = AsyncMock(return_value=mock_resp)
|
||||
|
||||
with patch("builtin_tools.a2a_tools.httpx.AsyncClient", return_value=mock_client):
|
||||
result = await list_peers()
|
||||
assert result == []
|
||||
|
||||
async def test_returns_empty_list_on_exception(self):
|
||||
"""Network errors must not propagate — list_peers returns []. """
|
||||
from builtin_tools.a2a_tools import list_peers
|
||||
|
||||
mock_client = AsyncMock()
|
||||
mock_client.__aenter__.return_value.get = AsyncMock(
|
||||
side_effect=RuntimeError("dns failure")
|
||||
)
|
||||
|
||||
with patch("builtin_tools.a2a_tools.httpx.AsyncClient", return_value=mock_client):
|
||||
result = await list_peers()
|
||||
assert result == []
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# builtin_tools/a2a_tools — delegate_task
|
||||
# =============================================================================
|
||||
|
||||
_DISCOVER_ROUTE = "http://test.invalid/registry/discover/ws-target"
|
||||
|
||||
|
||||
class TestDelegateTask:
|
||||
"""Coverage for builtin_tools/a2a_tools.delegate_task(workspace_id, task)."""
|
||||
|
||||
async def test_empty_workspace_id_returns_error(self):
|
||||
"""Empty workspace_id is validated before any network call."""
|
||||
from builtin_tools.a2a_tools import delegate_task
|
||||
|
||||
out = await delegate_task("", "do it")
|
||||
assert "Error" in out
|
||||
assert "workspace_id" in out.lower()
|
||||
|
||||
async def test_discover_returns_non_200(self):
|
||||
"""Discovery 4xx/5xx → error message with status code."""
|
||||
from builtin_tools.a2a_tools import delegate_task
|
||||
|
||||
discover_resp = _make_mock_response({}, 404)
|
||||
mock_client = AsyncMock()
|
||||
mock_client.__aenter__.return_value.get = AsyncMock(return_value=discover_resp)
|
||||
|
||||
with patch("builtin_tools.a2a_tools.httpx.AsyncClient", return_value=mock_client):
|
||||
out = await delegate_task("ws-target", "do it")
|
||||
assert "Error" in out
|
||||
assert "404" in out
|
||||
|
||||
async def test_discover_returns_200_with_empty_url(self):
|
||||
"""Discovery 200 but no url field → actionable error."""
|
||||
from builtin_tools.a2a_tools import delegate_task
|
||||
|
||||
discover_resp = _make_mock_response({"name": "orphan"}, 200)
|
||||
mock_client = AsyncMock()
|
||||
mock_client.__aenter__.return_value.get = AsyncMock(return_value=discover_resp)
|
||||
|
||||
with patch("builtin_tools.a2a_tools.httpx.AsyncClient", return_value=mock_client):
|
||||
out = await delegate_task("ws-target", "do it")
|
||||
assert "Error" in out
|
||||
assert "no URL" in out
|
||||
|
||||
async def test_a2a_post_returns_500(self):
|
||||
"""A2A send 5xx with empty body → str(data) returned (code doesn't check status_code)."""
|
||||
from builtin_tools.a2a_tools import delegate_task
|
||||
|
||||
discover_resp = _make_mock_response({"url": "http://peer.invalid/a2a"}, 200)
|
||||
a2a_resp = _make_mock_response({}, 500)
|
||||
mock_client = AsyncMock()
|
||||
mock_client.__aenter__.return_value.get = AsyncMock(return_value=discover_resp)
|
||||
mock_client.__aenter__.return_value.post = AsyncMock(return_value=a2a_resp)
|
||||
|
||||
with patch("builtin_tools.a2a_tools.httpx.AsyncClient", return_value=mock_client):
|
||||
out = await delegate_task("ws-target", "do it")
|
||||
# Code checks json body, not status_code; empty body {} → str({})
|
||||
assert out == "{}"
|
||||
|
||||
async def test_result_parts_empty_dict(self):
|
||||
"""Regression #279: {"parts": []} → str(result), not "(no text)"."""
|
||||
from builtin_tools.a2a_tools import delegate_task
|
||||
|
||||
discover_resp = _make_mock_response({"url": "http://peer.invalid/a2a"}, 200)
|
||||
a2a_resp = _make_mock_response({"result": {"parts": []}}, 200)
|
||||
mock_client = AsyncMock()
|
||||
mock_client.__aenter__.return_value.get = AsyncMock(return_value=discover_resp)
|
||||
mock_client.__aenter__.return_value.post = AsyncMock(return_value=a2a_resp)
|
||||
|
||||
with patch("builtin_tools.a2a_tools.httpx.AsyncClient", return_value=mock_client):
|
||||
out = await delegate_task("ws-target", "do it")
|
||||
# Must return str(result), not "(no text)"
|
||||
assert "parts" in out
|
||||
assert "(no text)" not in out
|
||||
|
||||
async def test_result_is_plain_string(self):
|
||||
"""A bare string result returns as-is."""
|
||||
from builtin_tools.a2a_tools import delegate_task
|
||||
|
||||
discover_resp = _make_mock_response({"url": "http://peer.invalid/a2a"}, 200)
|
||||
a2a_resp = _make_mock_response({"result": "just a plain string"}, 200)
|
||||
mock_client = AsyncMock()
|
||||
mock_client.__aenter__.return_value.get = AsyncMock(return_value=discover_resp)
|
||||
mock_client.__aenter__.return_value.post = AsyncMock(return_value=a2a_resp)
|
||||
|
||||
with patch("builtin_tools.a2a_tools.httpx.AsyncClient", return_value=mock_client):
|
||||
out = await delegate_task("ws-target", "do it")
|
||||
assert out == "just a plain string"
|
||||
|
||||
async def test_result_is_number(self):
|
||||
"""Non-dict, non-string result → falls through to "(no text)"."""
|
||||
from builtin_tools.a2a_tools import delegate_task
|
||||
|
||||
discover_resp = _make_mock_response({"url": "http://peer.invalid/a2a"}, 200)
|
||||
a2a_resp = _make_mock_response({"result": 12345}, 200)
|
||||
mock_client = AsyncMock()
|
||||
mock_client.__aenter__.return_value.get = AsyncMock(return_value=discover_resp)
|
||||
mock_client.__aenter__.return_value.post = AsyncMock(return_value=a2a_resp)
|
||||
|
||||
with patch("builtin_tools.a2a_tools.httpx.AsyncClient", return_value=mock_client):
|
||||
out = await delegate_task("ws-target", "do it")
|
||||
assert out == "(no text)"
|
||||
|
||||
async def test_result_parts_non_dict_element(self):
|
||||
"""parts[0] is not a dict → falls through to "(no text)".
|
||||
|
||||
The code checks if parts[0] is a dict; since 123 is an int, it hits
|
||||
the else-branch and returns "(no text)".
|
||||
"""
|
||||
from builtin_tools.a2a_tools import delegate_task
|
||||
|
||||
discover_resp = _make_mock_response({"url": "http://peer.invalid/a2a"}, 200)
|
||||
a2a_resp = _make_mock_response({"result": {"parts": [123, "also a string"]}}, 200)
|
||||
mock_client = AsyncMock()
|
||||
mock_client.__aenter__.return_value.get = AsyncMock(return_value=discover_resp)
|
||||
mock_client.__aenter__.return_value.post = AsyncMock(return_value=a2a_resp)
|
||||
|
||||
with patch("builtin_tools.a2a_tools.httpx.AsyncClient", return_value=mock_client):
|
||||
out = await delegate_task("ws-target", "do it")
|
||||
assert out == "(no text)"
|
||||
|
||||
async def test_error_dict_form(self):
|
||||
"""{"error": {"message": "..."}} → "Error: ..."."""
|
||||
from builtin_tools.a2a_tools import delegate_task
|
||||
|
||||
discover_resp = _make_mock_response({"url": "http://peer.invalid/a2a"}, 200)
|
||||
a2a_resp = _make_mock_response(
|
||||
{"error": {"message": "peer overloaded", "code": 429}}, 200
|
||||
)
|
||||
mock_client = AsyncMock()
|
||||
mock_client.__aenter__.return_value.get = AsyncMock(return_value=discover_resp)
|
||||
mock_client.__aenter__.return_value.post = AsyncMock(return_value=a2a_resp)
|
||||
|
||||
with patch("builtin_tools.a2a_tools.httpx.AsyncClient", return_value=mock_client):
|
||||
out = await delegate_task("ws-target", "do it")
|
||||
assert out == "Error: peer overloaded"
|
||||
|
||||
async def test_error_string_form(self):
|
||||
"""{"error": "string error"} → "Error: string error"."""
|
||||
from builtin_tools.a2a_tools import delegate_task
|
||||
|
||||
discover_resp = _make_mock_response({"url": "http://peer.invalid/a2a"}, 200)
|
||||
a2a_resp = _make_mock_response({"error": "workspace offline"}, 200)
|
||||
mock_client = AsyncMock()
|
||||
mock_client.__aenter__.return_value.get = AsyncMock(return_value=discover_resp)
|
||||
mock_client.__aenter__.return_value.post = AsyncMock(return_value=a2a_resp)
|
||||
|
||||
with patch("builtin_tools.a2a_tools.httpx.AsyncClient", return_value=mock_client):
|
||||
out = await delegate_task("ws-target", "do it")
|
||||
assert out == "Error: workspace offline"
|
||||
|
||||
async def test_error_null(self):
|
||||
"""{"error": null} → "Error: None" (edge case — str(null) in message)."""
|
||||
from builtin_tools.a2a_tools import delegate_task
|
||||
|
||||
discover_resp = _make_mock_response({"url": "http://peer.invalid/a2a"}, 200)
|
||||
a2a_resp = _make_mock_response({"error": None}, 200)
|
||||
mock_client = AsyncMock()
|
||||
mock_client.__aenter__.return_value.get = AsyncMock(return_value=discover_resp)
|
||||
mock_client.__aenter__.return_value.post = AsyncMock(return_value=a2a_resp)
|
||||
|
||||
with patch("builtin_tools.a2a_tools.httpx.AsyncClient", return_value=mock_client):
|
||||
out = await delegate_task("ws-target", "do it")
|
||||
assert "Error" in out
|
||||
|
||||
async def test_a2a_post_raises_exception(self):
|
||||
"""Network error during A2A POST → Error: sending A2A message: ..."""
|
||||
from builtin_tools.a2a_tools import delegate_task
|
||||
|
||||
discover_resp = _make_mock_response({"url": "http://peer.invalid/a2a"}, 200)
|
||||
mock_client = AsyncMock()
|
||||
mock_client.__aenter__.return_value.get = AsyncMock(return_value=discover_resp)
|
||||
mock_client.__aenter__.return_value.post = AsyncMock(
|
||||
side_effect=ConnectionError("connection refused")
|
||||
)
|
||||
|
||||
with patch("builtin_tools.a2a_tools.httpx.AsyncClient", return_value=mock_client):
|
||||
out = await delegate_task("ws-target", "do it")
|
||||
assert "Error" in out
|
||||
assert "connection refused" in out
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# builtin_tools/a2a_tools — get_peers_summary
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class TestGetPeersSummary:
|
||||
"""Coverage for builtin_tools/a2a_tools.get_peers_summary()."""
|
||||
|
||||
async def test_empty_peers_returns_no_peers_available(self):
|
||||
from builtin_tools.a2a_tools import get_peers_summary
|
||||
|
||||
mock_resp = _make_mock_response([], 200)
|
||||
mock_client = AsyncMock()
|
||||
mock_client.__aenter__.return_value.get = AsyncMock(return_value=mock_resp)
|
||||
|
||||
with patch("builtin_tools.a2a_tools.httpx.AsyncClient", return_value=mock_client):
|
||||
out = await get_peers_summary()
|
||||
assert "No peers" in out
|
||||
|
||||
async def test_peer_missing_fields(self):
|
||||
"""Peers with missing name/id/role/status must not KeyError/TypeError."""
|
||||
from builtin_tools.a2a_tools import get_peers_summary
|
||||
|
||||
mock_resp = _make_mock_response([{"id": "ws-x"}], 200)
|
||||
mock_client = AsyncMock()
|
||||
mock_client.__aenter__.return_value.get = AsyncMock(return_value=mock_resp)
|
||||
|
||||
with patch("builtin_tools.a2a_tools.httpx.AsyncClient", return_value=mock_client):
|
||||
out = await get_peers_summary()
|
||||
assert "ws-x" in out
|
||||
assert isinstance(out, str)
|
||||
|
||||
async def test_healthy_peer_roundtrip(self):
|
||||
"""Sanity: normal peer dicts produce a formatted list."""
|
||||
from builtin_tools.a2a_tools import get_peers_summary
|
||||
|
||||
peers = [
|
||||
{"id": "ws-alpha", "name": "Alpha", "role": "sre", "status": "online"},
|
||||
]
|
||||
mock_resp = _make_mock_response(peers, 200)
|
||||
mock_client = AsyncMock()
|
||||
mock_client.__aenter__.return_value.get = AsyncMock(return_value=mock_resp)
|
||||
|
||||
with patch("builtin_tools.a2a_tools.httpx.AsyncClient", return_value=mock_client):
|
||||
out = await get_peers_summary()
|
||||
assert "Alpha" in out
|
||||
assert "ws-alpha" in out
|
||||
assert "sre" in out
|
||||
assert "online" in out
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# send_message_wrapper — safe_send_message
|
||||
# =============================================================================
|
||||
|
||||
from adapters.smolagents.send_message_wrapper import safe_send_message
|
||||
|
||||
|
||||
class TestSafeSendMessage:
|
||||
"""Coverage for adapters.smolagents.send_message_wrapper.safe_send_message()."""
|
||||
|
||||
def test_non_string_input_converted(self):
|
||||
"""Non-str text is str()-converted before escaping."""
|
||||
delivered = []
|
||||
safe_send_message(42, send_fn=lambda s: delivered.append(s))
|
||||
assert delivered == ["[smolagents] 42"]
|
||||
assert isinstance(delivered[0], str)
|
||||
|
||||
def test_html_entities_escaped(self):
|
||||
"""< > ' are escaped so rendered UIs cannot be injected.
|
||||
|
||||
The payload <script>alert('xss')</script> has no literal '&', so &
|
||||
does not appear. The escape output is: <script>alert('xss')</script>
|
||||
"""
|
||||
delivered = []
|
||||
safe_send_message(
|
||||
"<script>alert('xss')</script>",
|
||||
send_fn=lambda s: delivered.append(s),
|
||||
)
|
||||
assert "<" in delivered[0]
|
||||
assert ">" in delivered[0]
|
||||
assert "'" in delivered[0]
|
||||
assert "<script>" in delivered[0]
|
||||
# The angle brackets and quotes must NOT appear unescaped
|
||||
assert "<script>" not in delivered[0]
|
||||
assert "alert('" not in delivered[0]
|
||||
|
||||
def test_truncation_at_max_len(self):
|
||||
"""Text > 2000 chars is truncated; caller is warned."""
|
||||
delivered = []
|
||||
with patch(
|
||||
"adapters.smolagents.send_message_wrapper.logger"
|
||||
) as mock_logger:
|
||||
long_text = "A" * 2500
|
||||
safe_send_message(long_text, send_fn=lambda s: delivered.append(s))
|
||||
assert len(delivered[0]) < len(long_text)
|
||||
mock_logger.warning.assert_called_once()
|
||||
assert "truncating" in mock_logger.warning.call_args[0][0]
|
||||
|
||||
def test_no_truncation_under_max_len(self):
|
||||
"""Text ≤ 2000 chars is passed through intact with no warning."""
|
||||
delivered = []
|
||||
with patch(
|
||||
"adapters.smolagents.send_message_wrapper.logger"
|
||||
) as mock_logger:
|
||||
text = "A" * 1500
|
||||
safe_send_message(text, send_fn=lambda s: delivered.append(s))
|
||||
expected = f"[smolagents] {text}"
|
||||
assert delivered[0] == expected
|
||||
mock_logger.warning.assert_not_called()
|
||||
|
||||
def test_debug_log_emitted(self):
|
||||
"""Every delivery logs at DEBUG with final payload length."""
|
||||
delivered = []
|
||||
with patch(
|
||||
"adapters.smolagents.send_message_wrapper.logger"
|
||||
) as mock_logger:
|
||||
safe_send_message("hello", send_fn=lambda s: delivered.append(s))
|
||||
mock_logger.debug.assert_called_once()
|
||||
assert "delivering" in mock_logger.debug.call_args[0][0]
|
||||
|
||||
def test_label_prefix_always_present(self):
|
||||
"""Every delivered payload starts with '[smolagents]'."""
|
||||
delivered = []
|
||||
safe_send_message("x", send_fn=lambda s: delivered.append(s))
|
||||
assert delivered[0].startswith("[smolagents]")
|
||||
@@ -0,0 +1,300 @@
|
||||
"""Test coverage for shared_runtime helpers (issue #366).
|
||||
|
||||
Six helper functions previously had zero test coverage:
|
||||
_extract_part_text, extract_message_text, format_conversation_history,
|
||||
build_task_text, append_peer_guidance, brief_task
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
|
||||
from shared_runtime import (
|
||||
_extract_part_text,
|
||||
append_peer_guidance,
|
||||
brief_task,
|
||||
build_task_text,
|
||||
extract_message_text,
|
||||
format_conversation_history,
|
||||
)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# _extract_part_text
|
||||
# =============================================================================
|
||||
|
||||
class TestExtractPartText:
|
||||
"""Coverage for shared_runtime._extract_part_text()."""
|
||||
|
||||
def test_dict_with_text_field(self):
|
||||
assert _extract_part_text({"text": "hello"}) == "hello"
|
||||
|
||||
def test_dict_without_text_field(self):
|
||||
assert _extract_part_text({"type": "image"}) == ""
|
||||
|
||||
def test_dict_with_empty_text_field(self):
|
||||
assert _extract_part_text({"text": ""}) == ""
|
||||
|
||||
def test_dict_with_root_nesting(self):
|
||||
"""Text buried in part['root']['text'] is extracted."""
|
||||
assert _extract_part_text({"root": {"text": "nested"}}) == "nested"
|
||||
|
||||
def test_dict_with_root_non_dict(self):
|
||||
"""part['root'] that is not a dict is safely skipped."""
|
||||
assert _extract_part_text({"root": "string", "text": "top"}) == "top"
|
||||
|
||||
def test_object_with_text_attribute(self):
|
||||
class FakePart:
|
||||
text = "attr-text"
|
||||
|
||||
assert _extract_part_text(FakePart()) == "attr-text"
|
||||
|
||||
def test_object_with_root_object_with_text(self):
|
||||
"""Object with root.attr.text is extracted (A2A v1 object style)."""
|
||||
|
||||
class FakeRoot:
|
||||
text = "root-attr-text"
|
||||
|
||||
class FakePart:
|
||||
root = FakeRoot()
|
||||
|
||||
assert _extract_part_text(FakePart()) == "root-attr-text"
|
||||
|
||||
def test_object_with_empty_text_attribute(self):
|
||||
class FakePart:
|
||||
text = ""
|
||||
|
||||
assert _extract_part_text(FakePart()) == ""
|
||||
|
||||
def test_none_input(self):
|
||||
assert _extract_part_text(None) == ""
|
||||
|
||||
def test_unexpected_type(self):
|
||||
"""Plain int/float/bool falls through to empty string."""
|
||||
assert _extract_part_text(42) == ""
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# extract_message_text
|
||||
# =============================================================================
|
||||
|
||||
class TestExtractMessageText:
|
||||
"""Coverage for shared_runtime.extract_message_text()."""
|
||||
|
||||
def test_list_of_dict_parts(self):
|
||||
parts = [{"text": "hello"}, {"text": "world"}]
|
||||
assert extract_message_text(parts) == "hello world"
|
||||
|
||||
def test_single_part(self):
|
||||
assert extract_message_text([{"text": "single"}]) == "single"
|
||||
|
||||
def test_context_object_with_message_parts(self):
|
||||
"""RequestContext-like: .message.parts is the parts list."""
|
||||
|
||||
class FakeContext:
|
||||
class _Msg:
|
||||
parts = [{"text": "from context"}]
|
||||
|
||||
message = _Msg()
|
||||
|
||||
assert extract_message_text(FakeContext()) == "from context"
|
||||
|
||||
def test_context_object_without_message(self):
|
||||
"""No .message attr → falls back to treating input as a parts list."""
|
||||
|
||||
class FakeContext:
|
||||
pass # no .message
|
||||
|
||||
# Pass a list directly as the context-like object
|
||||
assert extract_message_text([{"text": "fallback"}]) == "fallback"
|
||||
|
||||
def test_whitespace_normalized(self):
|
||||
"""Leading/trailing whitespace is stripped; internal newlines are preserved."""
|
||||
parts = [{"text": " hello "}, {"text": "\nworld\n"}]
|
||||
result = extract_message_text(parts)
|
||||
# Leading/trailing stripped, but internal \n stays (join uses single space)
|
||||
assert result == "hello \nworld"
|
||||
assert not result.startswith(" ")
|
||||
assert not result.endswith(" ")
|
||||
|
||||
def test_empty_parts_list(self):
|
||||
assert extract_message_text([]) == ""
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# format_conversation_history
|
||||
# =============================================================================
|
||||
|
||||
class TestFormatConversationHistory:
|
||||
"""Coverage for shared_runtime.format_conversation_history()."""
|
||||
|
||||
def test_single_user_message(self):
|
||||
hist = [("human", "hello")]
|
||||
out = format_conversation_history(hist)
|
||||
assert out == "User: hello"
|
||||
|
||||
def test_single_agent_message(self):
|
||||
hist = [("ai", "response")]
|
||||
out = format_conversation_history(hist)
|
||||
assert out == "Agent: response"
|
||||
|
||||
def test_interleaved_history(self):
|
||||
hist = [
|
||||
("human", "hello"),
|
||||
("ai", "hi there"),
|
||||
("human", "what is 2+2?"),
|
||||
("ai", "four"),
|
||||
]
|
||||
out = format_conversation_history(hist)
|
||||
lines = out.split("\n")
|
||||
assert lines[0] == "User: hello"
|
||||
assert lines[1] == "Agent: hi there"
|
||||
assert lines[2] == "User: what is 2+2?"
|
||||
assert lines[3] == "Agent: four"
|
||||
|
||||
def test_empty_history(self):
|
||||
assert format_conversation_history([]) == ""
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# build_task_text
|
||||
# =============================================================================
|
||||
|
||||
class TestBuildTaskText:
|
||||
"""Coverage for shared_runtime.build_task_text()."""
|
||||
|
||||
def test_no_history_returns_user_message_unchanged(self):
|
||||
assert build_task_text("do the thing", []) == "do the thing"
|
||||
|
||||
def test_history_prepends_transcript(self):
|
||||
hist = [("human", "hello"), ("ai", "hi")]
|
||||
result = build_task_text("follow-up", hist)
|
||||
assert "Conversation so far:" in result
|
||||
assert "User: hello" in result
|
||||
assert "Agent: hi" in result
|
||||
assert "follow-up" in result
|
||||
|
||||
def test_user_message_after_conversation_header(self):
|
||||
hist = [("human", "hello")]
|
||||
result = build_task_text("do it", hist)
|
||||
assert result.startswith("Conversation so far:")
|
||||
assert result.endswith("Current request: do it")
|
||||
|
||||
def test_empty_user_message_with_history(self):
|
||||
"""Empty user_message is still rendered with history."""
|
||||
hist = [("human", "hello")]
|
||||
result = build_task_text("", hist)
|
||||
assert "Conversation so far:" in result
|
||||
assert "Current request:" in result
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# append_peer_guidance
|
||||
# =============================================================================
|
||||
|
||||
class TestAppendPeerGuidance:
|
||||
"""Coverage for shared_runtime.append_peer_guidance()."""
|
||||
|
||||
def test_base_text_appended(self):
|
||||
result = append_peer_guidance(
|
||||
"base text",
|
||||
peers_info="alpha: ws-1",
|
||||
default_text="default",
|
||||
tool_name="delegate_task",
|
||||
)
|
||||
assert result.startswith("base text")
|
||||
assert "## Peers" in result
|
||||
assert "alpha: ws-1" in result
|
||||
assert "Use delegate_task" in result
|
||||
|
||||
def test_null_base_text_uses_default(self):
|
||||
result = append_peer_guidance(
|
||||
None,
|
||||
peers_info="peer info",
|
||||
default_text="DEFAULT_TEXT",
|
||||
tool_name="tool",
|
||||
)
|
||||
assert result.startswith("DEFAULT_TEXT")
|
||||
|
||||
def test_whitespace_base_text_strips_to_empty_peers_still_added(self):
|
||||
"""Whitespace-only base_text is stripped but default_text is NOT used
|
||||
(only None triggers the fallback). The peers section is still appended."""
|
||||
result = append_peer_guidance(
|
||||
" ",
|
||||
peers_info="peer",
|
||||
default_text="DEF",
|
||||
tool_name="t",
|
||||
)
|
||||
# " ".strip() == ""; default_text is NOT substituted for whitespace
|
||||
assert "## Peers" in result
|
||||
assert "peer" in result
|
||||
assert "DEF" not in result # default_text only on None, not whitespace
|
||||
|
||||
def test_none_base_text_uses_default(self):
|
||||
"""None base_text triggers fallback to default_text."""
|
||||
result = append_peer_guidance(
|
||||
None,
|
||||
peers_info="peer",
|
||||
default_text="DEFAULT",
|
||||
tool_name="tool",
|
||||
)
|
||||
assert result.startswith("DEFAULT")
|
||||
assert "## Peers" in result
|
||||
|
||||
def test_empty_peers_info_skips_section(self):
|
||||
result = append_peer_guidance(
|
||||
"base",
|
||||
peers_info="",
|
||||
default_text="def",
|
||||
tool_name="tool",
|
||||
)
|
||||
# No "## Peers" section when peers_info is empty
|
||||
assert result == "base"
|
||||
|
||||
def test_whitespace_in_base_and_peers_normalized(self):
|
||||
result = append_peer_guidance(
|
||||
" base \n",
|
||||
peers_info=" peer-1 \n",
|
||||
default_text="def",
|
||||
tool_name="tool",
|
||||
)
|
||||
# Base should be stripped of leading/trailing whitespace
|
||||
assert result.startswith("base")
|
||||
# Peer info should be appended
|
||||
assert "peer-1" in result
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# brief_task
|
||||
# =============================================================================
|
||||
|
||||
class TestBriefTask:
|
||||
"""Coverage for shared_runtime.brief_task()."""
|
||||
|
||||
def test_short_text_returned_unchanged(self):
|
||||
assert brief_task("hello", limit=60) == "hello"
|
||||
|
||||
def test_exact_limit_no_ellipsis(self):
|
||||
text = "A" * 60
|
||||
assert brief_task(text, limit=60) == text
|
||||
assert "..." not in text
|
||||
|
||||
def test_truncated_with_ellipsis(self):
|
||||
text = "A" * 80
|
||||
result = brief_task(text, limit=60)
|
||||
assert len(result) == 63 # 60 chars + "..."
|
||||
assert result.endswith("...")
|
||||
|
||||
def test_limit_10_shortens(self):
|
||||
result = brief_task("hello world", limit=10)
|
||||
assert len(result) == 13 # 10 chars + "..."
|
||||
assert result.endswith("...")
|
||||
|
||||
def test_limit_0_returns_ellipsis(self):
|
||||
"""limit=0 → 0-char slice + "..." since len("hello") > 0."""
|
||||
result = brief_task("hello", limit=0)
|
||||
assert result == "..."
|
||||
|
||||
def test_limit_1_single_char_plus_ellipsis(self):
|
||||
result = brief_task("hello", limit=1)
|
||||
assert len(result) == 4 # 1 char + "..."
|
||||
assert result.startswith("h")
|
||||
assert result.endswith("...")
|
||||
Reference in New Issue
Block a user