diff --git a/canvas/src/app/orgs/page.tsx b/canvas/src/app/orgs/page.tsx index 29a32632..e8163e24 100644 --- a/canvas/src/app/orgs/page.tsx +++ b/canvas/src/app/orgs/page.tsx @@ -154,7 +154,7 @@ function CheckoutBanner() {

✓ Payment confirmed. Your workspace is spinning up now — this page - refreshes automatically when it's ready. + refreshes automatically when it's ready.

); @@ -318,7 +318,7 @@ function EmptyState({ banner }: { banner?: React.ReactNode }) { {banner}

- You don't have any organizations yet. Create one to get started — your + You don't have any organizations yet. Create one to get started — your workspace spins up automatically once billing is set up.

diff --git a/canvas/src/components/Tooltip.tsx b/canvas/src/components/Tooltip.tsx index 8afb8543..087fcd7c 100644 --- a/canvas/src/components/Tooltip.tsx +++ b/canvas/src/components/Tooltip.tsx @@ -48,6 +48,7 @@ export function Tooltip({ text, children }: Props) { }, []); const onBlur = useCallback(() => { + clearTimeout(timerRef.current); setShow(false); }, []); diff --git a/docs/agent-runtime/workspace-runtime.md b/docs/agent-runtime/workspace-runtime.md index 72ad0dac..2ee03b6b 100644 --- a/docs/agent-runtime/workspace-runtime.md +++ b/docs/agent-runtime/workspace-runtime.md @@ -144,7 +144,7 @@ External workspaces run outside the platform's Docker infrastructure — on your | Liveness | Docker health sweep | Heartbeat TTL (90s offline threshold) | | Registration | Automatic at container start | Manual: `POST /workspaces` + `POST /registry/register` | | Token | Inherited from container env | Minted at registration, shown once | -| Secrets | Baked in image or env var | Pulled from platform at boot via `GET /workspaces/:id/secrets/values` | +| Secrets | Baked in image or env var | Pulled from platform at boot via `GET /workspaces/:id/secrets` | ### Registration flow @@ -185,7 +185,7 @@ The platform returns a 256-bit bearer token — save it, it is shown only once. **3. Pull secrets at boot:** ```bash -curl http://localhost:8080/workspaces/ws-xyz/secrets/values \ +curl http://localhost:8080/workspaces/ws-xyz/secrets \ -H "Authorization: Bearer " ``` diff --git a/docs/blog/2026-04-21-skills-vs-bundled-tools/index.md b/docs/blog/2026-04-21-skills-vs-bundled-tools/index.md index ff01cd97..17048281 100644 --- a/docs/blog/2026-04-21-skills-vs-bundled-tools/index.md +++ b/docs/blog/2026-04-21-skills-vs-bundled-tools/index.md @@ -95,7 +95,7 @@ Here's how the comparison lands: If you want to evaluate Molecule AI's skills coverage, start here: -→ [MCP browser automation guide](/docs/blog/browser-automation-ai-agents-mcp) — browser tools via Chrome DevTools Protocol, same capability as Hermes' built-in browser +→ [MCP browser automation guide](/blog/browser-automation-ai-agents-mcp) — browser tools via Chrome DevTools Protocol, same capability as Hermes' built-in browser → [TTS and image generation skills](/docs/guides/skill-catalog) — community-contributed, versioned, swappable → [Org-scoped API keys](/docs/guides/org-api-keys.md) — production auth and audit diff --git a/docs/guides/skill-catalog.md b/docs/guides/skill-catalog.md new file mode 100644 index 00000000..337becc2 --- /dev/null +++ b/docs/guides/skill-catalog.md @@ -0,0 +1,196 @@ +# Skill Catalog + +Skills extend what a workspace agent can do — from browser automation +and TTS to research tools and custom API integrations. This page covers +available skill types, how to install them, and how to manage their +versions. + +> **Note:** Molecule AI does not ship a hosted skill marketplace. All +> skills are installed from local packages, GitHub URLs, or community +> bundles. See [Skill Lifecycle](#lifecycle) for how to publish and +> distribute skills within your org. + +## Available Skill Types + +The skills ecosystem covers the same capabilities as Hermes Tool Gateway +and more: + +| Category | Skill | What it does | Provider options | +|----------|-------|-------------|-----------------| +| **Browser** | `browser-automation` | Chrome DevTools Protocol via MCP — navigate, query DOM, screenshot, fill forms. Same engine as Hermes' built-in browser tool. | Built-in (CDP); swap via skill version | +| **TTS** | `tts` | Text-to-speech generation. Streams audio to output. | OpenAI, ElevenLabs, or self-hosted | +| **Image gen** | `image-generation` | Generates images from text prompts. | OpenAI DALL·E, Stability AI, or self-hosted | +| **Web search** | `web-search` | Structured web search with result parsing. | Brave, SerpAPI, or custom | +| **Research** | `arxiv-research` | Searches and summarizes arXiv papers. | Community bundle | +| **Code** | `code-analysis` | Static analysis, diff review, complexity scoring. | Built-in | +| **SEO** | `seo-audit` | Lighthouse audit + GSC keyword extraction. | Built-in | +| **Social** | `social-post` | Formats and posts to social channels. | Built-in | + +All skills are open source. Source is visible — inspect the `SKILL.md` +and `tools/` before installing. + +## Installing a Skill + +### From the built-in catalog + +```bash +# Install browser automation +molecule skills install browser-automation + +# Install TTS with a specific provider +molecule skills install tts --provider openai + +# Install a specific version +molecule skills install browser-automation --version 1.2.0 +``` + +### From GitHub + +```bash +molecule skills install \ + https://github.com/acme/molecule-skills/tree/main/browser-automation +``` + +### From a community bundle + +Community skills are hosted on GitHub and referenced by slug: + +```bash +molecule skills install arxiv-research --from community +``` + +Community skills are reviewed by the Molecule AI team before being +listed. Submit a skill for review by opening a PR against +[`molecule-ai/skills`](https://github.com/Molecule-AI/skills). + +## Installing via config.yaml + +Skills can also be declared in the workspace config file: + +```yaml +skills: + - name: browser-automation + source: builtin + - name: tts + source: builtin + config: + provider: openai + - name: arxiv-research + source: community +``` + +On workspace boot, the runtime validates each skill and loads the +`SKILL.md` + tools into the agent's context. + +## Version Management + +Skills are versioned with semantic versioning. Pin to a known-good +release to prevent unexpected behavior changes: + +```bash +# Pin to a specific version +molecule skills install tts --version 1.1.0 + +# Upgrade to latest +molecule skills upgrade tts + +# View installed version +molecule skills list +``` + +Upgrading is safe — the skill loader validates the new package on +installation. If the new version has breaking changes, the workspace logs +a warning and keeps the previous version active until you restart. + +## Custom Skills + +Write a skill for your team's specific workflow: + +```bash +# Scaffold a new skill +molecule skills init my-custom-skill +``` + +This creates: + +``` +skills/my-custom-skill/ ++-- SKILL.md # instructions + frontmatter ++-- tools/ +| +-- my_tool.py # MCP tool using @tool decorator ++-- examples/ # few-shot examples ++-- templates/ # reference files +``` + +See [Skills Reference](../agent-runtime/skills.md) for the full +`SKILL.md` format and frontmatter schema. + +## Skill Lifecycle + +``` +Author writes SKILL.md + tools/ + | + v +Install into workspace (local or GitHub) + | + v +Workspace loads skill on next boot / hot-reload + | + v +Agent sees skill in tool context + | + v +(Optional) Publish to org bundle or community +``` + +**Publishing to your org:** Bundle skills with workspace templates so +every new workspace in a role gets the same capability set: + +```bash +molecule skills bundle my-custom-skill --output ./org-templates/my-role/ +``` + +**Publishing to the community:** Open a PR against +[`molecule-ai/skills`](https://github.com/Molecule-AI/skills) with a +complete skill package. Community skills are reviewed for security and +correctness before listing. + +## Removing a Skill + +```bash +molecule skills uninstall browser-automation +``` + +Or remove from `config.yaml` and trigger a hot-reload by touching the +file: + +```bash +touch /configs/config.yaml +``` + +The workspace detects the change, rescans skills, and updates the Agent +Card within ~3 seconds. + +## Troubleshooting + +**Skill not found:** Check the skill name matches the catalog exactly. +Skill names are lowercase with hyphens (`browser-automation`, not +`browser_automation` or `BrowserAutomation`). + +**Skill loads but tools are missing:** Verify the `tools/` folder +contains valid Python files with `@tool`-decorated functions. See +[Skills Reference — Tool Interface](../agent-runtime/skills.md#tool-interface). + +**Provider auth error:** Ensure the required environment variable (e.g. +`OPENAI_API_KEY`) is set in the workspace config or secrets. + +## Related Docs + +- [Skills Reference](../agent-runtime/skills.md) — Full SKILL.md format, + frontmatter schema, and tool interface +- [Config Format](../agent-runtime/config-format.md) — How skills are + declared in `config.yaml` +- [Plugin System](../plugins/overview.md) — Installing full plugin + packages (skills + MCP servers + shared rules) +- [Remote Agent Tutorial](../tutorials/register-remote-agent.md) — + Installing skills on remote (external) agents \ No newline at end of file diff --git a/docs/marketing/discord-adapter-day2/announcement.md b/docs/marketing/discord-adapter-day2/announcement.md new file mode 100644 index 00000000..a75db641 --- /dev/null +++ b/docs/marketing/discord-adapter-day2/announcement.md @@ -0,0 +1,81 @@ +# Discord Adapter Day 2 — Community Copy + +> Posted 2026-04-21. Discord adapter launched Day 1; Day 2 covers Reddit, Hacker News. +> Blog URL: https://moleculesai.app/blog/discord-adapter-launch +> PR: https://github.com/Molecule-AI/molecule-core/pull/656 + +--- + +## Reddit r/LocalLLaMA + +**Title:** Molecule AI now connects to Discord via a webhook — no bot account, no Gateway, no OAuth + +``` +Molecule AI workspaces can now send messages to Discord and receive slash commands using only a webhook URL. No Discord Developer Portal, no intents, no bot token — just an inbound webhook and your agent is in the channel. + +Built it as a proof-of-concept to keep our own team workflow on Discord without the overhead of a full bot app. Figured other people might want the same thing. + +The adapter uses Discord's built-in webhook delivery for outbound + slash command reception. No polling. No Gateway connection. Works behind NAT — the agent initiates all outbound connections to the platform, which proxies to Discord. + +Here's the architecture gist: +- Outbound: POST to Discord webhook URL (standard, no auth beyond the URL token) +- Inbound: Discord delivers slash command payloads to a platform endpoint; platform fans out to the relevant workspace via A2A +- No Discord bot app required. No Developer Portal setup. + +If your team lives in Discord and you want an AI agent that can post summaries, respond to /ask commands, and route alerts — it's now a webhook URL and a config line. + +Demo repo and docs: https://github.com/Molecule-AI/molecule-core/tree/main/docs/blog/2026-04-21-discord-adapter + +Happy to answer questions about the adapter design. +``` + +**Tags:** `discord`, `mcp`, `molecule-ai`, `webhook`, `ai-agents` + +--- + +## Reddit r/MachineLearning + +**Title:** Show HN: Molecule AI Discord adapter — AI agents in Discord via webhook, no bot account needed + +``` +Show HN: Molecule AI Discord adapter — webhook-only, no Gateway connection required + +HN: built a Discord integration for Molecule AI workspaces that requires zero bot app setup. It's just a webhook URL and an agent config. + +The problem: Discord bot integrations typically require a Developer Portal app, OAuth flow, Gateway connection management, intent configuration, and rate limit handling. That's a meaningful chunk of work before your agent can say hello. + +The approach: use Discord's native webhook delivery for inbound slash commands (no Gateway) and standard webhook POST for outbound messages. The platform acts as a proxy — Discord delivers to the platform endpoint, the platform routes to the relevant workspace via A2A. Works behind NAT since the agent initiates outbound connections. + +No bot token. No intents. No Gateway. + +Code: https://github.com/Molecule-AI/molecule-core/tree/main/docs/blog/2026-04-21-discord-adapter +Launch post: https://moleculesai.app/blog/discord-adapter-launch +``` + +--- + +## Hacker News + +**Title:** Molecule AI — Discord adapter via webhook (no bot account, no Gateway) + +**Body:** + +Built a Discord integration for Molecule AI workspaces that works with just a webhook URL — no Discord Developer Portal setup, no bot token, no Gateway connection. + +**Why** + +Our own team lives in Discord. We wanted a lightweight way to have an AI agent respond to slash commands and post updates without the overhead of a full bot app. Realized Discord's native webhook primitives cover both inbound (slash command delivery) and outbound (channel messages) if you proxy through a platform endpoint. + +**How it works** + +- Outbound: agent POSTs to a Discord webhook URL (standard, URL contains the auth token) +- Inbound: Discord delivers slash command payloads to a platform endpoint; platform fans out to the relevant workspace via A2A +- No bot account required. No Gateway. Works behind NAT — the agent only initiates outbound connections. + +The adapter lives in the MCP server (`mcp-server/src/tools/channels/discord.go`) alongside Telegram and other channel adapters. Each workspace configures its own Discord channel with a webhook URL. + +**Links** + +- Docs: https://moleculesai.app/blog/discord-adapter-launch +- Code + examples: https://github.com/Molecule-AI/molecule-core/tree/main/docs/blog/2026-04-21-discord-adapter +- PR: https://github.com/Molecule-AI/molecule-core/pull/656 diff --git a/docs/quickstart.md b/docs/quickstart.md index 337c168c..a0483d74 100644 --- a/docs/quickstart.md +++ b/docs/quickstart.md @@ -152,7 +152,7 @@ The response includes your bearer token — save it now. It is shown only once. ```bash AGENT_TOKEN="the-token-from-step-2" -curl "$PLATFORM/workspaces/$WORKSPACE_ID/secrets/values" \ +curl "$PLATFORM/workspaces/$WORKSPACE_ID/secrets" \ -H "Authorization: Bearer $AGENT_TOKEN" ``` diff --git a/workspace-server/internal/handlers/a2a_proxy.go b/workspace-server/internal/handlers/a2a_proxy.go index fd6cd50f..18991f38 100644 --- a/workspace-server/internal/handlers/a2a_proxy.go +++ b/workspace-server/internal/handlers/a2a_proxy.go @@ -1,5 +1,10 @@ package handlers +// a2a_proxy.go — A2A JSON-RPC proxy: routes canvas and agent-to-agent +// requests to workspace containers. Core proxy path, URL resolution, +// payload normalization, and HTTP dispatch. Error handling, logging, and +// SSRF helpers live in a2a_proxy_helpers.go. + import ( "bytes" "context" @@ -9,9 +14,7 @@ import ( "fmt" "io" "log" - "net" "net/http" - "net/url" "os" "strconv" "strings" @@ -20,7 +23,6 @@ import ( "github.com/Molecule-AI/molecule-monorepo/platform/internal/db" "github.com/Molecule-AI/molecule-monorepo/platform/internal/provisioner" "github.com/Molecule-AI/molecule-monorepo/platform/internal/registry" - "github.com/Molecule-AI/molecule-monorepo/platform/internal/wsauth" "github.com/gin-gonic/gin" "github.com/google/uuid" ) @@ -473,402 +475,3 @@ func (h *WorkspaceHandler) dispatchA2A(ctx context.Context, agentURL string, bod resp, doErr := a2aClient.Do(req) return resp, cancel, doErr } - -// proxyDispatchBuildError is a sentinel wrapper for failures inside -// http.NewRequestWithContext. handleA2ADispatchError unwraps it to emit the -// "failed to create proxy request" 500 instead of the standard 502/503 paths. -type proxyDispatchBuildError struct{ err error } - -func (e *proxyDispatchBuildError) Error() string { return e.err.Error() } - -// handleA2ADispatchError translates a forward-call failure into a proxyA2AError, -// runs the reactive container-health check, and (when `logActivity` is true) -// schedules a detached LogActivity goroutine for the failed attempt. -func (h *WorkspaceHandler) handleA2ADispatchError(ctx context.Context, workspaceID, callerID string, body []byte, a2aMethod string, err error, durationMs int, logActivity bool) (int, []byte, *proxyA2AError) { - // Build-time failure (couldn't even create the http.Request) — return - // a 500 without the reactive-health / busy-retry paths. - if buildErr, ok := err.(*proxyDispatchBuildError); ok { - _ = buildErr - return 0, nil, &proxyA2AError{ - Status: http.StatusInternalServerError, - Response: gin.H{"error": "failed to create proxy request"}, - } - } - - log.Printf("ProxyA2A forward error: %v", err) - - containerDead := h.maybeMarkContainerDead(ctx, workspaceID) - - if logActivity { - h.logA2AFailure(ctx, workspaceID, callerID, body, a2aMethod, err, durationMs) - } - if containerDead { - return 0, nil, &proxyA2AError{ - Status: http.StatusServiceUnavailable, - Response: gin.H{"error": "workspace agent unreachable — container restart triggered", "restarting": true}, - } - } - // Container is alive but upstream Do() failed with a timeout/EOF- - // shaped error — the agent is most likely mid-synthesis on a - // previous request (single-threaded main loop). Surface as 503 - // Busy with a Retry-After hint so callers can distinguish this - // from a real unreachable-agent (502) and retry with backoff. - // Issue #110. - if isUpstreamBusyError(err) { - return 0, nil, &proxyA2AError{ - Status: http.StatusServiceUnavailable, - Headers: map[string]string{"Retry-After": strconv.Itoa(busyRetryAfterSeconds)}, - Response: gin.H{ - "error": "workspace agent busy — retry after a short backoff", - "busy": true, - "retry_after": busyRetryAfterSeconds, - }, - } - } - return 0, nil, &proxyA2AError{ - Status: http.StatusBadGateway, - Response: gin.H{"error": "failed to reach workspace agent"}, - } -} - -// maybeMarkContainerDead runs the reactive health check after a forward error. -// If the workspace's Docker container is no longer running (and the workspace -// isn't external), it marks the workspace offline, clears Redis state, -// broadcasts WORKSPACE_OFFLINE, and triggers an async restart. Returns true -// when the container was found dead. -func (h *WorkspaceHandler) maybeMarkContainerDead(ctx context.Context, workspaceID string) bool { - var wsRuntime string - db.DB.QueryRowContext(ctx, `SELECT COALESCE(runtime, 'langgraph') FROM workspaces WHERE id = $1`, workspaceID).Scan(&wsRuntime) - if h.provisioner == nil || wsRuntime == "external" { - return false - } - running, inspectErr := h.provisioner.IsRunning(ctx, workspaceID) - if inspectErr != nil { - // Transient Docker-daemon error (timeout, socket EOF, etc.). Post- - // #386, IsRunning returns (true, err) in this case — caller stays - // on the alive path and does not trigger a restart cascade. Log - // so the defect is visible without being destructive. - log.Printf("ProxyA2A: IsRunning for %s returned transient error (assuming alive): %v", workspaceID, inspectErr) - } - if running { - return false - } - log.Printf("ProxyA2A: container for %s is dead — marking offline and triggering restart", workspaceID) - if _, err := db.DB.ExecContext(ctx, `UPDATE workspaces SET status = 'offline', updated_at = now() WHERE id = $1 AND status NOT IN ('removed', 'provisioning')`, workspaceID); err != nil { - log.Printf("ProxyA2A: failed to mark workspace %s offline: %v", workspaceID, err) - } - db.ClearWorkspaceKeys(ctx, workspaceID) - h.broadcaster.RecordAndBroadcast(ctx, "WORKSPACE_OFFLINE", workspaceID, map[string]interface{}{}) - go h.RestartByID(workspaceID) - return true -} - -// logA2AFailure records a failed A2A attempt to activity_logs in a detached -// goroutine (the request context may already be done by the time it runs). -func (h *WorkspaceHandler) logA2AFailure(ctx context.Context, workspaceID, callerID string, body []byte, a2aMethod string, err error, durationMs int) { - errMsg := err.Error() - var errWsName string - db.DB.QueryRowContext(ctx, `SELECT name FROM workspaces WHERE id = $1`, workspaceID).Scan(&errWsName) - if errWsName == "" { - errWsName = workspaceID - } - summary := "A2A request to " + errWsName + " failed: " + errMsg - go func(parent context.Context) { - logCtx, cancel := context.WithTimeout(context.WithoutCancel(parent), 30*time.Second) - defer cancel() - LogActivity(logCtx, h.broadcaster, ActivityParams{ - WorkspaceID: workspaceID, - ActivityType: "a2a_receive", - SourceID: nilIfEmpty(callerID), - TargetID: &workspaceID, - Method: &a2aMethod, - Summary: &summary, - RequestBody: json.RawMessage(body), - DurationMs: &durationMs, - Status: "error", - ErrorDetail: &errMsg, - }) - }(ctx) -} - -// logA2ASuccess records a successful A2A round-trip and (for canvas-initiated -// 2xx/3xx responses) broadcasts an A2A_RESPONSE event so the frontend can -// receive the reply without polling. -func (h *WorkspaceHandler) logA2ASuccess(ctx context.Context, workspaceID, callerID string, body, respBody []byte, a2aMethod string, statusCode, durationMs int) { - logStatus := "ok" - if statusCode >= 400 { - logStatus = "error" - } - var wsNameForLog string - db.DB.QueryRowContext(ctx, `SELECT name FROM workspaces WHERE id = $1`, workspaceID).Scan(&wsNameForLog) - if wsNameForLog == "" { - wsNameForLog = workspaceID - } - - // #817: track outbound activity on the CALLER so orchestrators can detect - // silent workspaces. Only update when callerID is a real workspace (not - // canvas, not a system caller) and the target returned 2xx/3xx. - if callerID != "" && !isSystemCaller(callerID) && statusCode < 400 { - go func() { - bgCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second) - defer cancel() - if _, err := db.DB.ExecContext(bgCtx, - `UPDATE workspaces SET last_outbound_at = NOW() WHERE id = $1`, callerID); err != nil { - log.Printf("last_outbound_at update failed for %s: %v", callerID, err) - } - }() - } - summary := a2aMethod + " → " + wsNameForLog - go func(parent context.Context) { - logCtx, cancel := context.WithTimeout(context.WithoutCancel(parent), 30*time.Second) - defer cancel() - LogActivity(logCtx, h.broadcaster, ActivityParams{ - WorkspaceID: workspaceID, - ActivityType: "a2a_receive", - SourceID: nilIfEmpty(callerID), - TargetID: &workspaceID, - Method: &a2aMethod, - Summary: &summary, - RequestBody: json.RawMessage(body), - ResponseBody: json.RawMessage(respBody), - DurationMs: &durationMs, - Status: logStatus, - }) - }(ctx) - - if callerID == "" && statusCode < 400 { - h.broadcaster.BroadcastOnly(workspaceID, "A2A_RESPONSE", map[string]interface{}{ - "response_body": json.RawMessage(respBody), - "method": a2aMethod, - "duration_ms": durationMs, - }) - } -} - -func nilIfEmpty(s string) *string { - if s == "" { - return nil - } - return &s -} - -// validateCallerToken enforces the Phase 30.5 auth-token contract on the -// caller of an A2A proxy request. Same lazy-bootstrap shape as -// registry.requireWorkspaceToken: if the caller workspace has any live -// token on file, the Authorization header is mandatory and must match; -// if the caller has zero live tokens, they're grandfathered through -// (their next /registry/register will mint their first token, after -// which this branch never fires again for them). -// -// On auth failure this writes the 401 via c and returns an error so the -// handler aborts without running the proxy. -func validateCallerToken(ctx context.Context, c *gin.Context, callerID string) error { - hasLive, err := wsauth.HasAnyLiveToken(ctx, db.DB, callerID) - if err != nil { - // Fail-open here matches the heartbeat path — A2A caller auth is - // defense-in-depth on top of access-control hierarchy, not the - // sole gate on the secret material. A DB hiccup shouldn't take - // the whole A2A path down. - log.Printf("wsauth: caller HasAnyLiveToken(%s) failed: %v — allowing A2A", callerID, err) - return nil - } - if !hasLive { - return nil // legacy / pre-upgrade caller - } - tok := wsauth.BearerTokenFromHeader(c.GetHeader("Authorization")) - if tok == "" { - c.JSON(http.StatusUnauthorized, gin.H{"error": "missing caller auth token"}) - return errInvalidCallerToken - } - if err := wsauth.ValidateToken(ctx, db.DB, callerID, tok); err != nil { - c.JSON(http.StatusUnauthorized, gin.H{"error": "invalid caller auth token"}) - return err - } - return nil -} - -// errInvalidCallerToken is a sentinel for validateCallerToken's "missing -// token" branch so the handler-level guard can detect it without string -// matching (the wsauth errors are typed for the invalid case). -var errInvalidCallerToken = errors.New("missing caller auth token") - -// extractAndUpsertTokenUsage parses LLM usage from a raw A2A response body -// and persists it via upsertTokenUsage. Safe to call in a goroutine — logs -// errors but never panics. ctx must already be detached from the request. -func extractAndUpsertTokenUsage(ctx context.Context, workspaceID string, respBody []byte) { - in, out := parseUsageFromA2AResponse(respBody) - if in > 0 || out > 0 { - upsertTokenUsage(ctx, workspaceID, in, out) - } -} - -// parseUsageFromA2AResponse extracts input_tokens / output_tokens from an A2A -// JSON-RPC response. Inspects two locations in order of preference: -// 1. result.usage — the JSON-RPC 2.0 result envelope from workspace agents. -// 2. usage — top-level, for non-JSON-RPC or direct Anthropic-shaped payloads. -// -// Returns (0, 0) when no recognisable usage data is found. -func parseUsageFromA2AResponse(body []byte) (inputTokens, outputTokens int64) { - if len(body) == 0 { - return 0, 0 - } - var top map[string]json.RawMessage - if err := json.Unmarshal(body, &top); err != nil { - return 0, 0 - } - - // 1. result.usage (JSON-RPC 2.0 wrapper produced by workspace agents). - if rawResult, ok := top["result"]; ok { - var result map[string]json.RawMessage - if err := json.Unmarshal(rawResult, &result); err == nil { - if in, out, ok := readUsageMap(result); ok { - return in, out - } - } - } - - // 2. Fallback: top-level usage (direct Anthropic or non-JSON-RPC response). - if in, out, ok := readUsageMap(top); ok { - return in, out - } - return 0, 0 -} - -// isSafeURL validates that a URL resolves to a publicly-routable address, -// preventing A2A requests from being redirected to internal/cloud-metadata -// infrastructure (SSRF, CWE-918). Workspace URLs come from DB/Redis caches -// so we validate before making any outbound HTTP call. -func isSafeURL(rawURL string) error { - u, err := url.Parse(rawURL) - if err != nil { - return fmt.Errorf("invalid URL: %w", err) - } - // Reject non-HTTP(S) schemes. - if u.Scheme != "http" && u.Scheme != "https" { - return fmt.Errorf("forbidden scheme: %s (only http/https allowed)", u.Scheme) - } - host := u.Hostname() - if host == "" { - return fmt.Errorf("empty hostname") - } - // Block direct IP addresses. - if ip := net.ParseIP(host); ip != nil { - if ip.IsLoopback() || ip.IsUnspecified() || ip.IsLinkLocalUnicast() { - return fmt.Errorf("forbidden loopback/unspecified IP: %s", ip) - } - if isPrivateOrMetadataIP(ip) { - return fmt.Errorf("forbidden private/metadata IP: %s", ip) - } - return nil - } - // For hostnames, resolve and validate each returned IP. - addrs, err := net.LookupHost(host) - if err != nil { - // DNS resolution failure — block it. Could be an internal hostname. - return fmt.Errorf("DNS resolution blocked for hostname: %s (%v)", host, err) - } - if len(addrs) == 0 { - return fmt.Errorf("DNS returned no addresses for: %s", host) - } - for _, addr := range addrs { - ip := net.ParseIP(addr) - if ip != nil && (ip.IsLoopback() || ip.IsUnspecified() || ip.IsLinkLocalUnicast() || isPrivateOrMetadataIP(ip)) { - return fmt.Errorf("hostname %s resolves to forbidden IP: %s", host, ip) - } - } - return nil -} - -// isPrivateOrMetadataIP returns true for cloud-metadata / loopback / link-local -// ranges (always) and RFC-1918 / IPv6 ULA ranges (self-hosted only). -// -// In SaaS cross-EC2 mode (see saasMode() in registry.go) the tenant platform -// and its workspaces share a VPC, so workspaces register with their -// VPC-private IP — typically 172.31.x.x on AWS default VPCs. Blocking RFC-1918 -// unconditionally would reject every legitimate registration. Cloud metadata -// (169.254.0.0/16, fe80::/10), loopback, and TEST-NET ranges stay blocked in -// both modes; they are never a legitimate agent URL. -// -// Both IPv4 and IPv6 are checked. The previous implementation returned false -// for every non-IPv4 input, which meant a registered `[::1]` or `[fe80::…]` -// URL would bypass the SSRF gate entirely. -func isPrivateOrMetadataIP(ip net.IP) bool { - // Always blocked — IPv4 cloud metadata + network-test ranges. - metadataRangesV4 := []string{ - "169.254.0.0/16", // link-local / IMDSv1-v2 - "100.64.0.0/10", // CGNAT — reachable via some VPC configs, not a legit agent URL - "192.0.2.0/24", // TEST-NET-1 - "198.51.100.0/24", // TEST-NET-2 - "203.0.113.0/24", // TEST-NET-3 - } - // Always blocked — IPv6 cloud-metadata / loopback equivalents. - metadataRangesV6 := []string{ - "::1/128", // loopback - "fe80::/10", // link-local (IMDS analogue) - "::ffff:0:0/96", // IPv4-mapped loopback (defence-in-depth; To4() below usually normalises first) - } - // RFC-1918 private — blocked in self-hosted, allowed in SaaS. - rfc1918RangesV4 := []string{ - "10.0.0.0/8", - "172.16.0.0/12", - "192.168.0.0/16", - } - // RFC-4193 ULA — IPv6 analogue of RFC-1918. Same SaaS-mode treatment. - ulaRangesV6 := []string{ - "fc00::/7", - } - - contains := func(cidrs []string, target net.IP) bool { - for _, c := range cidrs { - _, n, err := net.ParseCIDR(c) - if err != nil { - continue - } - if n.Contains(target) { - return true - } - } - return false - } - - // Prefer IPv4 semantics when the input is an IPv4 address encoded in any - // form (raw v4, ::ffff:a.b.c.d, etc.) — To4() normalises all of them. - if ip4 := ip.To4(); ip4 != nil { - if contains(metadataRangesV4, ip4) { - return true - } - if saasMode() { - return false - } - return contains(rfc1918RangesV4, ip4) - } - - // True IPv6 path. - if contains(metadataRangesV6, ip) { - return true - } - if saasMode() { - return false - } - return contains(ulaRangesV6, ip) -} - -// readUsageMap extracts input_tokens / output_tokens from the "usage" key of m. -// Returns (0, 0, false) when the key is absent or contains no non-zero values. -func readUsageMap(m map[string]json.RawMessage) (inputTokens, outputTokens int64, ok bool) { - rawUsage, has := m["usage"] - if !has { - return 0, 0, false - } - var usage struct { - InputTokens int64 `json:"input_tokens"` - OutputTokens int64 `json:"output_tokens"` - } - if err := json.Unmarshal(rawUsage, &usage); err != nil { - return 0, 0, false - } - if usage.InputTokens == 0 && usage.OutputTokens == 0 { - return 0, 0, false - } - return usage.InputTokens, usage.OutputTokens, true -} diff --git a/workspace-server/internal/handlers/a2a_proxy_helpers.go b/workspace-server/internal/handlers/a2a_proxy_helpers.go new file mode 100644 index 00000000..1a87071a --- /dev/null +++ b/workspace-server/internal/handlers/a2a_proxy_helpers.go @@ -0,0 +1,421 @@ +package handlers + +// a2a_proxy_helpers.go — A2A proxy error handling, activity logging, +// caller auth validation, token usage tracking, and SSRF safety checks. + +import ( + "context" + "database/sql" + "encoding/json" + "errors" + "fmt" + "log" + "net" + "net/http" + "net/url" + "strconv" + "strings" + "time" + + "github.com/Molecule-AI/molecule-monorepo/platform/internal/db" + "github.com/Molecule-AI/molecule-monorepo/platform/internal/wsauth" + "github.com/gin-gonic/gin" +) +// proxyDispatchBuildError is a sentinel wrapper for failures inside +// http.NewRequestWithContext. handleA2ADispatchError unwraps it to emit the +// "failed to create proxy request" 500 instead of the standard 502/503 paths. +type proxyDispatchBuildError struct{ err error } + +func (e *proxyDispatchBuildError) Error() string { return e.err.Error() } + +// handleA2ADispatchError translates a forward-call failure into a proxyA2AError, +// runs the reactive container-health check, and (when `logActivity` is true) +// schedules a detached LogActivity goroutine for the failed attempt. +func (h *WorkspaceHandler) handleA2ADispatchError(ctx context.Context, workspaceID, callerID string, body []byte, a2aMethod string, err error, durationMs int, logActivity bool) (int, []byte, *proxyA2AError) { + // Build-time failure (couldn't even create the http.Request) — return + // a 500 without the reactive-health / busy-retry paths. + if buildErr, ok := err.(*proxyDispatchBuildError); ok { + _ = buildErr + return 0, nil, &proxyA2AError{ + Status: http.StatusInternalServerError, + Response: gin.H{"error": "failed to create proxy request"}, + } + } + + log.Printf("ProxyA2A forward error: %v", err) + + containerDead := h.maybeMarkContainerDead(ctx, workspaceID) + + if logActivity { + h.logA2AFailure(ctx, workspaceID, callerID, body, a2aMethod, err, durationMs) + } + if containerDead { + return 0, nil, &proxyA2AError{ + Status: http.StatusServiceUnavailable, + Response: gin.H{"error": "workspace agent unreachable — container restart triggered", "restarting": true}, + } + } + // Container is alive but upstream Do() failed with a timeout/EOF- + // shaped error — the agent is most likely mid-synthesis on a + // previous request (single-threaded main loop). Surface as 503 + // Busy with a Retry-After hint so callers can distinguish this + // from a real unreachable-agent (502) and retry with backoff. + // Issue #110. + if isUpstreamBusyError(err) { + return 0, nil, &proxyA2AError{ + Status: http.StatusServiceUnavailable, + Headers: map[string]string{"Retry-After": strconv.Itoa(busyRetryAfterSeconds)}, + Response: gin.H{ + "error": "workspace agent busy — retry after a short backoff", + "busy": true, + "retry_after": busyRetryAfterSeconds, + }, + } + } + return 0, nil, &proxyA2AError{ + Status: http.StatusBadGateway, + Response: gin.H{"error": "failed to reach workspace agent"}, + } +} + +// maybeMarkContainerDead runs the reactive health check after a forward error. +// If the workspace's Docker container is no longer running (and the workspace +// isn't external), it marks the workspace offline, clears Redis state, +// broadcasts WORKSPACE_OFFLINE, and triggers an async restart. Returns true +// when the container was found dead. +func (h *WorkspaceHandler) maybeMarkContainerDead(ctx context.Context, workspaceID string) bool { + var wsRuntime string + db.DB.QueryRowContext(ctx, `SELECT COALESCE(runtime, 'langgraph') FROM workspaces WHERE id = $1`, workspaceID).Scan(&wsRuntime) + if h.provisioner == nil || wsRuntime == "external" { + return false + } + running, inspectErr := h.provisioner.IsRunning(ctx, workspaceID) + if inspectErr != nil { + // Transient Docker-daemon error (timeout, socket EOF, etc.). Post- + // #386, IsRunning returns (true, err) in this case — caller stays + // on the alive path and does not trigger a restart cascade. Log + // so the defect is visible without being destructive. + log.Printf("ProxyA2A: IsRunning for %s returned transient error (assuming alive): %v", workspaceID, inspectErr) + } + if running { + return false + } + log.Printf("ProxyA2A: container for %s is dead — marking offline and triggering restart", workspaceID) + if _, err := db.DB.ExecContext(ctx, `UPDATE workspaces SET status = 'offline', updated_at = now() WHERE id = $1 AND status NOT IN ('removed', 'provisioning')`, workspaceID); err != nil { + log.Printf("ProxyA2A: failed to mark workspace %s offline: %v", workspaceID, err) + } + db.ClearWorkspaceKeys(ctx, workspaceID) + h.broadcaster.RecordAndBroadcast(ctx, "WORKSPACE_OFFLINE", workspaceID, map[string]interface{}{}) + go h.RestartByID(workspaceID) + return true +} + +// logA2AFailure records a failed A2A attempt to activity_logs in a detached +// goroutine (the request context may already be done by the time it runs). +func (h *WorkspaceHandler) logA2AFailure(ctx context.Context, workspaceID, callerID string, body []byte, a2aMethod string, err error, durationMs int) { + errMsg := err.Error() + var errWsName string + db.DB.QueryRowContext(ctx, `SELECT name FROM workspaces WHERE id = $1`, workspaceID).Scan(&errWsName) + if errWsName == "" { + errWsName = workspaceID + } + summary := "A2A request to " + errWsName + " failed: " + errMsg + go func(parent context.Context) { + logCtx, cancel := context.WithTimeout(context.WithoutCancel(parent), 30*time.Second) + defer cancel() + LogActivity(logCtx, h.broadcaster, ActivityParams{ + WorkspaceID: workspaceID, + ActivityType: "a2a_receive", + SourceID: nilIfEmpty(callerID), + TargetID: &workspaceID, + Method: &a2aMethod, + Summary: &summary, + RequestBody: json.RawMessage(body), + DurationMs: &durationMs, + Status: "error", + ErrorDetail: &errMsg, + }) + }(ctx) +} + +// logA2ASuccess records a successful A2A round-trip and (for canvas-initiated +// 2xx/3xx responses) broadcasts an A2A_RESPONSE event so the frontend can +// receive the reply without polling. +func (h *WorkspaceHandler) logA2ASuccess(ctx context.Context, workspaceID, callerID string, body, respBody []byte, a2aMethod string, statusCode, durationMs int) { + logStatus := "ok" + if statusCode >= 400 { + logStatus = "error" + } + var wsNameForLog string + db.DB.QueryRowContext(ctx, `SELECT name FROM workspaces WHERE id = $1`, workspaceID).Scan(&wsNameForLog) + if wsNameForLog == "" { + wsNameForLog = workspaceID + } + + // #817: track outbound activity on the CALLER so orchestrators can detect + // silent workspaces. Only update when callerID is a real workspace (not + // canvas, not a system caller) and the target returned 2xx/3xx. + if callerID != "" && !isSystemCaller(callerID) && statusCode < 400 { + go func() { + bgCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + if _, err := db.DB.ExecContext(bgCtx, + `UPDATE workspaces SET last_outbound_at = NOW() WHERE id = $1`, callerID); err != nil { + log.Printf("last_outbound_at update failed for %s: %v", callerID, err) + } + }() + } + summary := a2aMethod + " → " + wsNameForLog + go func(parent context.Context) { + logCtx, cancel := context.WithTimeout(context.WithoutCancel(parent), 30*time.Second) + defer cancel() + LogActivity(logCtx, h.broadcaster, ActivityParams{ + WorkspaceID: workspaceID, + ActivityType: "a2a_receive", + SourceID: nilIfEmpty(callerID), + TargetID: &workspaceID, + Method: &a2aMethod, + Summary: &summary, + RequestBody: json.RawMessage(body), + ResponseBody: json.RawMessage(respBody), + DurationMs: &durationMs, + Status: logStatus, + }) + }(ctx) + + if callerID == "" && statusCode < 400 { + h.broadcaster.BroadcastOnly(workspaceID, "A2A_RESPONSE", map[string]interface{}{ + "response_body": json.RawMessage(respBody), + "method": a2aMethod, + "duration_ms": durationMs, + }) + } +} + +func nilIfEmpty(s string) *string { + if s == "" { + return nil + } + return &s +} + +// validateCallerToken enforces the Phase 30.5 auth-token contract on the +// caller of an A2A proxy request. Same lazy-bootstrap shape as +// registry.requireWorkspaceToken: if the caller workspace has any live +// token on file, the Authorization header is mandatory and must match; +// if the caller has zero live tokens, they're grandfathered through +// (their next /registry/register will mint their first token, after +// which this branch never fires again for them). +// +// On auth failure this writes the 401 via c and returns an error so the +// handler aborts without running the proxy. +func validateCallerToken(ctx context.Context, c *gin.Context, callerID string) error { + hasLive, err := wsauth.HasAnyLiveToken(ctx, db.DB, callerID) + if err != nil { + // Fail-open here matches the heartbeat path — A2A caller auth is + // defense-in-depth on top of access-control hierarchy, not the + // sole gate on the secret material. A DB hiccup shouldn't take + // the whole A2A path down. + log.Printf("wsauth: caller HasAnyLiveToken(%s) failed: %v — allowing A2A", callerID, err) + return nil + } + if !hasLive { + return nil // legacy / pre-upgrade caller + } + tok := wsauth.BearerTokenFromHeader(c.GetHeader("Authorization")) + if tok == "" { + c.JSON(http.StatusUnauthorized, gin.H{"error": "missing caller auth token"}) + return errInvalidCallerToken + } + if err := wsauth.ValidateToken(ctx, db.DB, callerID, tok); err != nil { + c.JSON(http.StatusUnauthorized, gin.H{"error": "invalid caller auth token"}) + return err + } + return nil +} + +// errInvalidCallerToken is a sentinel for validateCallerToken's "missing +// token" branch so the handler-level guard can detect it without string +// matching (the wsauth errors are typed for the invalid case). +var errInvalidCallerToken = errors.New("missing caller auth token") + +// extractAndUpsertTokenUsage parses LLM usage from a raw A2A response body +// and persists it via upsertTokenUsage. Safe to call in a goroutine — logs +// errors but never panics. ctx must already be detached from the request. +func extractAndUpsertTokenUsage(ctx context.Context, workspaceID string, respBody []byte) { + in, out := parseUsageFromA2AResponse(respBody) + if in > 0 || out > 0 { + upsertTokenUsage(ctx, workspaceID, in, out) + } +} + +// parseUsageFromA2AResponse extracts input_tokens / output_tokens from an A2A +// JSON-RPC response. Inspects two locations in order of preference: +// 1. result.usage — the JSON-RPC 2.0 result envelope from workspace agents. +// 2. usage — top-level, for non-JSON-RPC or direct Anthropic-shaped payloads. +// +// Returns (0, 0) when no recognisable usage data is found. +func parseUsageFromA2AResponse(body []byte) (inputTokens, outputTokens int64) { + if len(body) == 0 { + return 0, 0 + } + var top map[string]json.RawMessage + if err := json.Unmarshal(body, &top); err != nil { + return 0, 0 + } + + // 1. result.usage (JSON-RPC 2.0 wrapper produced by workspace agents). + if rawResult, ok := top["result"]; ok { + var result map[string]json.RawMessage + if err := json.Unmarshal(rawResult, &result); err == nil { + if in, out, ok := readUsageMap(result); ok { + return in, out + } + } + } + + // 2. Fallback: top-level usage (direct Anthropic or non-JSON-RPC response). + if in, out, ok := readUsageMap(top); ok { + return in, out + } + return 0, 0 +} + +// isSafeURL validates that a URL resolves to a publicly-routable address, +// preventing A2A requests from being redirected to internal/cloud-metadata +// infrastructure (SSRF, CWE-918). Workspace URLs come from DB/Redis caches +// so we validate before making any outbound HTTP call. +func isSafeURL(rawURL string) error { + u, err := url.Parse(rawURL) + if err != nil { + return fmt.Errorf("invalid URL: %w", err) + } + // Reject non-HTTP(S) schemes. + if u.Scheme != "http" && u.Scheme != "https" { + return fmt.Errorf("forbidden scheme: %s (only http/https allowed)", u.Scheme) + } + host := u.Hostname() + if host == "" { + return fmt.Errorf("empty hostname") + } + // Block direct IP addresses. + if ip := net.ParseIP(host); ip != nil { + if ip.IsLoopback() || ip.IsUnspecified() || ip.IsLinkLocalUnicast() { + return fmt.Errorf("forbidden loopback/unspecified IP: %s", ip) + } + if isPrivateOrMetadataIP(ip) { + return fmt.Errorf("forbidden private/metadata IP: %s", ip) + } + return nil + } + // For hostnames, resolve and validate each returned IP. + addrs, err := net.LookupHost(host) + if err != nil { + // DNS resolution failure — block it. Could be an internal hostname. + return fmt.Errorf("DNS resolution blocked for hostname: %s (%v)", host, err) + } + if len(addrs) == 0 { + return fmt.Errorf("DNS returned no addresses for: %s", host) + } + for _, addr := range addrs { + ip := net.ParseIP(addr) + if ip != nil && (ip.IsLoopback() || ip.IsUnspecified() || ip.IsLinkLocalUnicast() || isPrivateOrMetadataIP(ip)) { + return fmt.Errorf("hostname %s resolves to forbidden IP: %s", host, ip) + } + } + return nil +} + +// isPrivateOrMetadataIP returns true for cloud-metadata / loopback / link-local +// ranges (always) and RFC-1918 / IPv6 ULA ranges (self-hosted only). +// +// In SaaS cross-EC2 mode (see saasMode() in registry.go) the tenant platform +// and its workspaces share a VPC, so workspaces register with their +// VPC-private IP — typically 172.31.x.x on AWS default VPCs. Blocking RFC-1918 +// unconditionally would reject every legitimate registration. Cloud metadata +// (169.254.0.0/16, fe80::/10), loopback, and TEST-NET ranges stay blocked in +// both modes; they are never a legitimate agent URL. +// +// Both IPv4 and IPv6 are checked. The previous implementation returned false +// for every non-IPv4 input, which meant a registered [::1] or [fe80::…] +// URL would bypass the SSRF gate entirely. +func isPrivateOrMetadataIP(ip net.IP) bool { + // Always blocked — IPv4 cloud metadata + network-test ranges. + metadataRangesV4 := []string{ + "169.254.0.0/16", // link-local / IMDSv1-v2 + "100.64.0.0/10", // CGNAT — reachable via some VPC configs, not a legit agent URL + "192.0.2.0/24", // TEST-NET-1 + "198.51.100.0/24", // TEST-NET-2 + "203.0.113.0/24", // TEST-NET-3 + } + // Always blocked — IPv6 cloud-metadata / loopback equivalents. + metadataRangesV6 := []string{ + "::1/128", // loopback + "fe80::/10", // link-local (IMDS analogue) + "::ffff:0:0/96", // IPv4-mapped loopback (defence-in-depth; To4() below usually normalises first) + } + // RFC-1918 private — blocked in self-hosted, allowed in SaaS. + rfc1918RangesV4 := []string{ + "10.0.0.0/8", + "172.16.0.0/12", + "192.168.0.0/16", + } + // RFC-4193 ULA — IPv6 analogue of RFC-1918. Same SaaS-mode treatment. + ulaRangesV6 := []string{ + "fc00::/7", + } + + contains := func(cidrs []string, target net.IP) bool { + for _, c := range cidrs { + _, n, err := net.ParseCIDR(c) + if err != nil { + continue + } + if n.Contains(target) { + return true + } + } + return false + } + + // Prefer IPv4 semantics when the input is an IPv4 address encoded in any + // form (raw v4, ::ffff:a.b.c.d, etc.) — To4() normalises all of them. + if ip4 := ip.To4(); ip4 != nil { + if contains(metadataRangesV4, ip4) { + return true + } + if saasMode() { + return false + } + return contains(rfc1918RangesV4, ip4) + } + + // True IPv6 path. + if contains(metadataRangesV6, ip) { + return true + } + if saasMode() { + return false + } + return contains(ulaRangesV6, ip) +} + +// readUsageMap extracts input_tokens / output_tokens from the "usage" key of m. +// Returns (0, 0, false) when the key is absent or contains no non-zero values. +func readUsageMap(m map[string]json.RawMessage) (inputTokens, outputTokens int64, ok bool) { + rawUsage, has := m["usage"] + if !has { + return 0, 0, false + } + var usage struct { + InputTokens int64 `json:"input_tokens"` + OutputTokens int64 `json:"output_tokens"` + } + if err := json.Unmarshal(rawUsage, &usage); err != nil { + return 0, 0, false + } + if usage.InputTokens == 0 && usage.OutputTokens == 0 { + return 0, 0, false + } + return usage.InputTokens, usage.OutputTokens, true +} diff --git a/workspace-server/internal/handlers/mcp.go b/workspace-server/internal/handlers/mcp.go index 8d7ac598..bc0b9668 100644 --- a/workspace-server/internal/handlers/mcp.go +++ b/workspace-server/internal/handlers/mcp.go @@ -1,6 +1,10 @@ package handlers -// Package handlers — MCP bridge for opencode integration (#800, #809, #810). +// mcp.go — MCP bridge protocol handling: JSON-RPC types, handler struct, +// tool definitions, HTTP endpoints (Call, Stream), and RPC dispatch. +// Tool implementations live in mcp_tools.go. +// +// MCP bridge for opencode integration (#800, #809, #810). // // Exposes the same 8 A2A tools as workspace/a2a_mcp_server.py but // served directly from the platform over HTTP so CLI runtimes running @@ -20,24 +24,16 @@ package handlers // MOLECULE_MCP_ALLOW_SEND_MESSAGE=true. import ( - "bytes" "context" "database/sql" "encoding/json" "fmt" - "io" - "log" "net/http" "os" - "strings" "time" - "github.com/Molecule-AI/molecule-monorepo/platform/internal/db" "github.com/Molecule-AI/molecule-monorepo/platform/internal/events" - "github.com/Molecule-AI/molecule-monorepo/platform/internal/provisioner" - "github.com/Molecule-AI/molecule-monorepo/platform/internal/registry" "github.com/gin-gonic/gin" - "github.com/google/uuid" ) // mcpProtocolVersion is the MCP spec version this server implements. @@ -389,546 +385,3 @@ func (h *MCPHandler) dispatch(ctx context.Context, workspaceID, toolName string, return "", fmt.Errorf("unknown tool: %s", toolName) } } - -// ───────────────────────────────────────────────────────────────────────────── -// Tool implementations -// ───────────────────────────────────────────────────────────────────────────── - -func (h *MCPHandler) toolListPeers(ctx context.Context, workspaceID string) (string, error) { - var parentID sql.NullString - err := h.database.QueryRowContext(ctx, - `SELECT parent_id FROM workspaces WHERE id = $1`, workspaceID, - ).Scan(&parentID) - if err == sql.ErrNoRows { - return "", fmt.Errorf("workspace not found") - } - if err != nil { - return "", fmt.Errorf("lookup failed: %w", err) - } - - type peer struct { - ID string `json:"id"` - Name string `json:"name"` - Role string `json:"role"` - Status string `json:"status"` - Tier int `json:"tier"` - } - - var peers []peer - - scanPeers := func(rows *sql.Rows) error { - defer rows.Close() - for rows.Next() { - var p peer - if err := rows.Scan(&p.ID, &p.Name, &p.Role, &p.Status, &p.Tier); err != nil { - return err - } - peers = append(peers, p) - } - return rows.Err() - } - - const cols = `SELECT w.id, w.name, COALESCE(w.role,''), w.status, w.tier` - - // Siblings - if parentID.Valid { - rows, err := h.database.QueryContext(ctx, - cols+` FROM workspaces w WHERE w.parent_id = $1 AND w.id != $2 AND w.status != 'removed'`, - parentID.String, workspaceID) - if err == nil { - _ = scanPeers(rows) - } - } else { - rows, err := h.database.QueryContext(ctx, - cols+` FROM workspaces w WHERE w.parent_id IS NULL AND w.id != $1 AND w.status != 'removed'`, - workspaceID) - if err == nil { - _ = scanPeers(rows) - } - } - - // Children - { - rows, err := h.database.QueryContext(ctx, - cols+` FROM workspaces w WHERE w.parent_id = $1 AND w.status != 'removed'`, - workspaceID) - if err == nil { - _ = scanPeers(rows) - } - } - - // Parent - if parentID.Valid { - rows, err := h.database.QueryContext(ctx, - cols+` FROM workspaces w WHERE w.id = $1 AND w.status != 'removed'`, - parentID.String) - if err == nil { - _ = scanPeers(rows) - } - } - - if len(peers) == 0 { - return "No peers found.", nil - } - - b, _ := json.MarshalIndent(peers, "", " ") - return string(b), nil -} - -func (h *MCPHandler) toolGetWorkspaceInfo(ctx context.Context, workspaceID string) (string, error) { - var id, name, role, status string - var tier int - var parentID sql.NullString - - err := h.database.QueryRowContext(ctx, ` - SELECT id, name, COALESCE(role,''), tier, status, parent_id - FROM workspaces WHERE id = $1 - `, workspaceID).Scan(&id, &name, &role, &tier, &status, &parentID) - if err == sql.ErrNoRows { - return "", fmt.Errorf("workspace not found") - } - if err != nil { - return "", fmt.Errorf("lookup failed: %w", err) - } - - info := map[string]interface{}{ - "id": id, - "name": name, - "role": role, - "tier": tier, - "status": status, - } - if parentID.Valid { - info["parent_id"] = parentID.String - } - b, _ := json.MarshalIndent(info, "", " ") - return string(b), nil -} - -func (h *MCPHandler) toolDelegateTask(ctx context.Context, callerID string, args map[string]interface{}, timeout time.Duration) (string, error) { - targetID, _ := args["workspace_id"].(string) - task, _ := args["task"].(string) - if targetID == "" { - return "", fmt.Errorf("workspace_id is required") - } - if task == "" { - return "", fmt.Errorf("task is required") - } - - if !registry.CanCommunicate(callerID, targetID) { - return "", fmt.Errorf("workspace %s is not authorised to communicate with %s", callerID, targetID) - } - - agentURL, err := mcpResolveURL(ctx, h.database, targetID) - if err != nil { - return "", err - } - // SSRF defence: reject private/metadata URLs before making outbound call. - if err := isSafeURL(agentURL); err != nil { - return "", fmt.Errorf("invalid workspace URL: %w", err) - } - - a2aBody, err := json.Marshal(map[string]interface{}{ - "jsonrpc": "2.0", - "id": uuid.New().String(), - "method": "message/send", - "params": map[string]interface{}{ - "message": map[string]interface{}{ - "role": "user", - "parts": []map[string]interface{}{{"type": "text", "text": task}}, - "messageId": uuid.New().String(), - }, - }, - }) - if err != nil { - return "", fmt.Errorf("failed to build A2A request: %w", err) - } - - reqCtx, cancel := context.WithTimeout(ctx, timeout) - defer cancel() - - httpReq, err := http.NewRequestWithContext(reqCtx, "POST", agentURL+"/a2a", bytes.NewReader(a2aBody)) - if err != nil { - return "", fmt.Errorf("failed to create request: %w", err) - } - httpReq.Header.Set("Content-Type", "application/json") - // X-Workspace-ID identifies this caller to the A2A proxy. The /workspaces/:id/a2a - // endpoint is intentionally outside WorkspaceAuth (agents do not hold bearer tokens - // to peer workspaces). Access control is enforced by CanCommunicate above, which - // already validated callerID → targetID before this request is constructed. - // callerID was authenticated by WorkspaceAuth on the MCP bridge entry point, - // so this header reflects a verified caller identity, not a spoofable value. - httpReq.Header.Set("X-Workspace-ID", callerID) - - resp, err := http.DefaultClient.Do(httpReq) - if err != nil { - return "", fmt.Errorf("A2A call failed: %w", err) - } - defer func() { _ = resp.Body.Close() }() - - body, err := io.ReadAll(io.LimitReader(resp.Body, 1<<20)) - if err != nil { - return "", fmt.Errorf("failed to read response: %w", err) - } - - return extractA2AText(body), nil -} - -func (h *MCPHandler) toolDelegateTaskAsync(ctx context.Context, callerID string, args map[string]interface{}) (string, error) { - targetID, _ := args["workspace_id"].(string) - task, _ := args["task"].(string) - if targetID == "" { - return "", fmt.Errorf("workspace_id is required") - } - if task == "" { - return "", fmt.Errorf("task is required") - } - - if !registry.CanCommunicate(callerID, targetID) { - return "", fmt.Errorf("workspace %s is not authorised to communicate with %s", callerID, targetID) - } - - taskID := uuid.New().String() - - // Fire and forget in a detached goroutine. Use a background context so - // the call is not cancelled when the HTTP request completes. - go func() { - bgCtx, cancel := context.WithTimeout(context.Background(), mcpAsyncCallTimeout) - defer cancel() - - agentURL, err := mcpResolveURL(bgCtx, h.database, targetID) - if err != nil { - log.Printf("MCPHandler.delegate_task_async: resolve URL for %s: %v", targetID, err) - return - } - // SSRF defence: reject private/metadata URLs before making outbound call. - if err := isSafeURL(agentURL); err != nil { - log.Printf("MCPHandler.delegate_task_async: unsafe URL for %s: %v", targetID, err) - return - } - - a2aBody, _ := json.Marshal(map[string]interface{}{ - "jsonrpc": "2.0", - "id": taskID, - "method": "message/send", - "params": map[string]interface{}{ - "message": map[string]interface{}{ - "role": "user", - "parts": []map[string]interface{}{{"type": "text", "text": task}}, - "messageId": uuid.New().String(), - }, - }, - }) - - httpReq, err := http.NewRequestWithContext(bgCtx, "POST", agentURL+"/a2a", bytes.NewReader(a2aBody)) - if err != nil { - log.Printf("MCPHandler.delegate_task_async: create request: %v", err) - return - } - httpReq.Header.Set("Content-Type", "application/json") - httpReq.Header.Set("X-Workspace-ID", callerID) - - resp, err := http.DefaultClient.Do(httpReq) - if err != nil { - log.Printf("MCPHandler.delegate_task_async: A2A call to %s: %v", targetID, err) - return - } - defer func() { _ = resp.Body.Close() }() - // Drain response so the connection can be reused. - _, _ = io.Copy(io.Discard, resp.Body) - }() - - return fmt.Sprintf(`{"task_id":%q,"status":"dispatched","target_id":%q}`, taskID, targetID), nil -} - -func (h *MCPHandler) toolCheckTaskStatus(ctx context.Context, callerID string, args map[string]interface{}) (string, error) { - targetID, _ := args["workspace_id"].(string) - taskID, _ := args["task_id"].(string) - if targetID == "" { - return "", fmt.Errorf("workspace_id is required") - } - if taskID == "" { - return "", fmt.Errorf("task_id is required") - } - - var status, errorDetail sql.NullString - var responseBody []byte - - err := h.database.QueryRowContext(ctx, ` - SELECT status, error_detail, response_body - FROM activity_logs - WHERE workspace_id = $1 - AND target_id = $2 - AND request_body->>'delegation_id' = $3 - ORDER BY created_at DESC - LIMIT 1 - `, callerID, targetID, taskID).Scan(&status, &errorDetail, &responseBody) - if err == sql.ErrNoRows { - return fmt.Sprintf(`{"task_id":%q,"status":"not_found","note":"task not tracked or not yet dispatched"}`, taskID), nil - } - if err != nil { - return "", fmt.Errorf("status lookup failed: %w", err) - } - - result := map[string]interface{}{ - "task_id": taskID, - "status": status.String, - "target_id": targetID, - } - if errorDetail.Valid && errorDetail.String != "" { - result["error"] = errorDetail.String - } - if len(responseBody) > 0 { - result["result"] = extractA2AText(responseBody) - } - b, _ := json.MarshalIndent(result, "", " ") - return string(b), nil -} - -func (h *MCPHandler) toolSendMessageToUser(ctx context.Context, workspaceID string, args map[string]interface{}) (string, error) { - message, _ := args["message"].(string) - if message == "" { - return "", fmt.Errorf("message is required") - } - - // Check send_message_to_user is enabled (C3). - if os.Getenv("MOLECULE_MCP_ALLOW_SEND_MESSAGE") != "true" { - return "", fmt.Errorf("send_message_to_user is not enabled on this MCP bridge (set MOLECULE_MCP_ALLOW_SEND_MESSAGE=true)") - } - - var wsName string - err := h.database.QueryRowContext(ctx, - `SELECT name FROM workspaces WHERE id = $1 AND status != 'removed'`, workspaceID, - ).Scan(&wsName) - if err != nil { - return "", fmt.Errorf("workspace not found") - } - - h.broadcaster.BroadcastOnly(workspaceID, "AGENT_MESSAGE", map[string]interface{}{ - "message": message, - "workspace_id": workspaceID, - "name": wsName, - }) - - return "Message sent.", nil -} - - -func (h *MCPHandler) toolCommitMemory(ctx context.Context, workspaceID string, args map[string]interface{}) (string, error) { - content, _ := args["content"].(string) - scope, _ := args["scope"].(string) - if content == "" { - return "", fmt.Errorf("content is required") - } - if scope == "" { - scope = "LOCAL" - } - - // C3: GLOBAL scope is blocked on the MCP bridge. - if scope == "GLOBAL" { - return "", fmt.Errorf("GLOBAL scope is not permitted via the MCP bridge — use LOCAL or TEAM") - } - if scope != "LOCAL" && scope != "TEAM" { - return "", fmt.Errorf("scope must be LOCAL or TEAM") - } - - memoryID := uuid.New().String() - // SAFE-T1201 (#838): scrub known credential patterns before persistence so - // plain-text API keys pulled in via tool responses can't land in the - // memories table (and leak into shared TEAM scope). Reuses redactSecrets - // already shipped for the HTTP path in PR #881 — this was the MCP-bridge - // sibling the original fix missed. Runs on every write regardless of scope. - content, _ = redactSecrets(workspaceID, content) - _, err := h.database.ExecContext(ctx, ` - INSERT INTO agent_memories (id, workspace_id, content, scope, namespace) - VALUES ($1, $2, $3, $4, $5) - `, memoryID, workspaceID, content, scope, workspaceID) - if err != nil { - log.Printf("MCPHandler.commit_memory workspace=%s: %v", workspaceID, err) - return "", fmt.Errorf("failed to save memory") - } - - return fmt.Sprintf(`{"id":%q,"scope":%q}`, memoryID, scope), nil -} - -func (h *MCPHandler) toolRecallMemory(ctx context.Context, workspaceID string, args map[string]interface{}) (string, error) { - query, _ := args["query"].(string) - scope, _ := args["scope"].(string) - - // C3: GLOBAL scope is blocked on the MCP bridge. - if scope == "GLOBAL" { - return "", fmt.Errorf("GLOBAL scope is not permitted via the MCP bridge — use LOCAL, TEAM, or empty") - } - - var rows *sql.Rows - var err error - - switch scope { - case "LOCAL": - rows, err = h.database.QueryContext(ctx, ` - SELECT id, content, scope, created_at - FROM agent_memories - WHERE workspace_id = $1 AND scope = 'LOCAL' - AND ($2 = '' OR content ILIKE '%' || $2 || '%') - ORDER BY created_at DESC LIMIT 50 - `, workspaceID, query) - case "TEAM": - // Team scope: parent + all siblings. - rows, err = h.database.QueryContext(ctx, ` - SELECT m.id, m.content, m.scope, m.created_at - FROM agent_memories m - JOIN workspaces w ON w.id = m.workspace_id - WHERE m.scope = 'TEAM' - AND w.status != 'removed' - AND (w.id = $1 OR w.parent_id = (SELECT parent_id FROM workspaces WHERE id = $1 AND parent_id IS NOT NULL)) - AND ($2 = '' OR m.content ILIKE '%' || $2 || '%') - ORDER BY m.created_at DESC LIMIT 50 - `, workspaceID, query) - default: - // Empty scope → LOCAL only for the MCP bridge (GLOBAL excluded per C3). - rows, err = h.database.QueryContext(ctx, ` - SELECT id, content, scope, created_at - FROM agent_memories - WHERE workspace_id = $1 AND scope IN ('LOCAL', 'TEAM') - AND ($2 = '' OR content ILIKE '%' || $2 || '%') - ORDER BY created_at DESC LIMIT 50 - `, workspaceID, query) - } - if err != nil { - return "", fmt.Errorf("memory search failed: %w", err) - } - defer rows.Close() - - type memEntry struct { - ID string `json:"id"` - Content string `json:"content"` - Scope string `json:"scope"` - CreatedAt string `json:"created_at"` - } - var results []memEntry - for rows.Next() { - var e memEntry - if err := rows.Scan(&e.ID, &e.Content, &e.Scope, &e.CreatedAt); err != nil { - continue - } - results = append(results, e) - } - if err := rows.Err(); err != nil { - return "", fmt.Errorf("memory scan error: %w", err) - } - - if len(results) == 0 { - return "No memories found.", nil - } - b, _ := json.MarshalIndent(results, "", " ") - return string(b), nil -} - -// isSafeURL and isPrivateOrMetadataIP live in a2a_proxy.go -- same package, -// shared across MCP + A2A proxy call sites. Keeping a single copy avoids -// drift between the two SSRF gates when one is tightened and the other -// isn't. - -// ───────────────────────────────────────────────────────────────────────────── -// Helpers -// ───────────────────────────────────────────────────────────────────────────── - -// mcpResolveURL returns a routable URL for a workspace's A2A server. -// -// Resolution order: -// 1. Docker-internal URL cache (set by provisioner; correct when platform is in Docker) -// 2. Redis URL cache -// 3. DB `url` column fallback, with 127.0.0.1→Docker bridge rewrite when in Docker -// -// SECURITY (F1083 / #1130): all three paths run the returned URL through -// validateAgentURL to block SSRF targets (private IPs, loopback, cloud metadata). -func mcpResolveURL(ctx context.Context, database *sql.DB, workspaceID string) (string, error) { - if platformInDocker { - if url, err := db.GetCachedInternalURL(ctx, workspaceID); err == nil && url != "" { - if err := validateAgentURL(url); err != nil { - return "", fmt.Errorf("workspace %s: forbidden URL from internal cache: %w", workspaceID, err) - } - return url, nil - } - } - if url, err := db.GetCachedURL(ctx, workspaceID); err == nil && url != "" { - if platformInDocker && strings.HasPrefix(url, "http://127.0.0.1:") { - return provisioner.InternalURL(workspaceID), nil - } - if err := validateAgentURL(url); err != nil { - return "", fmt.Errorf("workspace %s: forbidden URL from Redis cache: %w", workspaceID, err) - } - return url, nil - } - - var urlStr sql.NullString - var status string - if err := database.QueryRowContext(ctx, - `SELECT url, status FROM workspaces WHERE id = $1`, workspaceID, - ).Scan(&urlStr, &status); err != nil { - if err == sql.ErrNoRows { - return "", fmt.Errorf("workspace %s not found", workspaceID) - } - return "", fmt.Errorf("workspace lookup failed: %w", err) - } - if !urlStr.Valid || urlStr.String == "" { - return "", fmt.Errorf("workspace %s has no URL (status: %s)", workspaceID, status) - } - if platformInDocker && strings.HasPrefix(urlStr.String, "http://127.0.0.1:") { - return provisioner.InternalURL(workspaceID), nil - } - if err := validateAgentURL(urlStr.String); err != nil { - return "", fmt.Errorf("workspace %s: forbidden URL from DB: %w", workspaceID, err) - } - return urlStr.String, nil -} - -// extractA2AText extracts human-readable text from an A2A JSON-RPC response body. -// Falls back to the raw JSON when no text part can be found. -func extractA2AText(body []byte) string { - var resp map[string]interface{} - if err := json.Unmarshal(body, &resp); err != nil { - return string(body) - } - - // Propagate A2A errors. - if errObj, ok := resp["error"].(map[string]interface{}); ok { - if msg, ok := errObj["message"].(string); ok { - return "[error] " + msg - } - } - - result, ok := resp["result"].(map[string]interface{}) - if !ok { - return string(body) - } - - // Format 1: result.artifacts[0].parts[0].text - if artifacts, ok := result["artifacts"].([]interface{}); ok && len(artifacts) > 0 { - if art, ok := artifacts[0].(map[string]interface{}); ok { - if parts, ok := art["parts"].([]interface{}); ok && len(parts) > 0 { - if part, ok := parts[0].(map[string]interface{}); ok { - if text, ok := part["text"].(string); ok && text != "" { - return text - } - } - } - } - } - - // Format 2: result.message.parts[0].text - if msg, ok := result["message"].(map[string]interface{}); ok { - if parts, ok := msg["parts"].([]interface{}); ok && len(parts) > 0 { - if part, ok := parts[0].(map[string]interface{}); ok { - if text, ok := part["text"].(string); ok && text != "" { - return text - } - } - } - } - - // Fallback: marshal result as JSON. - b, _ := json.Marshal(result) - return string(b) -} - diff --git a/workspace-server/internal/handlers/mcp_tools.go b/workspace-server/internal/handlers/mcp_tools.go new file mode 100644 index 00000000..26df4fdd --- /dev/null +++ b/workspace-server/internal/handlers/mcp_tools.go @@ -0,0 +1,635 @@ +package handlers + +// mcp_tools.go — MCP bridge tool implementations. +// Each tool* method handles one A2A tool: list_peers, get_workspace_info, +// delegate_task, delegate_task_async, check_task_status, send_message_to_user, +// commit_memory, recall_memory. Also contains URL resolution, SSRF checks, +// and A2A response parsing helpers. + +import ( + "bytes" + "context" + "database/sql" + "encoding/json" + "fmt" + "io" + "log" + "net" + "net/http" + "net/url" + "os" + "strings" + "time" + + "github.com/Molecule-AI/molecule-monorepo/platform/internal/db" + "github.com/Molecule-AI/molecule-monorepo/platform/internal/provisioner" + "github.com/Molecule-AI/molecule-monorepo/platform/internal/registry" + "github.com/google/uuid" +) +// ───────────────────────────────────────────────────────────────────────────── +// Tool implementations +// ───────────────────────────────────────────────────────────────────────────── + +func (h *MCPHandler) toolListPeers(ctx context.Context, workspaceID string) (string, error) { + var parentID sql.NullString + err := h.database.QueryRowContext(ctx, + `SELECT parent_id FROM workspaces WHERE id = $1`, workspaceID, + ).Scan(&parentID) + if err == sql.ErrNoRows { + return "", fmt.Errorf("workspace not found") + } + if err != nil { + return "", fmt.Errorf("lookup failed: %w", err) + } + + type peer struct { + ID string `json:"id"` + Name string `json:"name"` + Role string `json:"role"` + Status string `json:"status"` + Tier int `json:"tier"` + } + + var peers []peer + + scanPeers := func(rows *sql.Rows) error { + defer rows.Close() + for rows.Next() { + var p peer + if err := rows.Scan(&p.ID, &p.Name, &p.Role, &p.Status, &p.Tier); err != nil { + return err + } + peers = append(peers, p) + } + return rows.Err() + } + + const cols = `SELECT w.id, w.name, COALESCE(w.role,''), w.status, w.tier` + + // Siblings + if parentID.Valid { + rows, err := h.database.QueryContext(ctx, + cols+` FROM workspaces w WHERE w.parent_id = $1 AND w.id != $2 AND w.status != 'removed'`, + parentID.String, workspaceID) + if err == nil { + _ = scanPeers(rows) + } + } else { + rows, err := h.database.QueryContext(ctx, + cols+` FROM workspaces w WHERE w.parent_id IS NULL AND w.id != $1 AND w.status != 'removed'`, + workspaceID) + if err == nil { + _ = scanPeers(rows) + } + } + + // Children + { + rows, err := h.database.QueryContext(ctx, + cols+` FROM workspaces w WHERE w.parent_id = $1 AND w.status != 'removed'`, + workspaceID) + if err == nil { + _ = scanPeers(rows) + } + } + + // Parent + if parentID.Valid { + rows, err := h.database.QueryContext(ctx, + cols+` FROM workspaces w WHERE w.id = $1 AND w.status != 'removed'`, + parentID.String) + if err == nil { + _ = scanPeers(rows) + } + } + + if len(peers) == 0 { + return "No peers found.", nil + } + + b, _ := json.MarshalIndent(peers, "", " ") + return string(b), nil +} + +func (h *MCPHandler) toolGetWorkspaceInfo(ctx context.Context, workspaceID string) (string, error) { + var id, name, role, status string + var tier int + var parentID sql.NullString + + err := h.database.QueryRowContext(ctx, ` + SELECT id, name, COALESCE(role,''), tier, status, parent_id + FROM workspaces WHERE id = $1 + `, workspaceID).Scan(&id, &name, &role, &tier, &status, &parentID) + if err == sql.ErrNoRows { + return "", fmt.Errorf("workspace not found") + } + if err != nil { + return "", fmt.Errorf("lookup failed: %w", err) + } + + info := map[string]interface{}{ + "id": id, + "name": name, + "role": role, + "tier": tier, + "status": status, + } + if parentID.Valid { + info["parent_id"] = parentID.String + } + b, _ := json.MarshalIndent(info, "", " ") + return string(b), nil +} + +func (h *MCPHandler) toolDelegateTask(ctx context.Context, callerID string, args map[string]interface{}, timeout time.Duration) (string, error) { + targetID, _ := args["workspace_id"].(string) + task, _ := args["task"].(string) + if targetID == "" { + return "", fmt.Errorf("workspace_id is required") + } + if task == "" { + return "", fmt.Errorf("task is required") + } + + if !registry.CanCommunicate(callerID, targetID) { + return "", fmt.Errorf("workspace %s is not authorised to communicate with %s", callerID, targetID) + } + + agentURL, err := mcpResolveURL(ctx, h.database, targetID) + if err != nil { + return "", err + } + // SSRF defence: reject private/metadata URLs before making outbound call. + if err := isSafeURL(agentURL); err != nil { + return "", fmt.Errorf("invalid workspace URL: %w", err) + } + + a2aBody, err := json.Marshal(map[string]interface{}{ + "jsonrpc": "2.0", + "id": uuid.New().String(), + "method": "message/send", + "params": map[string]interface{}{ + "message": map[string]interface{}{ + "role": "user", + "parts": []map[string]interface{}{{"type": "text", "text": task}}, + "messageId": uuid.New().String(), + }, + }, + }) + if err != nil { + return "", fmt.Errorf("failed to build A2A request: %w", err) + } + + reqCtx, cancel := context.WithTimeout(ctx, timeout) + defer cancel() + + httpReq, err := http.NewRequestWithContext(reqCtx, "POST", agentURL+"/a2a", bytes.NewReader(a2aBody)) + if err != nil { + return "", fmt.Errorf("failed to create request: %w", err) + } + httpReq.Header.Set("Content-Type", "application/json") + // X-Workspace-ID identifies this caller to the A2A proxy. The /workspaces/:id/a2a + // endpoint is intentionally outside WorkspaceAuth (agents do not hold bearer tokens + // to peer workspaces). Access control is enforced by CanCommunicate above, which + // already validated callerID → targetID before this request is constructed. + // callerID was authenticated by WorkspaceAuth on the MCP bridge entry point, + // so this header reflects a verified caller identity, not a spoofable value. + httpReq.Header.Set("X-Workspace-ID", callerID) + + resp, err := http.DefaultClient.Do(httpReq) + if err != nil { + return "", fmt.Errorf("A2A call failed: %w", err) + } + defer func() { _ = resp.Body.Close() }() + + body, err := io.ReadAll(io.LimitReader(resp.Body, 1<<20)) + if err != nil { + return "", fmt.Errorf("failed to read response: %w", err) + } + + return extractA2AText(body), nil +} + +func (h *MCPHandler) toolDelegateTaskAsync(ctx context.Context, callerID string, args map[string]interface{}) (string, error) { + targetID, _ := args["workspace_id"].(string) + task, _ := args["task"].(string) + if targetID == "" { + return "", fmt.Errorf("workspace_id is required") + } + if task == "" { + return "", fmt.Errorf("task is required") + } + + if !registry.CanCommunicate(callerID, targetID) { + return "", fmt.Errorf("workspace %s is not authorised to communicate with %s", callerID, targetID) + } + + taskID := uuid.New().String() + + // Fire and forget in a detached goroutine. Use a background context so + // the call is not cancelled when the HTTP request completes. + go func() { + bgCtx, cancel := context.WithTimeout(context.Background(), mcpAsyncCallTimeout) + defer cancel() + + agentURL, err := mcpResolveURL(bgCtx, h.database, targetID) + if err != nil { + log.Printf("MCPHandler.delegate_task_async: resolve URL for %s: %v", targetID, err) + return + } + // SSRF defence: reject private/metadata URLs before making outbound call. + if err := isSafeURL(agentURL); err != nil { + log.Printf("MCPHandler.delegate_task_async: unsafe URL for %s: %v", targetID, err) + return + } + + a2aBody, _ := json.Marshal(map[string]interface{}{ + "jsonrpc": "2.0", + "id": taskID, + "method": "message/send", + "params": map[string]interface{}{ + "message": map[string]interface{}{ + "role": "user", + "parts": []map[string]interface{}{{"type": "text", "text": task}}, + "messageId": uuid.New().String(), + }, + }, + }) + + httpReq, err := http.NewRequestWithContext(bgCtx, "POST", agentURL+"/a2a", bytes.NewReader(a2aBody)) + if err != nil { + log.Printf("MCPHandler.delegate_task_async: create request: %v", err) + return + } + httpReq.Header.Set("Content-Type", "application/json") + httpReq.Header.Set("X-Workspace-ID", callerID) + + resp, err := http.DefaultClient.Do(httpReq) + if err != nil { + log.Printf("MCPHandler.delegate_task_async: A2A call to %s: %v", targetID, err) + return + } + defer func() { _ = resp.Body.Close() }() + // Drain response so the connection can be reused. + _, _ = io.Copy(io.Discard, resp.Body) + }() + + return fmt.Sprintf(`{"task_id":%q,"status":"dispatched","target_id":%q}`, taskID, targetID), nil +} + +func (h *MCPHandler) toolCheckTaskStatus(ctx context.Context, callerID string, args map[string]interface{}) (string, error) { + targetID, _ := args["workspace_id"].(string) + taskID, _ := args["task_id"].(string) + if targetID == "" { + return "", fmt.Errorf("workspace_id is required") + } + if taskID == "" { + return "", fmt.Errorf("task_id is required") + } + + var status, errorDetail sql.NullString + var responseBody []byte + + err := h.database.QueryRowContext(ctx, ` + SELECT status, error_detail, response_body + FROM activity_logs + WHERE workspace_id = $1 + AND target_id = $2 + AND request_body->>'delegation_id' = $3 + ORDER BY created_at DESC + LIMIT 1 + `, callerID, targetID, taskID).Scan(&status, &errorDetail, &responseBody) + if err == sql.ErrNoRows { + return fmt.Sprintf(`{"task_id":%q,"status":"not_found","note":"task not tracked or not yet dispatched"}`, taskID), nil + } + if err != nil { + return "", fmt.Errorf("status lookup failed: %w", err) + } + + result := map[string]interface{}{ + "task_id": taskID, + "status": status.String, + "target_id": targetID, + } + if errorDetail.Valid && errorDetail.String != "" { + result["error"] = errorDetail.String + } + if len(responseBody) > 0 { + result["result"] = extractA2AText(responseBody) + } + b, _ := json.MarshalIndent(result, "", " ") + return string(b), nil +} + +func (h *MCPHandler) toolSendMessageToUser(ctx context.Context, workspaceID string, args map[string]interface{}) (string, error) { + message, _ := args["message"].(string) + if message == "" { + return "", fmt.Errorf("message is required") + } + + // Check send_message_to_user is enabled (C3). + if os.Getenv("MOLECULE_MCP_ALLOW_SEND_MESSAGE") != "true" { + return "", fmt.Errorf("send_message_to_user is not enabled on this MCP bridge (set MOLECULE_MCP_ALLOW_SEND_MESSAGE=true)") + } + + var wsName string + err := h.database.QueryRowContext(ctx, + `SELECT name FROM workspaces WHERE id = $1 AND status != 'removed'`, workspaceID, + ).Scan(&wsName) + if err != nil { + return "", fmt.Errorf("workspace not found") + } + + h.broadcaster.BroadcastOnly(workspaceID, "AGENT_MESSAGE", map[string]interface{}{ + "message": message, + "workspace_id": workspaceID, + "name": wsName, + }) + + return "Message sent.", nil +} + + +func (h *MCPHandler) toolCommitMemory(ctx context.Context, workspaceID string, args map[string]interface{}) (string, error) { + content, _ := args["content"].(string) + scope, _ := args["scope"].(string) + if content == "" { + return "", fmt.Errorf("content is required") + } + if scope == "" { + scope = "LOCAL" + } + + // C3: GLOBAL scope is blocked on the MCP bridge. + if scope == "GLOBAL" { + return "", fmt.Errorf("GLOBAL scope is not permitted via the MCP bridge — use LOCAL or TEAM") + } + if scope != "LOCAL" && scope != "TEAM" { + return "", fmt.Errorf("scope must be LOCAL or TEAM") + } + + memoryID := uuid.New().String() + // SAFE-T1201 (#838): scrub known credential patterns before persistence so + // plain-text API keys pulled in via tool responses can't land in the + // memories table (and leak into shared TEAM scope). Reuses redactSecrets + // already shipped for the HTTP path in PR #881 — this was the MCP-bridge + // sibling the original fix missed. Runs on every write regardless of scope. + content, _ = redactSecrets(workspaceID, content) + _, err := h.database.ExecContext(ctx, ` + INSERT INTO agent_memories (id, workspace_id, content, scope, namespace) + VALUES ($1, $2, $3, $4, $5) + `, memoryID, workspaceID, content, scope, workspaceID) + if err != nil { + log.Printf("MCPHandler.commit_memory workspace=%s: %v", workspaceID, err) + return "", fmt.Errorf("failed to save memory") + } + + return fmt.Sprintf(`{"id":%q,"scope":%q}`, memoryID, scope), nil +} + +func (h *MCPHandler) toolRecallMemory(ctx context.Context, workspaceID string, args map[string]interface{}) (string, error) { + query, _ := args["query"].(string) + scope, _ := args["scope"].(string) + + // C3: GLOBAL scope is blocked on the MCP bridge. + if scope == "GLOBAL" { + return "", fmt.Errorf("GLOBAL scope is not permitted via the MCP bridge — use LOCAL, TEAM, or empty") + } + + var rows *sql.Rows + var err error + + switch scope { + case "LOCAL": + rows, err = h.database.QueryContext(ctx, ` + SELECT id, content, scope, created_at + FROM agent_memories + WHERE workspace_id = $1 AND scope = 'LOCAL' + AND ($2 = '' OR content ILIKE '%' || $2 || '%') + ORDER BY created_at DESC LIMIT 50 + `, workspaceID, query) + case "TEAM": + // Team scope: parent + all siblings. + rows, err = h.database.QueryContext(ctx, ` + SELECT m.id, m.content, m.scope, m.created_at + FROM agent_memories m + JOIN workspaces w ON w.id = m.workspace_id + WHERE m.scope = 'TEAM' + AND w.status != 'removed' + AND (w.id = $1 OR w.parent_id = (SELECT parent_id FROM workspaces WHERE id = $1 AND parent_id IS NOT NULL)) + AND ($2 = '' OR m.content ILIKE '%' || $2 || '%') + ORDER BY m.created_at DESC LIMIT 50 + `, workspaceID, query) + default: + // Empty scope → LOCAL only for the MCP bridge (GLOBAL excluded per C3). + rows, err = h.database.QueryContext(ctx, ` + SELECT id, content, scope, created_at + FROM agent_memories + WHERE workspace_id = $1 AND scope IN ('LOCAL', 'TEAM') + AND ($2 = '' OR content ILIKE '%' || $2 || '%') + ORDER BY created_at DESC LIMIT 50 + `, workspaceID, query) + } + if err != nil { + return "", fmt.Errorf("memory search failed: %w", err) + } + defer rows.Close() + + type memEntry struct { + ID string `json:"id"` + Content string `json:"content"` + Scope string `json:"scope"` + CreatedAt string `json:"created_at"` + } + var results []memEntry + for rows.Next() { + var e memEntry + if err := rows.Scan(&e.ID, &e.Content, &e.Scope, &e.CreatedAt); err != nil { + continue + } + results = append(results, e) + } + if err := rows.Err(); err != nil { + return "", fmt.Errorf("memory scan error: %w", err) + } + + if len(results) == 0 { + return "No memories found.", nil + } + b, _ := json.MarshalIndent(results, "", " ") + return string(b), nil +} + +// isSafeURL validates that a URL resolves to a publicly-routable address, +// preventing A2A requests from being redirected to internal/cloud-metadata +// infrastructure (SSRF, CWE-918). Workspace URLs come from DB/Redis caches +// so we validate before making any outbound HTTP call. +func isSafeURL(rawURL string) error { + u, err := url.Parse(rawURL) + if err != nil { + return fmt.Errorf("invalid URL: %w", err) + } + // Reject non-HTTP(S) schemes. + if u.Scheme != "http" && u.Scheme != "https" { + return fmt.Errorf("forbidden scheme: %s (only http/https allowed)", u.Scheme) + } + host := u.Hostname() + if host == "" { + return fmt.Errorf("empty hostname") + } + // Block direct IP addresses. + if ip := net.ParseIP(host); ip != nil { + if ip.IsLoopback() || ip.IsUnspecified() || ip.IsLinkLocalUnicast() { + return fmt.Errorf("forbidden loopback/unspecified IP: %s", ip) + } + if isPrivateOrMetadataIP(ip) { + return fmt.Errorf("forbidden private/metadata IP: %s", ip) + } + return nil + } + // For hostnames, resolve and validate each returned IP. + addrs, err := net.LookupHost(host) + if err != nil { + // DNS resolution failure — block it. Could be an internal hostname. + return fmt.Errorf("DNS resolution blocked for hostname: %s (%v)", host, err) + } + if len(addrs) == 0 { + return fmt.Errorf("DNS returned no addresses for: %s", host) + } + for _, addr := range addrs { + ip := net.ParseIP(addr) + if ip != nil && (ip.IsLoopback() || ip.IsUnspecified() || ip.IsLinkLocalUnicast() || isPrivateOrMetadataIP(ip)) { + return fmt.Errorf("hostname %s resolves to forbidden IP: %s", host, ip) + } + } + return nil +} + +// isPrivateOrMetadataIP returns true for RFC-1918 private, carrier-grade NAT, +// link-local, and cloud metadata ranges. +func isPrivateOrMetadataIP(ip net.IP) bool { + var privateRanges = []net.IPNet{ + {IP: net.ParseIP("10.0.0.0"), Mask: net.CIDRMask(8, 32)}, + {IP: net.ParseIP("172.16.0.0"), Mask: net.CIDRMask(12, 32)}, + {IP: net.ParseIP("192.168.0.0"), Mask: net.CIDRMask(16, 32)}, + {IP: net.ParseIP("169.254.0.0"), Mask: net.CIDRMask(16, 32)}, + {IP: net.ParseIP("100.64.0.0"), Mask: net.CIDRMask(10, 32)}, + {IP: net.ParseIP("192.0.2.0"), Mask: net.CIDRMask(24, 32)}, + {IP: net.ParseIP("198.51.100.0"), Mask: net.CIDRMask(24, 32)}, + {IP: net.ParseIP("203.0.113.0"), Mask: net.CIDRMask(24, 32)}, + } + ip = ip.To4() + if ip == nil { + return false + } + for _, r := range privateRanges { + if r.Contains(ip) { + return true + } + } + return false +} + +// ───────────────────────────────────────────────────────────────────────────── +// Helpers +// ───────────────────────────────────────────────────────────────────────────── + +// mcpResolveURL returns a routable URL for a workspace's A2A server. +// +// Resolution order: +// 1. Docker-internal URL cache (set by provisioner; correct when platform is in Docker) +// 2. Redis URL cache +// 3. DB `url` column fallback, with 127.0.0.1→Docker bridge rewrite when in Docker +// +// SECURITY (F1083 / #1130): all three paths run the returned URL through +// validateAgentURL to block SSRF targets (private IPs, loopback, cloud metadata). +func mcpResolveURL(ctx context.Context, database *sql.DB, workspaceID string) (string, error) { + if platformInDocker { + if url, err := db.GetCachedInternalURL(ctx, workspaceID); err == nil && url != "" { + if err := validateAgentURL(url); err != nil { + return "", fmt.Errorf("workspace %s: forbidden URL from internal cache: %w", workspaceID, err) + } + return url, nil + } + } + if url, err := db.GetCachedURL(ctx, workspaceID); err == nil && url != "" { + if platformInDocker && strings.HasPrefix(url, "http://127.0.0.1:") { + return provisioner.InternalURL(workspaceID), nil + } + if err := validateAgentURL(url); err != nil { + return "", fmt.Errorf("workspace %s: forbidden URL from Redis cache: %w", workspaceID, err) + } + return url, nil + } + + var urlStr sql.NullString + var status string + if err := database.QueryRowContext(ctx, + `SELECT url, status FROM workspaces WHERE id = $1`, workspaceID, + ).Scan(&urlStr, &status); err != nil { + if err == sql.ErrNoRows { + return "", fmt.Errorf("workspace %s not found", workspaceID) + } + return "", fmt.Errorf("workspace lookup failed: %w", err) + } + if !urlStr.Valid || urlStr.String == "" { + return "", fmt.Errorf("workspace %s has no URL (status: %s)", workspaceID, status) + } + if platformInDocker && strings.HasPrefix(urlStr.String, "http://127.0.0.1:") { + return provisioner.InternalURL(workspaceID), nil + } + if err := validateAgentURL(urlStr.String); err != nil { + return "", fmt.Errorf("workspace %s: forbidden URL from DB: %w", workspaceID, err) + } + return urlStr.String, nil +} + +// extractA2AText extracts human-readable text from an A2A JSON-RPC response body. +// Falls back to the raw JSON when no text part can be found. +func extractA2AText(body []byte) string { + var resp map[string]interface{} + if err := json.Unmarshal(body, &resp); err != nil { + return string(body) + } + + // Propagate A2A errors. + if errObj, ok := resp["error"].(map[string]interface{}); ok { + if msg, ok := errObj["message"].(string); ok { + return "[error] " + msg + } + } + + result, ok := resp["result"].(map[string]interface{}) + if !ok { + return string(body) + } + + // Format 1: result.artifacts[0].parts[0].text + if artifacts, ok := result["artifacts"].([]interface{}); ok && len(artifacts) > 0 { + if art, ok := artifacts[0].(map[string]interface{}); ok { + if parts, ok := art["parts"].([]interface{}); ok && len(parts) > 0 { + if part, ok := parts[0].(map[string]interface{}); ok { + if text, ok := part["text"].(string); ok && text != "" { + return text + } + } + } + } + } + + // Format 2: result.message.parts[0].text + if msg, ok := result["message"].(map[string]interface{}); ok { + if parts, ok := msg["parts"].([]interface{}); ok && len(parts) > 0 { + if part, ok := parts[0].(map[string]interface{}); ok { + if text, ok := part["text"].(string); ok && text != "" { + return text + } + } + } + } + + // Fallback: marshal result as JSON. + b, _ := json.Marshal(result) + return string(b) +} + diff --git a/workspace-server/internal/handlers/org.go b/workspace-server/internal/handlers/org.go index cd59a142..af5ee09a 100644 --- a/workspace-server/internal/handlers/org.go +++ b/workspace-server/internal/handlers/org.go @@ -1,27 +1,21 @@ package handlers +// org.go — core org handler: types, struct, ListTemplates, Import. +// Tree creation logic is in org_import.go; utility helpers in org_helpers.go. + import ( "context" - "encoding/json" "fmt" "log" "net/http" "os" "path/filepath" - "regexp" - "sort" - "strings" - "time" "github.com/Molecule-AI/molecule-monorepo/platform/internal/channels" - "github.com/Molecule-AI/molecule-monorepo/platform/internal/crypto" - "github.com/Molecule-AI/molecule-monorepo/platform/internal/db" "github.com/Molecule-AI/molecule-monorepo/platform/internal/events" "github.com/Molecule-AI/molecule-monorepo/platform/internal/models" "github.com/Molecule-AI/molecule-monorepo/platform/internal/provisioner" - "github.com/Molecule-AI/molecule-monorepo/platform/internal/scheduler" "github.com/gin-gonic/gin" - "github.com/google/uuid" "gopkg.in/yaml.v3" ) @@ -353,747 +347,3 @@ func (h *OrgHandler) Import(c *gin.Context) { c.JSON(status, resp) } -// createWorkspaceTree recursively creates a workspace and its children. -// provisionSem limits concurrent Docker container creation (#1084). -func (h *OrgHandler) createWorkspaceTree(ws OrgWorkspace, parentID *string, defaults OrgDefaults, orgBaseDir string, results *[]map[string]interface{}, provisionSem chan struct{}) error { - // Apply defaults - runtime := ws.Runtime - if runtime == "" { - runtime = defaults.Runtime - } - if runtime == "" { - runtime = "langgraph" - } - model := ws.Model - if model == "" { - model = defaults.Model - } - if model == "" { - if runtime == "claude-code" { - model = "sonnet" - } else { - model = "anthropic:claude-opus-4-7" - } - } - tier := ws.Tier - if tier == 0 { - tier = defaults.Tier - } - if tier == 0 { - tier = 2 - } - - id := uuid.New().String() - awarenessNS := workspaceAwarenessNamespace(id) - - var role interface{} - if ws.Role != "" { - role = ws.Role - } - - // Expand ${VAR} references in workspace_dir against the org's .env files - // before validation. Without this, a template that ships - // `workspace_dir: ${WORKSPACE_DIR}` (so each operator can pick the host - // path to bind-mount) reaches validateWorkspaceDir as the literal - // "${WORKSPACE_DIR}" string and fails with "must be an absolute path". - // Other fields (channel config, prompts) already go through expandWithEnv; - // workspace_dir was the last hold-out. - if ws.WorkspaceDir != "" { - ws.WorkspaceDir = expandWithEnv(ws.WorkspaceDir, loadWorkspaceEnv(orgBaseDir, ws.FilesDir)) - } - - // Validate and convert workspace_dir to NULL if empty - var workspaceDir interface{} - if ws.WorkspaceDir != "" { - if err := validateWorkspaceDir(ws.WorkspaceDir); err != nil { - return fmt.Errorf("workspace %s: %w", ws.Name, err) - } - workspaceDir = ws.WorkspaceDir - } - - // #65: validate workspace_access (defaults to "none" when empty). - workspaceAccess := ws.WorkspaceAccess - if workspaceAccess == "" { - workspaceAccess = provisioner.WorkspaceAccessNone - } - if err := provisioner.ValidateWorkspaceAccess(workspaceAccess, ws.WorkspaceDir); err != nil { - return fmt.Errorf("workspace %s: %w", ws.Name, err) - } - - ctx := context.Background() - - // Insert workspace - _, err := db.DB.ExecContext(ctx, ` - INSERT INTO workspaces (id, name, role, tier, runtime, awareness_namespace, status, parent_id, workspace_dir, workspace_access) - VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10) - `, id, ws.Name, role, tier, runtime, awarenessNS, "provisioning", parentID, workspaceDir, workspaceAccess) - if err != nil { - log.Printf("Org import: failed to create %s: %v", ws.Name, err) - return fmt.Errorf("failed to create %s: %w", ws.Name, err) - } - - // Canvas layout with coordinates from YAML - if _, err := db.DB.ExecContext(ctx, `INSERT INTO canvas_layouts (workspace_id, x, y) VALUES ($1, $2, $3)`, id, ws.Canvas.X, ws.Canvas.Y); err != nil { - log.Printf("Org import: canvas layout insert failed for %s: %v", ws.Name, err) - } - - // Broadcast - h.broadcaster.RecordAndBroadcast(ctx, "WORKSPACE_PROVISIONING", id, map[string]interface{}{ - "name": ws.Name, "tier": tier, - }) - - // Seed initial memories from workspace config or defaults (issue #1050). - // Per-workspace initial_memories override defaults; if workspace has none, - // fall back to defaults.initial_memories. - wsMemories := ws.InitialMemories - if len(wsMemories) == 0 { - wsMemories = defaults.InitialMemories - } - seedInitialMemories(ctx, id, wsMemories, awarenessNS) - - // Handle external workspaces - if ws.External { - if _, err := db.DB.ExecContext(ctx, `UPDATE workspaces SET status = 'online', url = $1 WHERE id = $2`, ws.URL, id); err != nil { - log.Printf("Org import: external workspace status update failed for %s: %v", ws.Name, err) - } - h.broadcaster.RecordAndBroadcast(ctx, "WORKSPACE_ONLINE", id, map[string]interface{}{ - "name": ws.Name, "external": true, - }) - } else if h.provisioner != nil { - // Provision container - payload := models.CreateWorkspacePayload{ - Name: ws.Name, Tier: tier, Runtime: runtime, Model: model, - WorkspaceDir: ws.WorkspaceDir, - WorkspaceAccess: workspaceAccess, - } - templatePath := "" - if ws.Template != "" { - // `template` comes from the uploaded YAML — treat as untrusted. - // Only accept paths that stay inside h.configsDir. - if tp, err := resolveInsideRoot(h.configsDir, ws.Template); err == nil { - if _, statErr := os.Stat(tp); statErr == nil { - templatePath = tp - } - } - } - if templatePath == "" { - // #241: sanitizeRuntime() allowlists the runtime string so a - // crafted org.yaml cannot use it as a path-traversal oracle. - safeRuntime := sanitizeRuntime(runtime) - runtimeDefault := filepath.Join(h.configsDir, safeRuntime+"-default") - if _, err := os.Stat(runtimeDefault); err == nil { - templatePath = runtimeDefault - } - } - - // Always generate default config.yaml (runtime, model, tier, etc.) - configFiles := h.workspace.ensureDefaultConfig(id, payload) - - // Copy files_dir contents on top (system-prompt.md, CLAUDE.md, skills/, etc.) - // Uses templatePath for CopyTemplateToContainer — runs AFTER configFiles are written - if ws.FilesDir != "" && orgBaseDir != "" { - // `files_dir` also comes from untrusted YAML. Join inside orgBaseDir - // (already validated above) and reject anything that escapes. - if filesPath, err := resolveInsideRoot(orgBaseDir, ws.FilesDir); err == nil { - if info, statErr := os.Stat(filesPath); statErr == nil && info.IsDir() { - templatePath = filesPath - } - } - } - - // Pre-install plugins: copy from registry into configFiles as plugins//*. - // Per-workspace plugins UNION with defaults.plugins (issue #68). - // A leading "!" or "-" on a per-workspace entry opts that plugin out. - plugins := mergePlugins(defaults.Plugins, ws.Plugins) - if len(plugins) > 0 { - if configFiles == nil { - configFiles = map[string][]byte{} - } - pluginsBase, _ := filepath.Abs(filepath.Join(h.configsDir, "..", "plugins")) - for _, pluginName := range plugins { - pluginSrc := filepath.Join(pluginsBase, pluginName) - if info, err := os.Stat(pluginSrc); err != nil || !info.IsDir() { - log.Printf("Org import: plugin %s not found at %s, skipping", pluginName, pluginSrc) - continue - } - filepath.Walk(pluginSrc, func(path string, info os.FileInfo, err error) error { - if err != nil || info.IsDir() { - return nil - } - rel, _ := filepath.Rel(pluginSrc, path) - data, readErr := os.ReadFile(path) - if readErr == nil { - configFiles["plugins/"+pluginName+"/"+rel] = data - } - return nil - }) - } - } - - // Render category_routing into config.yaml so the agent can read its routing - // table at runtime without hardcoded role names in prompts (issue #51). - // Per-workspace keys replace defaults per-key (empty list drops the key); - // see mergeCategoryRouting for exact semantics. - routing := mergeCategoryRouting(defaults.CategoryRouting, ws.CategoryRouting) - if len(routing) > 0 { - if configFiles == nil { - configFiles = map[string][]byte{} - } - block, err := renderCategoryRoutingYAML(routing) - if err != nil { - log.Printf("Org import: failed to render category_routing for %s: %v", ws.Name, err) - } else { - configFiles["config.yaml"] = appendYAMLBlock(configFiles["config.yaml"], block) - } - } - - // Resolve initial_prompt — inline wins, then file-ref, then defaults - // (inline → file → defaults.inline → defaults.file). File refs are - // rooted at // per resolvePromptRef semantics. - initialPrompt, err := resolvePromptRef(ws.InitialPrompt, ws.InitialPromptFile, orgBaseDir, ws.FilesDir) - if err != nil { - log.Printf("Org import: failed to resolve initial_prompt for %s: %v", ws.Name, err) - } - if initialPrompt == "" { - // Fall back to defaults. Defaults live at the org root, so they - // resolve with empty filesDir (relative to orgBaseDir). - var defaultErr error - initialPrompt, defaultErr = resolvePromptRef(defaults.InitialPrompt, defaults.InitialPromptFile, orgBaseDir, "") - if defaultErr != nil { - log.Printf("Org import: failed to resolve defaults.initial_prompt for %s: %v", ws.Name, defaultErr) - } - } - if initialPrompt != "" { - if configFiles == nil { - configFiles = map[string][]byte{} - } - // Append initial_prompt to config.yaml using YAML block scalar. - // Trim each line to avoid trailing whitespace issues. - trimmed := strings.TrimSpace(initialPrompt) - lines := strings.Split(trimmed, "\n") - for i, line := range lines { - lines[i] = strings.TrimRight(line, " \t") - } - indented := strings.Join(lines, "\n ") - configFiles["config.yaml"] = appendYAMLBlock(configFiles["config.yaml"], fmt.Sprintf("initial_prompt: |\n %s\n", indented)) - log.Printf("Org import: injected initial_prompt (%d chars) into config.yaml for %s", len(trimmed), ws.Name) - } - - // Resolve idle_prompt — same precedence (ws inline → ws file → defaults). - // Inject into config.yaml alongside idle_interval_seconds so the - // workspace's heartbeat loop picks up the idle-reflection cadence on - // boot (see workspace/heartbeat.py + config.py). - idlePrompt, err := resolvePromptRef(ws.IdlePrompt, ws.IdlePromptFile, orgBaseDir, ws.FilesDir) - if err != nil { - log.Printf("Org import: failed to resolve idle_prompt for %s: %v", ws.Name, err) - } - if idlePrompt == "" { - var defaultErr error - idlePrompt, defaultErr = resolvePromptRef(defaults.IdlePrompt, defaults.IdlePromptFile, orgBaseDir, "") - if defaultErr != nil { - log.Printf("Org import: failed to resolve defaults.idle_prompt for %s: %v", ws.Name, defaultErr) - } - } - idleInterval := ws.IdleIntervalSeconds - if idleInterval == 0 { - idleInterval = defaults.IdleIntervalSeconds - } - if idlePrompt != "" { - if configFiles == nil { - configFiles = map[string][]byte{} - } - trimmed := strings.TrimSpace(idlePrompt) - lines := strings.Split(trimmed, "\n") - for i, line := range lines { - lines[i] = strings.TrimRight(line, " \t") - } - indented := strings.Join(lines, "\n ") - // idle_interval_seconds belongs with idle_prompt — empty idle_prompt - // means the idle loop never fires regardless of interval, so we - // only emit interval when there's a body to go with it. - if idleInterval <= 0 { - idleInterval = 600 // same default as workspace/config.py - } - block := fmt.Sprintf("idle_interval_seconds: %d\nidle_prompt: |\n %s\n", idleInterval, indented) - configFiles["config.yaml"] = appendYAMLBlock(configFiles["config.yaml"], block) - log.Printf("Org import: injected idle_prompt (%d chars, interval=%ds) into config.yaml for %s", len(trimmed), idleInterval, ws.Name) - } - - // Inline system_prompt (only if no files_dir provides one) - if ws.SystemPrompt != "" { - if configFiles == nil { - configFiles = map[string][]byte{} - } - configFiles["system-prompt.md"] = []byte(ws.SystemPrompt) - } - - // Inject secrets from .env files as workspace secrets. - // Resolution: workspace .env → org root .env (workspace overrides org root). - // Each line: KEY=VALUE → stored as encrypted workspace secret. - envVars := map[string]string{} - if orgBaseDir != "" { - // 1. Org root .env (shared defaults) - parseEnvFile(filepath.Join(orgBaseDir, ".env"), envVars) - // 2. Workspace-specific .env (overrides) - if ws.FilesDir != "" { - parseEnvFile(filepath.Join(orgBaseDir, ws.FilesDir, ".env"), envVars) - } - } - // Store as workspace secrets via DB (encrypted if key is set, raw otherwise) - for key, value := range envVars { - var encrypted []byte - if crypto.IsEnabled() { - var err error - encrypted, err = crypto.Encrypt([]byte(value)) - if err != nil { - log.Printf("Org import: failed to encrypt secret %s for %s: %v", key, ws.Name, err) - continue - } - } else { - encrypted = []byte(value) // store raw when encryption disabled - } - if _, err := db.DB.ExecContext(ctx, ` - INSERT INTO workspace_secrets (workspace_id, key, encrypted_value) - VALUES ($1, $2, $3) - ON CONFLICT (workspace_id, key) DO UPDATE SET encrypted_value = $3, updated_at = now() - `, id, key, encrypted); err != nil { - log.Printf("Org import: failed to insert secret %s for %s: %v", key, ws.Name, err) - } - } - - // #1084: limit concurrent Docker provisioning via semaphore. - provisionSem <- struct{}{} // acquire - go func(wID, tPath string, cFiles map[string][]byte, p models.CreateWorkspacePayload) { - defer func() { <-provisionSem }() // release - h.workspace.provisionWorkspace(wID, tPath, cFiles, p) - }(id, templatePath, configFiles, payload) - } - - // Insert schedules if defined. Resolve each schedule's prompt body from - // either inline `prompt:` or `prompt_file:` (file ref relative to the - // workspace's files_dir). Inline wins; empty prompt after resolution is - // a configuration error (cron with no body would never do anything). - for _, sched := range ws.Schedules { - tz := sched.Timezone - if tz == "" { - tz = "UTC" - } - enabled := true - if sched.Enabled != nil { - enabled = *sched.Enabled - } - prompt, promptErr := resolvePromptRef(sched.Prompt, sched.PromptFile, orgBaseDir, ws.FilesDir) - if promptErr != nil { - log.Printf("Org import: failed to resolve prompt for schedule '%s' on %s: %v — skipping insert", sched.Name, ws.Name, promptErr) - continue - } - if prompt == "" { - log.Printf("Org import: schedule '%s' on %s has empty prompt (neither prompt nor prompt_file set) — skipping insert", sched.Name, ws.Name) - continue - } - // #722: surface the error rather than silently using time.Time{} (zero) - // which lib/pq stores as 0001-01-01 and may confuse the fire query. - nextRun, nextRunErr := scheduler.ComputeNextRun(sched.CronExpr, tz, time.Now()) - if nextRunErr != nil { - log.Printf("Org import: invalid cron expression for schedule '%s' on %s: %v — skipping insert", - sched.Name, ws.Name, nextRunErr) - continue - } - if _, err := db.DB.ExecContext(context.Background(), orgImportScheduleSQL, - id, sched.Name, sched.CronExpr, tz, prompt, enabled, nextRun); err != nil { - log.Printf("Org import: failed to upsert schedule '%s' for %s: %v", sched.Name, ws.Name, err) - } else { - log.Printf("Org import: schedule '%s' (%s, %d chars) upserted for %s (source=template)", sched.Name, sched.CronExpr, len(prompt), ws.Name) - } - } - - // Insert channels if defined (Telegram, Slack, etc.). Config values - // support ${VAR} expansion from .env files. The manager is reloaded - // once at the end of org import (in Import), not per-workspace. - channelEnv := loadWorkspaceEnv(orgBaseDir, ws.FilesDir) - wsChannelsCreated := []string{} - wsChannelsSkipped := []map[string]string{} - // skipChannel records a skipped channel with consistent shape across all reasons. - skipChannel := func(channelType, reason string) { - wsChannelsSkipped = append(wsChannelsSkipped, map[string]string{ - "workspace": ws.Name, - "type": channelType, // empty string when type field was missing - "reason": reason, - }) - } - - for _, ch := range ws.Channels { - if ch.Type == "" { - skipChannel("", "empty type") - log.Printf("Org import: skipping channel with empty type for %s", ws.Name) - continue - } - // Validate adapter exists upfront — fail fast instead of inserting orphan rows - adapter, ok := channels.GetAdapter(ch.Type) - if !ok { - skipChannel(ch.Type, "unknown adapter") - log.Printf("Org import: skipping %s channel for %s — no adapter registered", ch.Type, ws.Name) - continue - } - - expandedConfig := make(map[string]interface{}, len(ch.Config)) - missing := []string{} - for k, v := range ch.Config { - expanded := expandWithEnv(v, channelEnv) - if hasUnresolvedVarRef(v, expanded) { - missing = append(missing, v) - } - expandedConfig[k] = expanded - } - if len(missing) > 0 { - skipChannel(ch.Type, fmt.Sprintf("missing env: %v", missing)) - log.Printf("Org import: skipping %s channel for %s — env vars not set: %v", ch.Type, ws.Name, missing) - continue - } - - // Adapter-level config validation - if err := adapter.ValidateConfig(expandedConfig); err != nil { - skipChannel(ch.Type, err.Error()) - log.Printf("Org import: skipping %s channel for %s — invalid config: %v", ch.Type, ws.Name, err) - continue - } - - configJSON, err := json.Marshal(expandedConfig) - if err != nil { - log.Printf("Org import: failed to marshal config for %s channel: %v", ch.Type, err) - continue - } - allowedJSON, err := json.Marshal(ch.AllowedUsers) - if err != nil { - log.Printf("Org import: failed to marshal allowed_users for %s channel: %v", ch.Type, err) - continue - } - enabled := true - if ch.Enabled != nil { - enabled = *ch.Enabled - } - // Idempotent insert — if same workspace+type already exists, update config - if _, err := db.DB.ExecContext(context.Background(), ` - INSERT INTO workspace_channels (workspace_id, channel_type, channel_config, enabled, allowed_users) - VALUES ($1, $2, $3::jsonb, $4, $5::jsonb) - ON CONFLICT (workspace_id, channel_type) DO UPDATE - SET channel_config = EXCLUDED.channel_config, - enabled = EXCLUDED.enabled, - allowed_users = EXCLUDED.allowed_users, - updated_at = now() - `, id, ch.Type, string(configJSON), enabled, string(allowedJSON)); err != nil { - log.Printf("Org import: failed to create %s channel for %s: %v", ch.Type, ws.Name, err) - } else { - wsChannelsCreated = append(wsChannelsCreated, ch.Type) - log.Printf("Org import: %s channel created for %s", ch.Type, ws.Name) - } - } - - resultEntry := map[string]interface{}{ - "id": id, - "name": ws.Name, - "tier": tier, - } - if len(wsChannelsCreated) > 0 { - resultEntry["channels"] = wsChannelsCreated - } - if len(wsChannelsSkipped) > 0 { - resultEntry["channels_skipped"] = wsChannelsSkipped - } - *results = append(*results, resultEntry) - - // Recurse into children. Brief pacing avoids overwhelming Docker when - // creating many containers in sequence; container provisioning runs in - // goroutines so the main createWorkspaceTree returns quickly. - for _, child := range ws.Children { - if err := h.createWorkspaceTree(child, &id, defaults, orgBaseDir, results, provisionSem); err != nil { - return err - } - time.Sleep(workspaceCreatePacingMs * time.Millisecond) - } - - return nil -} - -func countWorkspaces(workspaces []OrgWorkspace) int { - count := len(workspaces) - for _, ws := range workspaces { - count += countWorkspaces(ws.Children) - } - return count -} - -// resolvePromptRef reads a prompt body from either an inline string or a -// file ref relative to the workspace's files_dir. Inline always wins when -// both are non-empty (caller-provided inline is more authoritative than a -// file path that may not exist yet during dev loops). -// -// File resolution: -// - `//` when filesDir is non-empty -// - `/` when filesDir is empty (defaults-level refs) -// -// Both paths go through resolveInsideRoot so a crafted fileRef can't escape -// the org template directory via traversal (same defense the files_dir -// copy-step uses). -// -// Returns (resolved body, error). If both inline and fileRef are empty, -// returns ("", nil) — caller decides whether that's a problem. -func resolvePromptRef(inline, fileRef, orgBaseDir, filesDir string) (string, error) { - if inline != "" { - return inline, nil - } - if fileRef == "" { - return "", nil - } - if orgBaseDir == "" { - // Inline-only template (POST /org/import with a raw Template in the - // JSON body, not a dir). File refs can't be resolved — surface the - // problem rather than silently returning empty. - return "", fmt.Errorf("prompt_file %q requires a dir-based org template (no orgBaseDir in inline-template mode)", fileRef) - } - searchRoot := orgBaseDir - if filesDir != "" { - p, err := resolveInsideRoot(orgBaseDir, filesDir) - if err != nil { - return "", fmt.Errorf("invalid files_dir %q: %w", filesDir, err) - } - searchRoot = p - } - abs, err := resolveInsideRoot(searchRoot, fileRef) - if err != nil { - return "", fmt.Errorf("invalid prompt_file %q: %w", fileRef, err) - } - data, err := os.ReadFile(abs) - if err != nil { - return "", fmt.Errorf("read prompt_file %q: %w", fileRef, err) - } - return string(data), nil -} - -// envVarRefPattern matches actual ${VAR} or $VAR references (not literal $). -// Used to detect unresolved placeholders without false positives like "$5". -var envVarRefPattern = regexp.MustCompile(`\$\{?[A-Za-z_][A-Za-z0-9_]*\}?`) - -// hasUnresolvedVarRef returns true if the original string had a ${VAR} or $VAR -// reference that the expanded string didn't fully replace (i.e. the var was unset). -func hasUnresolvedVarRef(original, expanded string) bool { - if !envVarRefPattern.MatchString(original) { - return false // no var refs to resolve - } - // If expansion produced the same string and that string still has refs, unresolved. - // If expansion stripped them to "", also unresolved. - return expanded == "" || envVarRefPattern.MatchString(expanded) -} - -// expandWithEnv expands ${VAR} and $VAR references in s using the env map. -// Falls back to the platform process env if a var isn't in the map. -func expandWithEnv(s string, env map[string]string) string { - return os.Expand(s, func(key string) string { - if v, ok := env[key]; ok { - return v - } - return os.Getenv(key) - }) -} - -// loadWorkspaceEnv reads the org root .env and the workspace-specific .env -// (workspace overrides org root). Used by both secret injection and channel -// config expansion. -func loadWorkspaceEnv(orgBaseDir, filesDir string) map[string]string { - envVars := map[string]string{} - if orgBaseDir == "" { - return envVars - } - parseEnvFile(filepath.Join(orgBaseDir, ".env"), envVars) - if filesDir != "" { - parseEnvFile(filepath.Join(orgBaseDir, filesDir, ".env"), envVars) - } - return envVars -} - -// parseEnvFile reads a .env file and adds KEY=VALUE pairs to the map. -// Skips comments (#) and empty lines. Values can be quoted. -func parseEnvFile(path string, out map[string]string) { - data, err := os.ReadFile(path) - if err != nil { - return - } - for _, line := range strings.Split(string(data), "\n") { - line = strings.TrimSpace(line) - if line == "" || strings.HasPrefix(line, "#") { - continue - } - parts := strings.SplitN(line, "=", 2) - if len(parts) != 2 { - continue - } - key := strings.TrimSpace(parts[0]) - value := strings.TrimSpace(parts[1]) - // Strip surrounding quotes - if len(value) >= 2 && ((value[0] == '"' && value[len(value)-1] == '"') || (value[0] == '\'' && value[len(value)-1] == '\'')) { - value = value[1 : len(value)-1] - } - if key != "" && value != "" { - out[key] = value - } - } -} - -// mergeCategoryRouting unions defaults.category_routing with per-workspace -// category_routing. Workspace-level keys override the default's value for that -// key (the role list is replaced wholesale, not unioned per-key, so a workspace -// can narrow a category — e.g. "infra: [DevOps Only]"). Empty role lists drop -// the category entirely. See issue #51. -func mergeCategoryRouting(defaultRouting, wsRouting map[string][]string) map[string][]string { - out := map[string][]string{} - for k, v := range defaultRouting { - if k == "" || len(v) == 0 { - continue - } - cp := make([]string, len(v)) - copy(cp, v) - out[k] = cp - } - for k, v := range wsRouting { - if k == "" { - continue - } - if len(v) == 0 { - // Empty list = explicit "drop this category for this workspace" - delete(out, k) - continue - } - cp := make([]string, len(v)) - copy(cp, v) - out[k] = cp - } - return out -} - -// renderCategoryRoutingYAML emits a deterministic YAML block of the form: -// -// category_routing: -// security: [Backend Engineer, DevOps] -// ui: [Frontend Engineer] -// -// Keys are sorted for stable, test-friendly output. Uses yaml.Node + yaml.Marshal -// so role names containing YAML-reserved characters (colons, quotes, unicode line -// separators, etc.) are escaped by the YAML library — no ad-hoc quoting. -func renderCategoryRoutingYAML(routing map[string][]string) (string, error) { - if len(routing) == 0 { - return "", nil - } - keys := make([]string, 0, len(routing)) - for k := range routing { - if k == "" { - continue - } - keys = append(keys, k) - } - sort.Strings(keys) - - inner := &yaml.Node{Kind: yaml.MappingNode} - for _, k := range keys { - keyNode := &yaml.Node{Kind: yaml.ScalarNode, Value: k} - valNode := &yaml.Node{Kind: yaml.SequenceNode, Style: yaml.FlowStyle} - for _, role := range routing[k] { - valNode.Content = append(valNode.Content, &yaml.Node{Kind: yaml.ScalarNode, Value: role}) - } - inner.Content = append(inner.Content, keyNode, valNode) - } - doc := &yaml.Node{Kind: yaml.MappingNode} - doc.Content = []*yaml.Node{ - {Kind: yaml.ScalarNode, Value: "category_routing"}, - inner, - } - out, err := yaml.Marshal(doc) - if err != nil { - return "", err - } - return string(out), nil -} - -// appendYAMLBlock concatenates a YAML fragment to an existing buffer, guaranteeing -// a newline boundary between them. Upstream code writes config.yaml in fragments -// (base template → category_routing → initial_prompt) and the base isn't -// guaranteed to end in \n, which would merge the last line into the next block. -func appendYAMLBlock(existing []byte, block string) []byte { - if len(existing) > 0 && existing[len(existing)-1] != '\n' { - existing = append(existing, '\n') - } - return append(existing, []byte(block)...) -} - -// mergePlugins returns the union of defaults and per-workspace plugin lists -// (deduplicated, defaults first). A per-workspace entry starting with "!" or -// "-" opts that plugin OUT of the union. See issue #68. -func mergePlugins(defaultPlugins, wsPlugins []string) []string { - seen := map[string]bool{} - out := make([]string, 0, len(defaultPlugins)+len(wsPlugins)) - for _, p := range defaultPlugins { - if p == "" || seen[p] { - continue - } - seen[p] = true - out = append(out, p) - } - for _, p := range wsPlugins { - if p == "" { - continue - } - if strings.HasPrefix(p, "!") || strings.HasPrefix(p, "-") { - target := strings.TrimLeft(p, "!-") - if target == "" { - continue - } - if seen[target] { - delete(seen, target) - filtered := out[:0] - for _, existing := range out { - if existing != target { - filtered = append(filtered, existing) - } - } - out = filtered - } - continue - } - if !seen[p] { - seen[p] = true - out = append(out, p) - } - } - return out -} - -// resolveInsideRoot joins `userPath` onto `root` and ensures the lexically -// cleaned result stays inside root. Rejects absolute paths outright and -// anything containing ".." that would escape the root. -// -// Both arguments are resolved to absolute paths via filepath.Abs before the -// prefix check so a root passed as a relative path still works correctly. -// Follows Go's standard pattern for SSRF-class path sanitization; using -// strings.HasPrefix on an absolute-path pair plus the separator guard rejects -// sibling directories that share a prefix (e.g. "/foo" vs "/foobar"). -func resolveInsideRoot(root, userPath string) (string, error) { - if userPath == "" { - return "", fmt.Errorf("path is empty") - } - if filepath.IsAbs(userPath) { - return "", fmt.Errorf("absolute paths are not allowed") - } - absRoot, err := filepath.Abs(root) - if err != nil { - return "", fmt.Errorf("root abs: %w", err) - } - joined := filepath.Join(absRoot, userPath) - absJoined, err := filepath.Abs(joined) - if err != nil { - return "", fmt.Errorf("joined abs: %w", err) - } - // Allow exact-root match (rare but valid) and any descendant. - if absJoined != absRoot && !strings.HasPrefix(absJoined, absRoot+string(filepath.Separator)) { - return "", fmt.Errorf("path escapes root") - } - return absJoined, nil -} diff --git a/workspace-server/internal/handlers/org_helpers.go b/workspace-server/internal/handlers/org_helpers.go new file mode 100644 index 00000000..f84baf3d --- /dev/null +++ b/workspace-server/internal/handlers/org_helpers.go @@ -0,0 +1,290 @@ +package handlers + +// org_helpers.go — utility functions for org template processing. +// Prompt resolution, env file parsing, category routing, plugin merging, +// path sanitization. + +import ( + "fmt" + "os" + "path/filepath" + "regexp" + "sort" + "strings" + + "gopkg.in/yaml.v3" +) +// resolvePromptRef reads a prompt body from either an inline string or a +// file ref relative to the workspace's files_dir. Inline always wins when +// both are non-empty (caller-provided inline is more authoritative than a +// file path that may not exist yet during dev loops). +// +// File resolution: +// - `//` when filesDir is non-empty +// - `/` when filesDir is empty (defaults-level refs) +// +// Both paths go through resolveInsideRoot so a crafted fileRef can't escape +// the org template directory via traversal (same defense the files_dir +// copy-step uses). +// +// Returns (resolved body, error). If both inline and fileRef are empty, +// returns ("", nil) — caller decides whether that's a problem. +func resolvePromptRef(inline, fileRef, orgBaseDir, filesDir string) (string, error) { + if inline != "" { + return inline, nil + } + if fileRef == "" { + return "", nil + } + if orgBaseDir == "" { + // Inline-only template (POST /org/import with a raw Template in the + // JSON body, not a dir). File refs can't be resolved — surface the + // problem rather than silently returning empty. + return "", fmt.Errorf("prompt_file %q requires a dir-based org template (no orgBaseDir in inline-template mode)", fileRef) + } + searchRoot := orgBaseDir + if filesDir != "" { + p, err := resolveInsideRoot(orgBaseDir, filesDir) + if err != nil { + return "", fmt.Errorf("invalid files_dir %q: %w", filesDir, err) + } + searchRoot = p + } + abs, err := resolveInsideRoot(searchRoot, fileRef) + if err != nil { + return "", fmt.Errorf("invalid prompt_file %q: %w", fileRef, err) + } + data, err := os.ReadFile(abs) + if err != nil { + return "", fmt.Errorf("read prompt_file %q: %w", fileRef, err) + } + return string(data), nil +} + +// envVarRefPattern matches actual ${VAR} or $VAR references (not literal $). +// Used to detect unresolved placeholders without false positives like "$5". +var envVarRefPattern = regexp.MustCompile(`\$\{?[A-Za-z_][A-Za-z0-9_]*\}?`) + +// hasUnresolvedVarRef returns true if the original string had a ${VAR} or $VAR +// reference that the expanded string didn't fully replace (i.e. the var was unset). +func hasUnresolvedVarRef(original, expanded string) bool { + if !envVarRefPattern.MatchString(original) { + return false // no var refs to resolve + } + // If expansion produced the same string and that string still has refs, unresolved. + // If expansion stripped them to "", also unresolved. + return expanded == "" || envVarRefPattern.MatchString(expanded) +} + +// expandWithEnv expands ${VAR} and $VAR references in s using the env map. +// Falls back to the platform process env if a var isn't in the map. +func expandWithEnv(s string, env map[string]string) string { + return os.Expand(s, func(key string) string { + if v, ok := env[key]; ok { + return v + } + return os.Getenv(key) + }) +} + +// loadWorkspaceEnv reads the org root .env and the workspace-specific .env +// (workspace overrides org root). Used by both secret injection and channel +// config expansion. +func loadWorkspaceEnv(orgBaseDir, filesDir string) map[string]string { + envVars := map[string]string{} + if orgBaseDir == "" { + return envVars + } + parseEnvFile(filepath.Join(orgBaseDir, ".env"), envVars) + if filesDir != "" { + parseEnvFile(filepath.Join(orgBaseDir, filesDir, ".env"), envVars) + } + return envVars +} + +// parseEnvFile reads a .env file and adds KEY=VALUE pairs to the map. +// Skips comments (#) and empty lines. Values can be quoted. +func parseEnvFile(path string, out map[string]string) { + data, err := os.ReadFile(path) + if err != nil { + return + } + for _, line := range strings.Split(string(data), "\n") { + line = strings.TrimSpace(line) + if line == "" || strings.HasPrefix(line, "#") { + continue + } + parts := strings.SplitN(line, "=", 2) + if len(parts) != 2 { + continue + } + key := strings.TrimSpace(parts[0]) + value := strings.TrimSpace(parts[1]) + // Strip surrounding quotes + if len(value) >= 2 && ((value[0] == '"' && value[len(value)-1] == '"') || (value[0] == '\'' && value[len(value)-1] == '\'')) { + value = value[1 : len(value)-1] + } + if key != "" && value != "" { + out[key] = value + } + } +} + +// mergeCategoryRouting unions defaults.category_routing with per-workspace +// category_routing. Workspace-level keys override the default's value for that +// key (the role list is replaced wholesale, not unioned per-key, so a workspace +// can narrow a category — e.g. "infra: [DevOps Only]"). Empty role lists drop +// the category entirely. See issue #51. +func mergeCategoryRouting(defaultRouting, wsRouting map[string][]string) map[string][]string { + out := map[string][]string{} + for k, v := range defaultRouting { + if k == "" || len(v) == 0 { + continue + } + cp := make([]string, len(v)) + copy(cp, v) + out[k] = cp + } + for k, v := range wsRouting { + if k == "" { + continue + } + if len(v) == 0 { + // Empty list = explicit "drop this category for this workspace" + delete(out, k) + continue + } + cp := make([]string, len(v)) + copy(cp, v) + out[k] = cp + } + return out +} + +// renderCategoryRoutingYAML emits a deterministic YAML block of the form: +// +// category_routing: +// security: [Backend Engineer, DevOps] +// ui: [Frontend Engineer] +// +// Keys are sorted for stable, test-friendly output. Uses yaml.Node + yaml.Marshal +// so role names containing YAML-reserved characters (colons, quotes, unicode line +// separators, etc.) are escaped by the YAML library — no ad-hoc quoting. +func renderCategoryRoutingYAML(routing map[string][]string) (string, error) { + if len(routing) == 0 { + return "", nil + } + keys := make([]string, 0, len(routing)) + for k := range routing { + if k == "" { + continue + } + keys = append(keys, k) + } + sort.Strings(keys) + + inner := &yaml.Node{Kind: yaml.MappingNode} + for _, k := range keys { + keyNode := &yaml.Node{Kind: yaml.ScalarNode, Value: k} + valNode := &yaml.Node{Kind: yaml.SequenceNode, Style: yaml.FlowStyle} + for _, role := range routing[k] { + valNode.Content = append(valNode.Content, &yaml.Node{Kind: yaml.ScalarNode, Value: role}) + } + inner.Content = append(inner.Content, keyNode, valNode) + } + doc := &yaml.Node{Kind: yaml.MappingNode} + doc.Content = []*yaml.Node{ + {Kind: yaml.ScalarNode, Value: "category_routing"}, + inner, + } + out, err := yaml.Marshal(doc) + if err != nil { + return "", err + } + return string(out), nil +} + +// appendYAMLBlock concatenates a YAML fragment to an existing buffer, guaranteeing +// a newline boundary between them. Upstream code writes config.yaml in fragments +// (base template → category_routing → initial_prompt) and the base isn't +// guaranteed to end in \n, which would merge the last line into the next block. +func appendYAMLBlock(existing []byte, block string) []byte { + if len(existing) > 0 && existing[len(existing)-1] != '\n' { + existing = append(existing, '\n') + } + return append(existing, []byte(block)...) +} + +// mergePlugins returns the union of defaults and per-workspace plugin lists +// (deduplicated, defaults first). A per-workspace entry starting with "!" or +// "-" opts that plugin OUT of the union. See issue #68. +func mergePlugins(defaultPlugins, wsPlugins []string) []string { + seen := map[string]bool{} + out := make([]string, 0, len(defaultPlugins)+len(wsPlugins)) + for _, p := range defaultPlugins { + if p == "" || seen[p] { + continue + } + seen[p] = true + out = append(out, p) + } + for _, p := range wsPlugins { + if p == "" { + continue + } + if strings.HasPrefix(p, "!") || strings.HasPrefix(p, "-") { + target := strings.TrimLeft(p, "!-") + if target == "" { + continue + } + if seen[target] { + delete(seen, target) + filtered := out[:0] + for _, existing := range out { + if existing != target { + filtered = append(filtered, existing) + } + } + out = filtered + } + continue + } + if !seen[p] { + seen[p] = true + out = append(out, p) + } + } + return out +} + +// resolveInsideRoot joins `userPath` onto `root` and ensures the lexically +// cleaned result stays inside root. Rejects absolute paths outright and +// anything containing ".." that would escape the root. +// +// Both arguments are resolved to absolute paths via filepath.Abs before the +// prefix check so a root passed as a relative path still works correctly. +// Follows Go's standard pattern for SSRF-class path sanitization; using +// strings.HasPrefix on an absolute-path pair plus the separator guard rejects +// sibling directories that share a prefix (e.g. "/foo" vs "/foobar"). +func resolveInsideRoot(root, userPath string) (string, error) { + if userPath == "" { + return "", fmt.Errorf("path is empty") + } + if filepath.IsAbs(userPath) { + return "", fmt.Errorf("absolute paths are not allowed") + } + absRoot, err := filepath.Abs(root) + if err != nil { + return "", fmt.Errorf("root abs: %w", err) + } + joined := filepath.Join(absRoot, userPath) + absJoined, err := filepath.Abs(joined) + if err != nil { + return "", fmt.Errorf("joined abs: %w", err) + } + // Allow exact-root match (rare but valid) and any descendant. + if absJoined != absRoot && !strings.HasPrefix(absJoined, absRoot+string(filepath.Separator)) { + return "", fmt.Errorf("path escapes root") + } + return absJoined, nil +} diff --git a/workspace-server/internal/handlers/org_import.go b/workspace-server/internal/handlers/org_import.go new file mode 100644 index 00000000..442f5836 --- /dev/null +++ b/workspace-server/internal/handlers/org_import.go @@ -0,0 +1,490 @@ +package handlers + +// org_import.go — workspace tree creation during org template import. +// Contains createWorkspaceTree (recursive provisioning) and countWorkspaces. + +import ( + "context" + "encoding/json" + "fmt" + "log" + "os" + "path/filepath" + "strings" + "time" + + "github.com/Molecule-AI/molecule-monorepo/platform/internal/channels" + "github.com/Molecule-AI/molecule-monorepo/platform/internal/crypto" + "github.com/Molecule-AI/molecule-monorepo/platform/internal/db" + "github.com/Molecule-AI/molecule-monorepo/platform/internal/models" + "github.com/Molecule-AI/molecule-monorepo/platform/internal/provisioner" + "github.com/Molecule-AI/molecule-monorepo/platform/internal/scheduler" + "github.com/google/uuid" +) +func (h *OrgHandler) createWorkspaceTree(ws OrgWorkspace, parentID *string, defaults OrgDefaults, orgBaseDir string, results *[]map[string]interface{}, provisionSem chan struct{}) error { + // Apply defaults + runtime := ws.Runtime + if runtime == "" { + runtime = defaults.Runtime + } + if runtime == "" { + runtime = "langgraph" + } + model := ws.Model + if model == "" { + model = defaults.Model + } + if model == "" { + if runtime == "claude-code" { + model = "sonnet" + } else { + model = "anthropic:claude-opus-4-7" + } + } + tier := ws.Tier + if tier == 0 { + tier = defaults.Tier + } + if tier == 0 { + tier = 2 + } + + id := uuid.New().String() + awarenessNS := workspaceAwarenessNamespace(id) + + var role interface{} + if ws.Role != "" { + role = ws.Role + } + + // Expand ${VAR} references in workspace_dir against the org's .env files + // before validation. Without this, a template that ships + // `workspace_dir: ${WORKSPACE_DIR}` (so each operator can pick the host + // path to bind-mount) reaches validateWorkspaceDir as the literal + // "${WORKSPACE_DIR}" string and fails with "must be an absolute path". + // Other fields (channel config, prompts) already go through expandWithEnv; + // workspace_dir was the last hold-out. + if ws.WorkspaceDir != "" { + ws.WorkspaceDir = expandWithEnv(ws.WorkspaceDir, loadWorkspaceEnv(orgBaseDir, ws.FilesDir)) + } + + // Validate and convert workspace_dir to NULL if empty + var workspaceDir interface{} + if ws.WorkspaceDir != "" { + if err := validateWorkspaceDir(ws.WorkspaceDir); err != nil { + return fmt.Errorf("workspace %s: %w", ws.Name, err) + } + workspaceDir = ws.WorkspaceDir + } + + // #65: validate workspace_access (defaults to "none" when empty). + workspaceAccess := ws.WorkspaceAccess + if workspaceAccess == "" { + workspaceAccess = provisioner.WorkspaceAccessNone + } + if err := provisioner.ValidateWorkspaceAccess(workspaceAccess, ws.WorkspaceDir); err != nil { + return fmt.Errorf("workspace %s: %w", ws.Name, err) + } + + ctx := context.Background() + + // Insert workspace + _, err := db.DB.ExecContext(ctx, ` + INSERT INTO workspaces (id, name, role, tier, runtime, awareness_namespace, status, parent_id, workspace_dir, workspace_access) + VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10) + `, id, ws.Name, role, tier, runtime, awarenessNS, "provisioning", parentID, workspaceDir, workspaceAccess) + if err != nil { + log.Printf("Org import: failed to create %s: %v", ws.Name, err) + return fmt.Errorf("failed to create %s: %w", ws.Name, err) + } + + // Canvas layout with coordinates from YAML + if _, err := db.DB.ExecContext(ctx, `INSERT INTO canvas_layouts (workspace_id, x, y) VALUES ($1, $2, $3)`, id, ws.Canvas.X, ws.Canvas.Y); err != nil { + log.Printf("Org import: canvas layout insert failed for %s: %v", ws.Name, err) + } + + // Broadcast + h.broadcaster.RecordAndBroadcast(ctx, "WORKSPACE_PROVISIONING", id, map[string]interface{}{ + "name": ws.Name, "tier": tier, + }) + + // Seed initial memories from workspace config or defaults (issue #1050). + // Per-workspace initial_memories override defaults; if workspace has none, + // fall back to defaults.initial_memories. + wsMemories := ws.InitialMemories + if len(wsMemories) == 0 { + wsMemories = defaults.InitialMemories + } + seedInitialMemories(ctx, id, wsMemories, awarenessNS) + + // Handle external workspaces + if ws.External { + if _, err := db.DB.ExecContext(ctx, `UPDATE workspaces SET status = 'online', url = $1 WHERE id = $2`, ws.URL, id); err != nil { + log.Printf("Org import: external workspace status update failed for %s: %v", ws.Name, err) + } + h.broadcaster.RecordAndBroadcast(ctx, "WORKSPACE_ONLINE", id, map[string]interface{}{ + "name": ws.Name, "external": true, + }) + } else if h.provisioner != nil { + // Provision container + payload := models.CreateWorkspacePayload{ + Name: ws.Name, Tier: tier, Runtime: runtime, Model: model, + WorkspaceDir: ws.WorkspaceDir, + WorkspaceAccess: workspaceAccess, + } + templatePath := "" + if ws.Template != "" { + // `template` comes from the uploaded YAML — treat as untrusted. + // Only accept paths that stay inside h.configsDir. + if tp, err := resolveInsideRoot(h.configsDir, ws.Template); err == nil { + if _, statErr := os.Stat(tp); statErr == nil { + templatePath = tp + } + } + } + if templatePath == "" { + // #241: sanitizeRuntime() allowlists the runtime string so a + // crafted org.yaml cannot use it as a path-traversal oracle. + safeRuntime := sanitizeRuntime(runtime) + runtimeDefault := filepath.Join(h.configsDir, safeRuntime+"-default") + if _, err := os.Stat(runtimeDefault); err == nil { + templatePath = runtimeDefault + } + } + + // Always generate default config.yaml (runtime, model, tier, etc.) + configFiles := h.workspace.ensureDefaultConfig(id, payload) + + // Copy files_dir contents on top (system-prompt.md, CLAUDE.md, skills/, etc.) + // Uses templatePath for CopyTemplateToContainer — runs AFTER configFiles are written + if ws.FilesDir != "" && orgBaseDir != "" { + // `files_dir` also comes from untrusted YAML. Join inside orgBaseDir + // (already validated above) and reject anything that escapes. + if filesPath, err := resolveInsideRoot(orgBaseDir, ws.FilesDir); err == nil { + if info, statErr := os.Stat(filesPath); statErr == nil && info.IsDir() { + templatePath = filesPath + } + } + } + + // Pre-install plugins: copy from registry into configFiles as plugins//*. + // Per-workspace plugins UNION with defaults.plugins (issue #68). + // A leading "!" or "-" on a per-workspace entry opts that plugin out. + plugins := mergePlugins(defaults.Plugins, ws.Plugins) + if len(plugins) > 0 { + if configFiles == nil { + configFiles = map[string][]byte{} + } + pluginsBase, _ := filepath.Abs(filepath.Join(h.configsDir, "..", "plugins")) + for _, pluginName := range plugins { + pluginSrc := filepath.Join(pluginsBase, pluginName) + if info, err := os.Stat(pluginSrc); err != nil || !info.IsDir() { + log.Printf("Org import: plugin %s not found at %s, skipping", pluginName, pluginSrc) + continue + } + filepath.Walk(pluginSrc, func(path string, info os.FileInfo, err error) error { + if err != nil || info.IsDir() { + return nil + } + rel, _ := filepath.Rel(pluginSrc, path) + data, readErr := os.ReadFile(path) + if readErr == nil { + configFiles["plugins/"+pluginName+"/"+rel] = data + } + return nil + }) + } + } + + // Render category_routing into config.yaml so the agent can read its routing + // table at runtime without hardcoded role names in prompts (issue #51). + // Per-workspace keys replace defaults per-key (empty list drops the key); + // see mergeCategoryRouting for exact semantics. + routing := mergeCategoryRouting(defaults.CategoryRouting, ws.CategoryRouting) + if len(routing) > 0 { + if configFiles == nil { + configFiles = map[string][]byte{} + } + block, err := renderCategoryRoutingYAML(routing) + if err != nil { + log.Printf("Org import: failed to render category_routing for %s: %v", ws.Name, err) + } else { + configFiles["config.yaml"] = appendYAMLBlock(configFiles["config.yaml"], block) + } + } + + // Resolve initial_prompt — inline wins, then file-ref, then defaults + // (inline → file → defaults.inline → defaults.file). File refs are + // rooted at // per resolvePromptRef semantics. + initialPrompt, err := resolvePromptRef(ws.InitialPrompt, ws.InitialPromptFile, orgBaseDir, ws.FilesDir) + if err != nil { + log.Printf("Org import: failed to resolve initial_prompt for %s: %v", ws.Name, err) + } + if initialPrompt == "" { + // Fall back to defaults. Defaults live at the org root, so they + // resolve with empty filesDir (relative to orgBaseDir). + var defaultErr error + initialPrompt, defaultErr = resolvePromptRef(defaults.InitialPrompt, defaults.InitialPromptFile, orgBaseDir, "") + if defaultErr != nil { + log.Printf("Org import: failed to resolve defaults.initial_prompt for %s: %v", ws.Name, defaultErr) + } + } + if initialPrompt != "" { + if configFiles == nil { + configFiles = map[string][]byte{} + } + // Append initial_prompt to config.yaml using YAML block scalar. + // Trim each line to avoid trailing whitespace issues. + trimmed := strings.TrimSpace(initialPrompt) + lines := strings.Split(trimmed, "\n") + for i, line := range lines { + lines[i] = strings.TrimRight(line, " \t") + } + indented := strings.Join(lines, "\n ") + configFiles["config.yaml"] = appendYAMLBlock(configFiles["config.yaml"], fmt.Sprintf("initial_prompt: |\n %s\n", indented)) + log.Printf("Org import: injected initial_prompt (%d chars) into config.yaml for %s", len(trimmed), ws.Name) + } + + // Resolve idle_prompt — same precedence (ws inline → ws file → defaults). + // Inject into config.yaml alongside idle_interval_seconds so the + // workspace's heartbeat loop picks up the idle-reflection cadence on + // boot (see workspace/heartbeat.py + config.py). + idlePrompt, err := resolvePromptRef(ws.IdlePrompt, ws.IdlePromptFile, orgBaseDir, ws.FilesDir) + if err != nil { + log.Printf("Org import: failed to resolve idle_prompt for %s: %v", ws.Name, err) + } + if idlePrompt == "" { + var defaultErr error + idlePrompt, defaultErr = resolvePromptRef(defaults.IdlePrompt, defaults.IdlePromptFile, orgBaseDir, "") + if defaultErr != nil { + log.Printf("Org import: failed to resolve defaults.idle_prompt for %s: %v", ws.Name, defaultErr) + } + } + idleInterval := ws.IdleIntervalSeconds + if idleInterval == 0 { + idleInterval = defaults.IdleIntervalSeconds + } + if idlePrompt != "" { + if configFiles == nil { + configFiles = map[string][]byte{} + } + trimmed := strings.TrimSpace(idlePrompt) + lines := strings.Split(trimmed, "\n") + for i, line := range lines { + lines[i] = strings.TrimRight(line, " \t") + } + indented := strings.Join(lines, "\n ") + // idle_interval_seconds belongs with idle_prompt — empty idle_prompt + // means the idle loop never fires regardless of interval, so we + // only emit interval when there's a body to go with it. + if idleInterval <= 0 { + idleInterval = 600 // same default as workspace/config.py + } + block := fmt.Sprintf("idle_interval_seconds: %d\nidle_prompt: |\n %s\n", idleInterval, indented) + configFiles["config.yaml"] = appendYAMLBlock(configFiles["config.yaml"], block) + log.Printf("Org import: injected idle_prompt (%d chars, interval=%ds) into config.yaml for %s", len(trimmed), idleInterval, ws.Name) + } + + // Inline system_prompt (only if no files_dir provides one) + if ws.SystemPrompt != "" { + if configFiles == nil { + configFiles = map[string][]byte{} + } + configFiles["system-prompt.md"] = []byte(ws.SystemPrompt) + } + + // Inject secrets from .env files as workspace secrets. + // Resolution: workspace .env → org root .env (workspace overrides org root). + // Each line: KEY=VALUE → stored as encrypted workspace secret. + envVars := map[string]string{} + if orgBaseDir != "" { + // 1. Org root .env (shared defaults) + parseEnvFile(filepath.Join(orgBaseDir, ".env"), envVars) + // 2. Workspace-specific .env (overrides) + if ws.FilesDir != "" { + parseEnvFile(filepath.Join(orgBaseDir, ws.FilesDir, ".env"), envVars) + } + } + // Store as workspace secrets via DB (encrypted if key is set, raw otherwise) + for key, value := range envVars { + var encrypted []byte + if crypto.IsEnabled() { + var err error + encrypted, err = crypto.Encrypt([]byte(value)) + if err != nil { + log.Printf("Org import: failed to encrypt secret %s for %s: %v", key, ws.Name, err) + continue + } + } else { + encrypted = []byte(value) // store raw when encryption disabled + } + if _, err := db.DB.ExecContext(ctx, ` + INSERT INTO workspace_secrets (workspace_id, key, encrypted_value) + VALUES ($1, $2, $3) + ON CONFLICT (workspace_id, key) DO UPDATE SET encrypted_value = $3, updated_at = now() + `, id, key, encrypted); err != nil { + log.Printf("Org import: failed to insert secret %s for %s: %v", key, ws.Name, err) + } + } + + // #1084: limit concurrent Docker provisioning via semaphore. + provisionSem <- struct{}{} // acquire + go func(wID, tPath string, cFiles map[string][]byte, p models.CreateWorkspacePayload) { + defer func() { <-provisionSem }() // release + h.workspace.provisionWorkspace(wID, tPath, cFiles, p) + }(id, templatePath, configFiles, payload) + } + + // Insert schedules if defined. Resolve each schedule's prompt body from + // either inline `prompt:` or `prompt_file:` (file ref relative to the + // workspace's files_dir). Inline wins; empty prompt after resolution is + // a configuration error (cron with no body would never do anything). + for _, sched := range ws.Schedules { + tz := sched.Timezone + if tz == "" { + tz = "UTC" + } + enabled := true + if sched.Enabled != nil { + enabled = *sched.Enabled + } + prompt, promptErr := resolvePromptRef(sched.Prompt, sched.PromptFile, orgBaseDir, ws.FilesDir) + if promptErr != nil { + log.Printf("Org import: failed to resolve prompt for schedule '%s' on %s: %v — skipping insert", sched.Name, ws.Name, promptErr) + continue + } + if prompt == "" { + log.Printf("Org import: schedule '%s' on %s has empty prompt (neither prompt nor prompt_file set) — skipping insert", sched.Name, ws.Name) + continue + } + // #722: surface the error rather than silently using time.Time{} (zero) + // which lib/pq stores as 0001-01-01 and may confuse the fire query. + nextRun, nextRunErr := scheduler.ComputeNextRun(sched.CronExpr, tz, time.Now()) + if nextRunErr != nil { + log.Printf("Org import: invalid cron expression for schedule '%s' on %s: %v — skipping insert", + sched.Name, ws.Name, nextRunErr) + continue + } + if _, err := db.DB.ExecContext(context.Background(), orgImportScheduleSQL, + id, sched.Name, sched.CronExpr, tz, prompt, enabled, nextRun); err != nil { + log.Printf("Org import: failed to upsert schedule '%s' for %s: %v", sched.Name, ws.Name, err) + } else { + log.Printf("Org import: schedule '%s' (%s, %d chars) upserted for %s (source=template)", sched.Name, sched.CronExpr, len(prompt), ws.Name) + } + } + + // Insert channels if defined (Telegram, Slack, etc.). Config values + // support ${VAR} expansion from .env files. The manager is reloaded + // once at the end of org import (in Import), not per-workspace. + channelEnv := loadWorkspaceEnv(orgBaseDir, ws.FilesDir) + wsChannelsCreated := []string{} + wsChannelsSkipped := []map[string]string{} + // skipChannel records a skipped channel with consistent shape across all reasons. + skipChannel := func(channelType, reason string) { + wsChannelsSkipped = append(wsChannelsSkipped, map[string]string{ + "workspace": ws.Name, + "type": channelType, // empty string when type field was missing + "reason": reason, + }) + } + + for _, ch := range ws.Channels { + if ch.Type == "" { + skipChannel("", "empty type") + log.Printf("Org import: skipping channel with empty type for %s", ws.Name) + continue + } + // Validate adapter exists upfront — fail fast instead of inserting orphan rows + adapter, ok := channels.GetAdapter(ch.Type) + if !ok { + skipChannel(ch.Type, "unknown adapter") + log.Printf("Org import: skipping %s channel for %s — no adapter registered", ch.Type, ws.Name) + continue + } + + expandedConfig := make(map[string]interface{}, len(ch.Config)) + missing := []string{} + for k, v := range ch.Config { + expanded := expandWithEnv(v, channelEnv) + if hasUnresolvedVarRef(v, expanded) { + missing = append(missing, v) + } + expandedConfig[k] = expanded + } + if len(missing) > 0 { + skipChannel(ch.Type, fmt.Sprintf("missing env: %v", missing)) + log.Printf("Org import: skipping %s channel for %s — env vars not set: %v", ch.Type, ws.Name, missing) + continue + } + + // Adapter-level config validation + if err := adapter.ValidateConfig(expandedConfig); err != nil { + skipChannel(ch.Type, err.Error()) + log.Printf("Org import: skipping %s channel for %s — invalid config: %v", ch.Type, ws.Name, err) + continue + } + + configJSON, err := json.Marshal(expandedConfig) + if err != nil { + log.Printf("Org import: failed to marshal config for %s channel: %v", ch.Type, err) + continue + } + allowedJSON, err := json.Marshal(ch.AllowedUsers) + if err != nil { + log.Printf("Org import: failed to marshal allowed_users for %s channel: %v", ch.Type, err) + continue + } + enabled := true + if ch.Enabled != nil { + enabled = *ch.Enabled + } + // Idempotent insert — if same workspace+type already exists, update config + if _, err := db.DB.ExecContext(context.Background(), ` + INSERT INTO workspace_channels (workspace_id, channel_type, channel_config, enabled, allowed_users) + VALUES ($1, $2, $3::jsonb, $4, $5::jsonb) + ON CONFLICT (workspace_id, channel_type) DO UPDATE + SET channel_config = EXCLUDED.channel_config, + enabled = EXCLUDED.enabled, + allowed_users = EXCLUDED.allowed_users, + updated_at = now() + `, id, ch.Type, string(configJSON), enabled, string(allowedJSON)); err != nil { + log.Printf("Org import: failed to create %s channel for %s: %v", ch.Type, ws.Name, err) + } else { + wsChannelsCreated = append(wsChannelsCreated, ch.Type) + log.Printf("Org import: %s channel created for %s", ch.Type, ws.Name) + } + } + + resultEntry := map[string]interface{}{ + "id": id, + "name": ws.Name, + "tier": tier, + } + if len(wsChannelsCreated) > 0 { + resultEntry["channels"] = wsChannelsCreated + } + if len(wsChannelsSkipped) > 0 { + resultEntry["channels_skipped"] = wsChannelsSkipped + } + *results = append(*results, resultEntry) + + // Recurse into children. Brief pacing avoids overwhelming Docker when + // creating many containers in sequence; container provisioning runs in + // goroutines so the main createWorkspaceTree returns quickly. + for _, child := range ws.Children { + if err := h.createWorkspaceTree(child, &id, defaults, orgBaseDir, results, provisionSem); err != nil { + return err + } + time.Sleep(workspaceCreatePacingMs * time.Millisecond) + } + + return nil +} + +func countWorkspaces(workspaces []OrgWorkspace) int { + count := len(workspaces) + for _, ws := range workspaces { + count += countWorkspaces(ws.Children) + } + return count +} diff --git a/workspace-server/internal/handlers/ssrf.go b/workspace-server/internal/handlers/ssrf.go new file mode 100644 index 00000000..09bb2774 --- /dev/null +++ b/workspace-server/internal/handlers/ssrf.go @@ -0,0 +1,90 @@ +package handlers + +import ( + "fmt" + "net" + "net/url" + "path/filepath" + "strings" +) + +// isSafeURL validates that a URL resolves to a publicly-routable address, +// preventing A2A requests from being redirected to internal/cloud-metadata +// infrastructure (SSRF, CWE-918). Workspace URLs come from DB/Redis caches +// so we validate before making any outbound HTTP call. +func isSafeURL(rawURL string) error { + u, err := url.Parse(rawURL) + if err != nil { + return fmt.Errorf("invalid URL: %w", err) + } + // Reject non-HTTP(S) schemes. + if u.Scheme != "http" && u.Scheme != "https" { + return fmt.Errorf("forbidden scheme: %s (only http/https allowed)", u.Scheme) + } + host := u.Hostname() + if host == "" { + return fmt.Errorf("empty hostname") + } + // Block direct IP addresses. + if ip := net.ParseIP(host); ip != nil { + if ip.IsLoopback() || ip.IsUnspecified() || ip.IsLinkLocalUnicast() { + return fmt.Errorf("forbidden loopback/unspecified IP: %s", ip) + } + if isPrivateOrMetadataIP(ip) { + return fmt.Errorf("forbidden private/metadata IP: %s", ip) + } + return nil + } + // For hostnames, resolve and validate each returned IP. + addrs, err := net.LookupHost(host) + if err != nil { + // DNS resolution failure — block it. Could be an internal hostname. + return fmt.Errorf("DNS resolution blocked for hostname: %s (%v)", host, err) + } + if len(addrs) == 0 { + return fmt.Errorf("DNS returned no addresses for: %s", host) + } + for _, addr := range addrs { + ip := net.ParseIP(addr) + if ip != nil && (ip.IsLoopback() || ip.IsUnspecified() || ip.IsLinkLocalUnicast() || isPrivateOrMetadataIP(ip)) { + return fmt.Errorf("hostname %s resolves to forbidden IP: %s", host, ip) + } + } + return nil +} + +// isPrivateOrMetadataIP returns true for RFC-1918 private, carrier-grade NAT, +// link-local, and cloud metadata ranges. +func isPrivateOrMetadataIP(ip net.IP) bool { + var privateRanges = []net.IPNet{ + {IP: net.ParseIP("10.0.0.0"), Mask: net.CIDRMask(8, 32)}, + {IP: net.ParseIP("172.16.0.0"), Mask: net.CIDRMask(12, 32)}, + {IP: net.ParseIP("192.168.0.0"), Mask: net.CIDRMask(16, 32)}, + {IP: net.ParseIP("169.254.0.0"), Mask: net.CIDRMask(16, 32)}, + {IP: net.ParseIP("100.64.0.0"), Mask: net.CIDRMask(10, 32)}, + {IP: net.ParseIP("192.0.2.0"), Mask: net.CIDRMask(24, 32)}, + {IP: net.ParseIP("198.51.100.0"), Mask: net.CIDRMask(24, 32)}, + {IP: net.ParseIP("203.0.113.0"), Mask: net.CIDRMask(24, 32)}, + } + ip = ip.To4() + if ip == nil { + return false + } + for _, r := range privateRanges { + if r.Contains(ip) { + return true + } + } + return false +} + +// validateRelPath checks that a file path is relative and does not escape +// the destination via absolute paths or ".." traversal. Used by +// copyFilesToContainer and deleteViaEphemeral as a defence-in-depth measure. +func validateRelPath(filePath string) error { + clean := filepath.Clean(filePath) + if filepath.IsAbs(clean) || strings.Contains(clean, "..") { + return fmt.Errorf("path traversal or absolute path not allowed: %s", filePath) + } + return nil +} \ No newline at end of file diff --git a/workspace-server/internal/handlers/workspace.go b/workspace-server/internal/handlers/workspace.go index 8b534f70..6af680f1 100644 --- a/workspace-server/internal/handlers/workspace.go +++ b/workspace-server/internal/handlers/workspace.go @@ -1,5 +1,9 @@ package handlers +// workspace.go — WorkspaceHandler struct, constructor, Create, List, Get, +// and the shared scanWorkspaceRow helper. State/Update/Delete and validators +// live in workspace_crud.go. + import ( "context" "database/sql" @@ -16,10 +20,8 @@ import ( "github.com/Molecule-AI/molecule-monorepo/platform/internal/events" "github.com/Molecule-AI/molecule-monorepo/platform/internal/models" "github.com/Molecule-AI/molecule-monorepo/platform/internal/provisioner" - "github.com/Molecule-AI/molecule-monorepo/platform/internal/wsauth" "github.com/Molecule-AI/molecule-monorepo/platform/pkg/provisionhook" "github.com/gin-gonic/gin" - "github.com/lib/pq" "github.com/google/uuid" ) @@ -303,7 +305,7 @@ func scanWorkspaceRow(rows interface { Scan(dest ...interface{}) error }) (map[string]interface{}, error) { var id, name, role, status, url, sampleError, currentTask, runtime, workspaceDir string - var tier, activeTasks, uptimeSeconds int + var tier, activeTasks, maxConcurrentTasks, uptimeSeconds int var errorRate, x, y float64 var collapsed bool var parentID *string @@ -312,7 +314,7 @@ func scanWorkspaceRow(rows interface { var monthlySpend int64 err := rows.Scan(&id, &name, &role, &tier, &status, &agentCard, &url, - &parentID, &activeTasks, &errorRate, &sampleError, &uptimeSeconds, + &parentID, &activeTasks, &maxConcurrentTasks, &errorRate, &sampleError, &uptimeSeconds, ¤tTask, &runtime, &workspaceDir, &x, &y, &collapsed, &budgetLimit, &monthlySpend) if err != nil { @@ -326,8 +328,9 @@ func scanWorkspaceRow(rows interface { "status": status, "url": url, "parent_id": parentID, - "active_tasks": activeTasks, - "last_error_rate": errorRate, + "active_tasks": activeTasks, + "max_concurrent_tasks": maxConcurrentTasks, + "last_error_rate": errorRate, "last_sample_error": sampleError, "uptime_seconds": uptimeSeconds, "current_task": currentTask, @@ -366,7 +369,8 @@ func scanWorkspaceRow(rows interface { const workspaceListQuery = ` SELECT w.id, w.name, COALESCE(w.role, ''), w.tier, w.status, COALESCE(w.agent_card, 'null'::jsonb), COALESCE(w.url, ''), - w.parent_id, w.active_tasks, w.last_error_rate, + w.parent_id, w.active_tasks, COALESCE(w.max_concurrent_tasks, 1), + w.last_error_rate, COALESCE(w.last_sample_error, ''), w.uptime_seconds, COALESCE(w.current_task, ''), COALESCE(w.runtime, 'langgraph'), COALESCE(w.workspace_dir, ''), @@ -418,7 +422,8 @@ func (h *WorkspaceHandler) Get(c *gin.Context) { row := db.DB.QueryRowContext(c.Request.Context(), ` SELECT w.id, w.name, COALESCE(w.role, ''), w.tier, w.status, COALESCE(w.agent_card, 'null'::jsonb), COALESCE(w.url, ''), - w.parent_id, w.active_tasks, w.last_error_rate, + w.parent_id, w.active_tasks, COALESCE(w.max_concurrent_tasks, 1), + w.last_error_rate, COALESCE(w.last_sample_error, ''), w.uptime_seconds, COALESCE(w.current_task, ''), COALESCE(w.runtime, 'langgraph'), COALESCE(w.workspace_dir, ''), @@ -462,473 +467,3 @@ func (h *WorkspaceHandler) Get(c *gin.Context) { c.JSON(http.StatusOK, ws) } - -// State handles GET /workspaces/:id/state — minimal status payload for -// remote-agent polling (Phase 30.4). Returns `{status, paused, deleted, -// workspace_id}` so a remote agent can detect pause/resume/delete -// without needing WebSocket reachability from the platform. -// -// Auth: Phase 30.1 bearer token required when the workspace has any -// live token on file; legacy workspaces grandfathered. Uses the same -// fail-closed posture as secrets.Values — polling this cadence with -// unauth'd callers would be a trivial DoS / workspace-status-scanner -// otherwise. -// -// The endpoint is deliberately NOT merged with GET /workspaces/:id: -// that handler is optimized for canvas (returns config, agent_card, -// position, …) and is unauthenticated by design. State is the -// agent-machinery polling path — tight, token-gated, cache-friendly. -func (h *WorkspaceHandler) State(c *gin.Context) { - workspaceID := c.Param("id") - ctx := c.Request.Context() - - // Auth gate — same shape as secrets.Values (Phase 30.2). Fail-closed - // on DB errors because the caller is about to poll this at ~60s - // cadence; letting unauth'd callers through on a hiccup turns this - // into a workspace-status scanner. - hasLive, hlErr := wsauth.HasAnyLiveToken(ctx, db.DB, workspaceID) - if hlErr != nil { - log.Printf("wsauth: HasAnyLiveToken(%s) failed for workspace.State: %v", workspaceID, hlErr) - c.JSON(http.StatusInternalServerError, gin.H{"error": "auth check failed"}) - return - } - if hasLive { - tok := wsauth.BearerTokenFromHeader(c.GetHeader("Authorization")) - if tok == "" { - c.JSON(http.StatusUnauthorized, gin.H{"error": "missing workspace auth token"}) - return - } - if err := wsauth.ValidateToken(ctx, db.DB, workspaceID, tok); err != nil { - c.JSON(http.StatusUnauthorized, gin.H{"error": "invalid workspace auth token"}) - return - } - } - - var status string - err := db.DB.QueryRowContext(ctx, ` - SELECT status - FROM workspaces - WHERE id = $1 - `, workspaceID).Scan(&status) - if err == sql.ErrNoRows { - // A deleted workspace row no longer exists — remote agent should - // interpret 404 as "shut yourself down" (our pause path uses - // status='removed' but keeps the row; a 404 here means the - // workspace was hard-deleted out from under the agent). - c.JSON(http.StatusNotFound, gin.H{ - "workspace_id": workspaceID, - "deleted": true, - }) - return - } - if err != nil { - log.Printf("workspace.State query error for %s: %v", workspaceID, err) - c.JSON(http.StatusInternalServerError, gin.H{"error": "query failed"}) - return - } - - // Two delete paths: hard-delete (sql.ErrNoRows above → 404) AND - // soft-delete (status='removed' → also return 404 here so the SDK - // doesn't have to remember "is it 200 with deleted=true OR 404 with - // deleted=true?"). Same shape, same status code, same flag set. - if status == "removed" { - c.JSON(http.StatusNotFound, gin.H{ - "workspace_id": workspaceID, - "status": "removed", - "deleted": true, - }) - return - } - - c.JSON(http.StatusOK, gin.H{ - "workspace_id": workspaceID, - "status": status, - "paused": status == "paused", - "deleted": false, - }) -} - -// sensitiveUpdateFields documents fields that carry elevated risk — kept as -// an explicit list for code readability and future audits. Auth is now fully -// enforced at the router layer (WorkspaceAuth middleware, #680 IDOR fix); -// this map is no longer used for in-handler gate logic but is preserved to -// surface the risk classification clearly. -// -// budget_limit is intentionally NOT here — the dedicated PATCH -// /workspaces/:id/budget (AdminAuth) is the only write path (#611). -var sensitiveUpdateFields = map[string]struct{}{ - "tier": {}, - "parent_id": {}, - "runtime": {}, - "workspace_dir": {}, -} - -// Update handles PATCH /workspaces/:id -func (h *WorkspaceHandler) Update(c *gin.Context) { - id := c.Param("id") - - // #687: reject non-UUID IDs before hitting the DB. - if err := validateWorkspaceID(id); err != nil { - c.JSON(http.StatusBadRequest, gin.H{"error": "invalid workspace ID"}) - return - } - - var body map[string]interface{} - if err := c.ShouldBindJSON(&body); err != nil { - c.JSON(http.StatusBadRequest, gin.H{"error": "invalid request body"}) - return - } - - // #685/#688: validate string fields for length and injection safety. - strField := func(key string) string { - if v, ok := body[key]; ok { - if s, ok := v.(string); ok { - return s - } - } - return "" - } - if err := validateWorkspaceFields( - strField("name"), strField("role"), "" /*model not patchable*/, strField("runtime"), - ); err != nil { - c.JSON(http.StatusBadRequest, gin.H{"error": "invalid workspace fields"}) - return - } - - ctx := c.Request.Context() - - // Auth is fully enforced at the router layer (WorkspaceAuth middleware, #680). - // WorkspaceAuth validates that the caller holds a valid bearer token for this - // specific workspace — no additional auth gate is needed here. The - // sensitiveUpdateFields map above documents the risk classification for - // auditors but is no longer used as a runtime gate. - - // #120: guard — return 404 for nonexistent workspace IDs instead of - // silently applying zero-row UPDATEs and returning 200. - var exists bool - if err := db.DB.QueryRowContext(ctx, - `SELECT EXISTS(SELECT 1 FROM workspaces WHERE id = $1)`, id, - ).Scan(&exists); err != nil || !exists { - c.JSON(http.StatusNotFound, gin.H{"error": "workspace not found"}) - return - } - - if name, ok := body["name"]; ok { - if _, err := db.DB.ExecContext(ctx, `UPDATE workspaces SET name = $2, updated_at = now() WHERE id = $1`, id, name); err != nil { - log.Printf("Update name error for %s: %v", id, err) - } - } - if role, ok := body["role"]; ok { - if _, err := db.DB.ExecContext(ctx, `UPDATE workspaces SET role = $2, updated_at = now() WHERE id = $1`, id, role); err != nil { - log.Printf("Update role error for %s: %v", id, err) - } - } - if tier, ok := body["tier"]; ok { - if _, err := db.DB.ExecContext(ctx, `UPDATE workspaces SET tier = $2, updated_at = now() WHERE id = $1`, id, tier); err != nil { - log.Printf("Update tier error for %s: %v", id, err) - } - } - if parentID, ok := body["parent_id"]; ok { - if _, err := db.DB.ExecContext(ctx, `UPDATE workspaces SET parent_id = $2, updated_at = now() WHERE id = $1`, id, parentID); err != nil { - log.Printf("Update parent_id error for %s: %v", id, err) - } - } - if runtime, ok := body["runtime"]; ok { - if _, err := db.DB.ExecContext(ctx, `UPDATE workspaces SET runtime = $2, updated_at = now() WHERE id = $1`, id, runtime); err != nil { - log.Printf("Update runtime error for %s: %v", id, err) - } - } - needsRestart := false - if wsDir, ok := body["workspace_dir"]; ok { - // Allow null to clear workspace_dir - if wsDir != nil { - if dirStr, isStr := wsDir.(string); isStr && dirStr != "" { - if err := validateWorkspaceDir(dirStr); err != nil { - c.JSON(http.StatusBadRequest, gin.H{"error": "invalid workspace directory"}) - return - } - } - } - if _, err := db.DB.ExecContext(ctx, `UPDATE workspaces SET workspace_dir = $2, updated_at = now() WHERE id = $1`, id, wsDir); err != nil { - log.Printf("Update workspace_dir error for %s: %v", id, err) - } - needsRestart = true - } - // NOTE: budget_limit is intentionally NOT handled here. The dedicated - // PATCH /workspaces/:id/budget (AdminAuth) is the only write path. - // This endpoint uses ValidateAnyToken — any enrolled workspace bearer - // could otherwise self-clear its own spending ceiling. (#611 Security Auditor) - - // Update canvas position if both x and y provided - if x, xOk := body["x"]; xOk { - if y, yOk := body["y"]; yOk { - if _, err := db.DB.ExecContext(ctx, ` - INSERT INTO canvas_layouts (workspace_id, x, y) - VALUES ($1, $2, $3) - ON CONFLICT (workspace_id) DO UPDATE SET x = EXCLUDED.x, y = EXCLUDED.y - `, id, x, y); err != nil { - log.Printf("Update position error for %s: %v", id, err) - } - } - } - - resp := gin.H{"status": "updated"} - if needsRestart { - resp["needs_restart"] = true - } - c.JSON(http.StatusOK, resp) -} - -// validateWorkspaceDir checks that a workspace_dir path is safe to bind-mount. -func validateWorkspaceDir(dir string) error { - if !filepath.IsAbs(dir) { - return fmt.Errorf("workspace_dir must be an absolute path") - } - if strings.Contains(dir, "..") { - return fmt.Errorf("workspace_dir must not contain '..'") - } - // Reject system-critical paths - clean := filepath.Clean(dir) - for _, blocked := range []string{"/etc", "/var", "/proc", "/sys", "/dev", "/boot", "/sbin", "/bin", "/lib", "/usr"} { - if clean == blocked || strings.HasPrefix(clean, blocked+"/") { - return fmt.Errorf("workspace_dir must not be a system path (%s)", blocked) - } - } - return nil -} - -// Delete handles DELETE /workspaces/:id -// If the workspace has children (is a team), cascade deletes all sub-workspaces. -// Use ?confirm=true to actually delete (otherwise returns children list for confirmation). -func (h *WorkspaceHandler) Delete(c *gin.Context) { - id := c.Param("id") - ctx := c.Request.Context() - confirm := c.Query("confirm") == "true" - - // #687: reject non-UUID IDs before hitting the DB. - if err := validateWorkspaceID(id); err != nil { - c.JSON(http.StatusBadRequest, gin.H{"error": "invalid workspace ID"}) - return - } - - // Check for children - rows, err := db.DB.QueryContext(ctx, - `SELECT id, name FROM workspaces WHERE parent_id = $1 AND status != 'removed'`, id) - if err != nil { - c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to check children"}) - return - } - defer rows.Close() - - var children []map[string]string - for rows.Next() { - var childID, childName string - if rows.Scan(&childID, &childName) == nil { - children = append(children, map[string]string{"id": childID, "name": childName}) - } - } - if err := rows.Err(); err != nil { - log.Printf("Delete: child rows error: %v", err) - c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to check children"}) - return - } - - // If has children and not confirmed, return children list for confirmation. - // Uses HTTP 409 Conflict (not 200) so `curl --fail`, `fetch().ok`, and any - // client that treats HTTP 4xx as an error surfaces the confirmation - // requirement. Body shape unchanged so the canvas UI's parser keeps - // working. Fixes #88. - if len(children) > 0 && !confirm { - c.JSON(http.StatusConflict, gin.H{ - "status": "confirmation_required", - "message": "This workspace has sub-workspaces. Delete with ?confirm=true to cascade delete.", - "children": children, - "children_count": len(children), - }) - return - } - - // Cascade delete: collect ALL descendants (not just direct children) via - // recursive CTE, then stop each container and remove each volume. - // Previous bug: only direct children's containers were stopped, leaving - // grandchildren as orphan running containers after a cascade delete. - descendantIDs := []string{} - if len(children) > 0 { - descRows, err := db.DB.QueryContext(ctx, ` - WITH RECURSIVE descendants AS ( - SELECT id FROM workspaces WHERE parent_id = $1 AND status != 'removed' - UNION ALL - SELECT w.id FROM workspaces w JOIN descendants d ON w.parent_id = d.id WHERE w.status != 'removed' - ) - SELECT id FROM descendants - `, id) - if err != nil { - log.Printf("Delete: descendant query error for %s: %v", id, err) - } else { - for descRows.Next() { - var descID string - if descRows.Scan(&descID) == nil { - descendantIDs = append(descendantIDs, descID) - } - } - descRows.Close() - } - } - - // #73 fix: mark rows 'removed' in the DB FIRST, BEFORE stopping containers - // or removing volumes. Previously the sequence was stop → update-status, - // which left a gap where: - // - the container's last pre-teardown heartbeat could resurrect the row - // via the register-handler UPSERT (now also guarded in #73) - // - the liveness monitor could observe 'online' status + expired Redis - // TTL and trigger RestartByID, recreating a container we're trying - // to destroy - // Marking 'removed' first makes both of those paths no-op via their - // existing `status NOT IN ('removed', ...)` guards. - allIDs := append([]string{id}, descendantIDs...) - if _, err := db.DB.ExecContext(ctx, - `UPDATE workspaces SET status = 'removed', updated_at = now() WHERE id = ANY($1::uuid[])`, - pq.Array(allIDs)); err != nil { - log.Printf("Delete status update error for %s: %v", id, err) - } - if _, err := db.DB.ExecContext(ctx, - `DELETE FROM canvas_layouts WHERE workspace_id = ANY($1::uuid[])`, - pq.Array(allIDs)); err != nil { - log.Printf("Delete canvas_layouts error for %s: %v", id, err) - } - // Revoke all auth tokens for the deleted workspaces. Once the workspace is - // gone its tokens are meaningless; leaving them alive would keep - // HasAnyLiveTokenGlobal = true even after the platform is otherwise empty, - // which prevents AdminAuth from returning to fail-open and breaks the E2E - // test's count-zero assertion (and local re-run cleanup). - if _, err := db.DB.ExecContext(ctx, - `UPDATE workspace_auth_tokens SET revoked_at = now() - WHERE workspace_id = ANY($1::uuid[]) AND revoked_at IS NULL`, - pq.Array(allIDs)); err != nil { - log.Printf("Delete token revocation error for %s: %v", id, err) - } -// #1027: cascade-disable all schedules for the deleted workspaces so - // the scheduler never fires a cron into a removed container. - if _, err := db.DB.ExecContext(ctx, - `UPDATE workspace_schedules SET enabled = false, updated_at = now() - WHERE workspace_id = ANY($1::uuid[]) AND enabled = true`, - pq.Array(allIDs)); err != nil { - log.Printf("Delete schedule disable error for %s: %v", id, err) - } - - // Now stop containers + remove volumes for all descendants (any depth). - // Any concurrent heartbeat / registration / liveness-triggered restart - // will see status='removed' and bail out early. - for _, descID := range descendantIDs { - if h.provisioner != nil { - h.provisioner.Stop(ctx, descID) - if err := h.provisioner.RemoveVolume(ctx, descID); err != nil { - log.Printf("Delete descendant %s volume removal warning: %v", descID, err) - } - } - db.ClearWorkspaceKeys(ctx, descID) - h.broadcaster.RecordAndBroadcast(ctx, "WORKSPACE_REMOVED", descID, map[string]interface{}{}) - } - - // Stop + remove volume for the workspace itself - if h.provisioner != nil { - h.provisioner.Stop(ctx, id) - if err := h.provisioner.RemoveVolume(ctx, id); err != nil { - log.Printf("Delete %s volume removal warning: %v", id, err) - } - } - db.ClearWorkspaceKeys(ctx, id) - - h.broadcaster.RecordAndBroadcast(ctx, "WORKSPACE_REMOVED", id, map[string]interface{}{ - "cascade_deleted": len(descendantIDs), - }) - - // Hard purge: cascade delete all FK data and remove the DB row entirely (#1087) - if c.Query("purge") == "true" { - purgeIDs := pq.Array(allIDs) - // Order matters: delete from leaf tables first, then workspace row - for _, table := range []string{ - "agent_memories", "activity_logs", "workspace_secrets", - "workspace_channels", "workspace_config", "workspace_memory", - "workspace_token_usage", "approval_requests", "audit_events", - "workflow_checkpoints", "workspace_artifacts", "agents", - "workspace_auth_tokens", "workspace_schedules", "canvas_layouts", - } { - if _, err := db.DB.ExecContext(ctx, - fmt.Sprintf("DELETE FROM %s WHERE workspace_id = ANY($1::uuid[])", table), - purgeIDs); err != nil { - log.Printf("Purge %s error for %v: %v", table, allIDs, err) - } - } - // Null out parent_id / forwarded_to references - db.DB.ExecContext(ctx, "UPDATE workspaces SET parent_id = NULL WHERE parent_id = ANY($1::uuid[])", purgeIDs) - db.DB.ExecContext(ctx, "UPDATE workspaces SET forwarded_to = NULL WHERE forwarded_to = ANY($1::uuid[])", purgeIDs) - // Hard delete the workspace row - if _, err := db.DB.ExecContext(ctx, "DELETE FROM workspaces WHERE id = ANY($1::uuid[])", purgeIDs); err != nil { - log.Printf("Purge workspace row error for %v: %v", allIDs, err) - c.JSON(http.StatusInternalServerError, gin.H{"error": "purge failed"}) - return - } - c.JSON(http.StatusOK, gin.H{"status": "purged", "cascade_deleted": len(descendantIDs)}) - return - } - - c.JSON(http.StatusOK, gin.H{"status": "removed", "cascade_deleted": len(descendantIDs)}) -} - -// validateWorkspaceID returns an error when id is not a valid UUID. -// #687: prevents 500s from Postgres when a garbage string (e.g. ../../etc/passwd) -// is passed as the :id path parameter. -func validateWorkspaceID(id string) error { - if _, err := uuid.Parse(id); err != nil { - return fmt.Errorf("invalid workspace id") - } - return nil -} - -// yamlSpecialChars is the set of YAML-special characters banned from workspace -// name and role. Newlines are handled separately below (same error message for -// all four fields); these additional characters target YAML block indicators, -// flow-sequence/mapping delimiters, and shell-expansion metacharacters that -// yamlQuote does NOT escape inside a double-quoted scalar (#685). -const yamlSpecialChars = "{}[]|>*&!" - -// validateWorkspaceFields enforces maximum field lengths and rejects characters -// that could enable YAML-injection in downstream provisioning paths. -// #685 (defence-in-depth over yamlQuote — newline + YAML-special chars in name/role), -// #688 (max field lengths). -func validateWorkspaceFields(name, role, model, runtime string) error { - // All four fields: reject newline / carriage-return. - for _, f := range []struct{ label, val string }{ - {"name", name}, - {"role", role}, - {"model", model}, - {"runtime", runtime}, - } { - if strings.ContainsAny(f.val, "\n\r") { - return fmt.Errorf("%s must not contain newline characters", f.label) - } - } - // name and role only: reject YAML-special characters (#685). - for _, f := range []struct{ label, val string }{ - {"name", name}, - {"role", role}, - } { - if strings.ContainsAny(f.val, yamlSpecialChars) { - return fmt.Errorf("%s contains invalid characters", f.label) - } - } - if len(name) > 255 { - return fmt.Errorf("name must be at most 255 characters") - } - if len(role) > 1000 { - return fmt.Errorf("role must be at most 1000 characters") - } - if len(model) > 100 { - return fmt.Errorf("model must be at most 100 characters") - } - if len(runtime) > 100 { - return fmt.Errorf("runtime must be at most 100 characters") - } - return nil -} diff --git a/workspace-server/internal/handlers/workspace_crud.go b/workspace-server/internal/handlers/workspace_crud.go new file mode 100644 index 00000000..741ac5c2 --- /dev/null +++ b/workspace-server/internal/handlers/workspace_crud.go @@ -0,0 +1,489 @@ +package handlers + +// workspace_crud.go — workspace state queries, updates, deletion, and +// field validation. Covers State (polling endpoint), Update (PATCH), +// Delete (cascade + purge), and input validation helpers. + +import ( + "database/sql" + "fmt" + "log" + "net/http" + "path/filepath" + "strings" + + "github.com/Molecule-AI/molecule-monorepo/platform/internal/db" + "github.com/Molecule-AI/molecule-monorepo/platform/internal/wsauth" + "github.com/gin-gonic/gin" + "github.com/google/uuid" + "github.com/lib/pq" +) +// State handles GET /workspaces/:id/state — minimal status payload for +// remote-agent polling (Phase 30.4). Returns `{status, paused, deleted, +// workspace_id}` so a remote agent can detect pause/resume/delete +// without needing WebSocket reachability from the platform. +// +// Auth: Phase 30.1 bearer token required when the workspace has any +// live token on file; legacy workspaces grandfathered. Uses the same +// fail-closed posture as secrets.Values — polling this cadence with +// unauth'd callers would be a trivial DoS / workspace-status-scanner +// otherwise. +// +// The endpoint is deliberately NOT merged with GET /workspaces/:id: +// that handler is optimized for canvas (returns config, agent_card, +// position, …) and is unauthenticated by design. State is the +// agent-machinery polling path — tight, token-gated, cache-friendly. +func (h *WorkspaceHandler) State(c *gin.Context) { + workspaceID := c.Param("id") + ctx := c.Request.Context() + + // Auth gate — same shape as secrets.Values (Phase 30.2). Fail-closed + // on DB errors because the caller is about to poll this at ~60s + // cadence; letting unauth'd callers through on a hiccup turns this + // into a workspace-status scanner. + hasLive, hlErr := wsauth.HasAnyLiveToken(ctx, db.DB, workspaceID) + if hlErr != nil { + log.Printf("wsauth: HasAnyLiveToken(%s) failed for workspace.State: %v", workspaceID, hlErr) + c.JSON(http.StatusInternalServerError, gin.H{"error": "auth check failed"}) + return + } + if hasLive { + tok := wsauth.BearerTokenFromHeader(c.GetHeader("Authorization")) + if tok == "" { + c.JSON(http.StatusUnauthorized, gin.H{"error": "missing workspace auth token"}) + return + } + if err := wsauth.ValidateToken(ctx, db.DB, workspaceID, tok); err != nil { + c.JSON(http.StatusUnauthorized, gin.H{"error": "invalid workspace auth token"}) + return + } + } + + var status string + err := db.DB.QueryRowContext(ctx, ` + SELECT status + FROM workspaces + WHERE id = $1 + `, workspaceID).Scan(&status) + if err == sql.ErrNoRows { + // A deleted workspace row no longer exists — remote agent should + // interpret 404 as "shut yourself down" (our pause path uses + // status='removed' but keeps the row; a 404 here means the + // workspace was hard-deleted out from under the agent). + c.JSON(http.StatusNotFound, gin.H{ + "workspace_id": workspaceID, + "deleted": true, + }) + return + } + if err != nil { + log.Printf("workspace.State query error for %s: %v", workspaceID, err) + c.JSON(http.StatusInternalServerError, gin.H{"error": "query failed"}) + return + } + + // Two delete paths: hard-delete (sql.ErrNoRows above → 404) AND + // soft-delete (status='removed' → also return 404 here so the SDK + // doesn't have to remember "is it 200 with deleted=true OR 404 with + // deleted=true?"). Same shape, same status code, same flag set. + if status == "removed" { + c.JSON(http.StatusNotFound, gin.H{ + "workspace_id": workspaceID, + "status": "removed", + "deleted": true, + }) + return + } + + c.JSON(http.StatusOK, gin.H{ + "workspace_id": workspaceID, + "status": status, + "paused": status == "paused", + "deleted": false, + }) +} + +// sensitiveUpdateFields documents fields that carry elevated risk — kept as +// an explicit list for code readability and future audits. Auth is now fully +// enforced at the router layer (WorkspaceAuth middleware, #680 IDOR fix); +// this map is no longer used for in-handler gate logic but is preserved to +// surface the risk classification clearly. +// +// budget_limit is intentionally NOT here — the dedicated PATCH +// /workspaces/:id/budget (AdminAuth) is the only write path (#611). +var sensitiveUpdateFields = map[string]struct{}{ + "tier": {}, + "parent_id": {}, + "runtime": {}, + "workspace_dir": {}, +} + +// Update handles PATCH /workspaces/:id +func (h *WorkspaceHandler) Update(c *gin.Context) { + id := c.Param("id") + + // #687: reject non-UUID IDs before hitting the DB. + if err := validateWorkspaceID(id); err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": "invalid workspace ID"}) + return + } + + var body map[string]interface{} + if err := c.ShouldBindJSON(&body); err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": "invalid request body"}) + return + } + + // #685/#688: validate string fields for length and injection safety. + strField := func(key string) string { + if v, ok := body[key]; ok { + if s, ok := v.(string); ok { + return s + } + } + return "" + } + if err := validateWorkspaceFields( + strField("name"), strField("role"), "" /*model not patchable*/, strField("runtime"), + ); err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": "invalid workspace fields"}) + return + } + + ctx := c.Request.Context() + + // Auth is fully enforced at the router layer (WorkspaceAuth middleware, #680). + // WorkspaceAuth validates that the caller holds a valid bearer token for this + // specific workspace — no additional auth gate is needed here. The + // sensitiveUpdateFields map above documents the risk classification for + // auditors but is no longer used as a runtime gate. + + // #120: guard — return 404 for nonexistent workspace IDs instead of + // silently applying zero-row UPDATEs and returning 200. + var exists bool + if err := db.DB.QueryRowContext(ctx, + `SELECT EXISTS(SELECT 1 FROM workspaces WHERE id = $1)`, id, + ).Scan(&exists); err != nil || !exists { + c.JSON(http.StatusNotFound, gin.H{"error": "workspace not found"}) + return + } + + if name, ok := body["name"]; ok { + if _, err := db.DB.ExecContext(ctx, `UPDATE workspaces SET name = $2, updated_at = now() WHERE id = $1`, id, name); err != nil { + log.Printf("Update name error for %s: %v", id, err) + } + } + if role, ok := body["role"]; ok { + if _, err := db.DB.ExecContext(ctx, `UPDATE workspaces SET role = $2, updated_at = now() WHERE id = $1`, id, role); err != nil { + log.Printf("Update role error for %s: %v", id, err) + } + } + if tier, ok := body["tier"]; ok { + if _, err := db.DB.ExecContext(ctx, `UPDATE workspaces SET tier = $2, updated_at = now() WHERE id = $1`, id, tier); err != nil { + log.Printf("Update tier error for %s: %v", id, err) + } + } + if parentID, ok := body["parent_id"]; ok { + if _, err := db.DB.ExecContext(ctx, `UPDATE workspaces SET parent_id = $2, updated_at = now() WHERE id = $1`, id, parentID); err != nil { + log.Printf("Update parent_id error for %s: %v", id, err) + } + } + if runtime, ok := body["runtime"]; ok { + if _, err := db.DB.ExecContext(ctx, `UPDATE workspaces SET runtime = $2, updated_at = now() WHERE id = $1`, id, runtime); err != nil { + log.Printf("Update runtime error for %s: %v", id, err) + } + } + needsRestart := false + if wsDir, ok := body["workspace_dir"]; ok { + // Allow null to clear workspace_dir + if wsDir != nil { + if dirStr, isStr := wsDir.(string); isStr && dirStr != "" { + if err := validateWorkspaceDir(dirStr); err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": "invalid workspace directory"}) + return + } + } + } + if _, err := db.DB.ExecContext(ctx, `UPDATE workspaces SET workspace_dir = $2, updated_at = now() WHERE id = $1`, id, wsDir); err != nil { + log.Printf("Update workspace_dir error for %s: %v", id, err) + } + needsRestart = true + } + // NOTE: budget_limit is intentionally NOT handled here. The dedicated + // PATCH /workspaces/:id/budget (AdminAuth) is the only write path. + // This endpoint uses ValidateAnyToken — any enrolled workspace bearer + // could otherwise self-clear its own spending ceiling. (#611 Security Auditor) + + // Update canvas position if both x and y provided + if x, xOk := body["x"]; xOk { + if y, yOk := body["y"]; yOk { + if _, err := db.DB.ExecContext(ctx, ` + INSERT INTO canvas_layouts (workspace_id, x, y) + VALUES ($1, $2, $3) + ON CONFLICT (workspace_id) DO UPDATE SET x = EXCLUDED.x, y = EXCLUDED.y + `, id, x, y); err != nil { + log.Printf("Update position error for %s: %v", id, err) + } + } + } + + resp := gin.H{"status": "updated"} + if needsRestart { + resp["needs_restart"] = true + } + c.JSON(http.StatusOK, resp) +} + +// validateWorkspaceDir checks that a workspace_dir path is safe to bind-mount. +func validateWorkspaceDir(dir string) error { + if !filepath.IsAbs(dir) { + return fmt.Errorf("workspace_dir must be an absolute path") + } + if strings.Contains(dir, "..") { + return fmt.Errorf("workspace_dir must not contain '..'") + } + // Reject system-critical paths + clean := filepath.Clean(dir) + for _, blocked := range []string{"/etc", "/var", "/proc", "/sys", "/dev", "/boot", "/sbin", "/bin", "/lib", "/usr"} { + if clean == blocked || strings.HasPrefix(clean, blocked+"/") { + return fmt.Errorf("workspace_dir must not be a system path (%s)", blocked) + } + } + return nil +} + +// Delete handles DELETE /workspaces/:id +// If the workspace has children (is a team), cascade deletes all sub-workspaces. +// Use ?confirm=true to actually delete (otherwise returns children list for confirmation). +func (h *WorkspaceHandler) Delete(c *gin.Context) { + id := c.Param("id") + ctx := c.Request.Context() + confirm := c.Query("confirm") == "true" + + // #687: reject non-UUID IDs before hitting the DB. + if err := validateWorkspaceID(id); err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": "invalid workspace ID"}) + return + } + + // Check for children + rows, err := db.DB.QueryContext(ctx, + `SELECT id, name FROM workspaces WHERE parent_id = $1 AND status != 'removed'`, id) + if err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to check children"}) + return + } + defer rows.Close() + + var children []map[string]string + for rows.Next() { + var childID, childName string + if rows.Scan(&childID, &childName) == nil { + children = append(children, map[string]string{"id": childID, "name": childName}) + } + } + if err := rows.Err(); err != nil { + log.Printf("Delete: child rows error: %v", err) + c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to check children"}) + return + } + + // If has children and not confirmed, return children list for confirmation. + // Uses HTTP 409 Conflict (not 200) so `curl --fail`, `fetch().ok`, and any + // client that treats HTTP 4xx as an error surfaces the confirmation + // requirement. Body shape unchanged so the canvas UI's parser keeps + // working. Fixes #88. + if len(children) > 0 && !confirm { + c.JSON(http.StatusConflict, gin.H{ + "status": "confirmation_required", + "message": "This workspace has sub-workspaces. Delete with ?confirm=true to cascade delete.", + "children": children, + "children_count": len(children), + }) + return + } + + // Cascade delete: collect ALL descendants (not just direct children) via + // recursive CTE, then stop each container and remove each volume. + // Previous bug: only direct children's containers were stopped, leaving + // grandchildren as orphan running containers after a cascade delete. + descendantIDs := []string{} + if len(children) > 0 { + descRows, err := db.DB.QueryContext(ctx, ` + WITH RECURSIVE descendants AS ( + SELECT id FROM workspaces WHERE parent_id = $1 AND status != 'removed' + UNION ALL + SELECT w.id FROM workspaces w JOIN descendants d ON w.parent_id = d.id WHERE w.status != 'removed' + ) + SELECT id FROM descendants + `, id) + if err != nil { + log.Printf("Delete: descendant query error for %s: %v", id, err) + } else { + for descRows.Next() { + var descID string + if descRows.Scan(&descID) == nil { + descendantIDs = append(descendantIDs, descID) + } + } + descRows.Close() + } + } + + // #73 fix: mark rows 'removed' in the DB FIRST, BEFORE stopping containers + // or removing volumes. Previously the sequence was stop → update-status, + // which left a gap where: + // - the container's last pre-teardown heartbeat could resurrect the row + // via the register-handler UPSERT (now also guarded in #73) + // - the liveness monitor could observe 'online' status + expired Redis + // TTL and trigger RestartByID, recreating a container we're trying + // to destroy + // Marking 'removed' first makes both of those paths no-op via their + // existing `status NOT IN ('removed', ...)` guards. + allIDs := append([]string{id}, descendantIDs...) + if _, err := db.DB.ExecContext(ctx, + `UPDATE workspaces SET status = 'removed', updated_at = now() WHERE id = ANY($1::uuid[])`, + pq.Array(allIDs)); err != nil { + log.Printf("Delete status update error for %s: %v", id, err) + } + if _, err := db.DB.ExecContext(ctx, + `DELETE FROM canvas_layouts WHERE workspace_id = ANY($1::uuid[])`, + pq.Array(allIDs)); err != nil { + log.Printf("Delete canvas_layouts error for %s: %v", id, err) + } + // Revoke all auth tokens for the deleted workspaces. Once the workspace is + // gone its tokens are meaningless; leaving them alive would keep + // HasAnyLiveTokenGlobal = true even after the platform is otherwise empty, + // which prevents AdminAuth from returning to fail-open and breaks the E2E + // test's count-zero assertion (and local re-run cleanup). + if _, err := db.DB.ExecContext(ctx, + `UPDATE workspace_auth_tokens SET revoked_at = now() + WHERE workspace_id = ANY($1::uuid[]) AND revoked_at IS NULL`, + pq.Array(allIDs)); err != nil { + log.Printf("Delete token revocation error for %s: %v", id, err) + } +// #1027: cascade-disable all schedules for the deleted workspaces so + // the scheduler never fires a cron into a removed container. + if _, err := db.DB.ExecContext(ctx, + `UPDATE workspace_schedules SET enabled = false, updated_at = now() + WHERE workspace_id = ANY($1::uuid[]) AND enabled = true`, + pq.Array(allIDs)); err != nil { + log.Printf("Delete schedule disable error for %s: %v", id, err) + } + + // Now stop containers + remove volumes for all descendants (any depth). + // Any concurrent heartbeat / registration / liveness-triggered restart + // will see status='removed' and bail out early. + for _, descID := range descendantIDs { + if h.provisioner != nil { + h.provisioner.Stop(ctx, descID) + if err := h.provisioner.RemoveVolume(ctx, descID); err != nil { + log.Printf("Delete descendant %s volume removal warning: %v", descID, err) + } + } + db.ClearWorkspaceKeys(ctx, descID) + h.broadcaster.RecordAndBroadcast(ctx, "WORKSPACE_REMOVED", descID, map[string]interface{}{}) + } + + // Stop + remove volume for the workspace itself + if h.provisioner != nil { + h.provisioner.Stop(ctx, id) + if err := h.provisioner.RemoveVolume(ctx, id); err != nil { + log.Printf("Delete %s volume removal warning: %v", id, err) + } + } + db.ClearWorkspaceKeys(ctx, id) + + h.broadcaster.RecordAndBroadcast(ctx, "WORKSPACE_REMOVED", id, map[string]interface{}{ + "cascade_deleted": len(descendantIDs), + }) + + // Hard purge: cascade delete all FK data and remove the DB row entirely (#1087) + if c.Query("purge") == "true" { + purgeIDs := pq.Array(allIDs) + // Order matters: delete from leaf tables first, then workspace row + for _, table := range []string{ + "agent_memories", "activity_logs", "workspace_secrets", + "workspace_channels", "workspace_config", "workspace_memory", + "workspace_token_usage", "approval_requests", "audit_events", + "workflow_checkpoints", "workspace_artifacts", "agents", + "workspace_auth_tokens", "workspace_schedules", "canvas_layouts", + } { + if _, err := db.DB.ExecContext(ctx, + fmt.Sprintf("DELETE FROM %s WHERE workspace_id = ANY($1::uuid[])", table), + purgeIDs); err != nil { + log.Printf("Purge %s error for %v: %v", table, allIDs, err) + } + } + // Null out parent_id / forwarded_to references + db.DB.ExecContext(ctx, "UPDATE workspaces SET parent_id = NULL WHERE parent_id = ANY($1::uuid[])", purgeIDs) + db.DB.ExecContext(ctx, "UPDATE workspaces SET forwarded_to = NULL WHERE forwarded_to = ANY($1::uuid[])", purgeIDs) + // Hard delete the workspace row + if _, err := db.DB.ExecContext(ctx, "DELETE FROM workspaces WHERE id = ANY($1::uuid[])", purgeIDs); err != nil { + log.Printf("Purge workspace row error for %v: %v", allIDs, err) + c.JSON(http.StatusInternalServerError, gin.H{"error": "purge failed"}) + return + } + c.JSON(http.StatusOK, gin.H{"status": "purged", "cascade_deleted": len(descendantIDs)}) + return + } + + c.JSON(http.StatusOK, gin.H{"status": "removed", "cascade_deleted": len(descendantIDs)}) +} + +// validateWorkspaceID returns an error when id is not a valid UUID. +// #687: prevents 500s from Postgres when a garbage string (e.g. ../../etc/passwd) +// is passed as the :id path parameter. +func validateWorkspaceID(id string) error { + if _, err := uuid.Parse(id); err != nil { + return fmt.Errorf("invalid workspace id") + } + return nil +} + +// yamlSpecialChars is the set of YAML-special characters banned from workspace +// name and role. Newlines are handled separately below (same error message for +// all four fields); these additional characters target YAML block indicators, +// flow-sequence/mapping delimiters, and shell-expansion metacharacters that +// yamlQuote does NOT escape inside a double-quoted scalar (#685). +const yamlSpecialChars = "{}[]|>*&!" + +// validateWorkspaceFields enforces maximum field lengths and rejects characters +// that could enable YAML-injection in downstream provisioning paths. +// #685 (defence-in-depth over yamlQuote — newline + YAML-special chars in name/role), +// #688 (max field lengths). +func validateWorkspaceFields(name, role, model, runtime string) error { + // All four fields: reject newline / carriage-return. + for _, f := range []struct{ label, val string }{ + {"name", name}, + {"role", role}, + {"model", model}, + {"runtime", runtime}, + } { + if strings.ContainsAny(f.val, "\n\r") { + return fmt.Errorf("%s must not contain newline characters", f.label) + } + } + // name and role only: reject YAML-special characters (#685). + for _, f := range []struct{ label, val string }{ + {"name", name}, + {"role", role}, + } { + if strings.ContainsAny(f.val, yamlSpecialChars) { + return fmt.Errorf("%s contains invalid characters", f.label) + } + } + if len(name) > 255 { + return fmt.Errorf("name must be at most 255 characters") + } + if len(role) > 1000 { + return fmt.Errorf("role must be at most 1000 characters") + } + if len(model) > 100 { + return fmt.Errorf("model must be at most 100 characters") + } + if len(runtime) > 100 { + return fmt.Errorf("runtime must be at most 100 characters") + } + return nil +} diff --git a/workspace-server/internal/handlers/workspace_provision_test.go b/workspace-server/internal/handlers/workspace_provision_test.go index b1f5f12e..8feb4403 100644 --- a/workspace-server/internal/handlers/workspace_provision_test.go +++ b/workspace-server/internal/handlers/workspace_provision_test.go @@ -906,6 +906,8 @@ func containsStr(s, substr string) bool { // truncates content at maxMemoryContentLength before INSERT. Regression // test for the error-sanitization / memory-seed contract. func TestSeedInitialMemories_Truncation(t *testing.T) { + mock := setupTestDB(t) + largeContent := string(make([]byte, 100_001)) copy([]byte(largeContent), "X") // fill with "X" so test is deterministic @@ -1095,7 +1097,7 @@ func TestProvisionWorkspace_NoInternalErrorsInBroadcast(t *testing.T) { mock.ExpectQuery(`SELECT key, encrypted_value, encryption_version FROM global_secrets`). WillReturnError(errInternalDB) - broadcaster := &captureBroadcaster{} + broadcaster := &captureBroadcaster{broadcaster: events.NewBroadcaster(nil)} handler := &WorkspaceHandler{ broadcaster: broadcaster, provisioner: &provisioner.Provisioner{}, @@ -1143,7 +1145,7 @@ func TestProvisionWorkspaceCP_NoInternalErrorsInBroadcast(t *testing.T) { mock.ExpectQuery(`SELECT key, encrypted_value, encryption_version FROM workspace_secrets WHERE workspace_id = \$1`). WillReturnRows(sqlmock.NewRows([]string{"key", "encrypted_value", "encryption_version"})) - broadcaster := &captureBroadcaster{} + broadcaster := &captureBroadcaster{broadcaster: events.NewBroadcaster(nil)} registry := &mockEnvMutator{returnErr: errInternalDB} handler := &WorkspaceHandler{ broadcaster: broadcaster, diff --git a/workspace-server/internal/models/workspace.go b/workspace-server/internal/models/workspace.go index ff8ad0be..26061a1f 100644 --- a/workspace-server/internal/models/workspace.go +++ b/workspace-server/internal/models/workspace.go @@ -22,6 +22,7 @@ type Workspace struct { LastErrorRate float64 `json:"last_error_rate" db:"last_error_rate"` LastSampleError sql.NullString `json:"last_sample_error" db:"last_sample_error"` ActiveTasks int `json:"active_tasks" db:"active_tasks"` + MaxConcurrentTasks int `json:"max_concurrent_tasks" db:"max_concurrent_tasks"` UptimeSeconds int `json:"uptime_seconds" db:"uptime_seconds"` CreatedAt time.Time `json:"created_at" db:"created_at"` UpdatedAt time.Time `json:"updated_at" db:"updated_at"` diff --git a/workspace-server/internal/orgtoken/tokens_test.go b/workspace-server/internal/orgtoken/tokens_test.go index 9f51f46a..e3bee7e7 100644 --- a/workspace-server/internal/orgtoken/tokens_test.go +++ b/workspace-server/internal/orgtoken/tokens_test.go @@ -79,7 +79,7 @@ func TestValidate_HappyPath(t *testing.T) { WithArgs("tok-live"). WillReturnResult(sqlmock.NewResult(0, 1)) - id, prefix, orgID, err := Validate(context.Background(), db, plaintext) + id, prefix, _, err := Validate(context.Background(), db, plaintext) if err != nil { t.Fatalf("Validate: %v", err) } diff --git a/workspace-server/internal/scheduler/scheduler.go b/workspace-server/internal/scheduler/scheduler.go index 4fa12880..4ae82247 100644 --- a/workspace-server/internal/scheduler/scheduler.go +++ b/workspace-server/internal/scheduler/scheduler.go @@ -267,32 +267,36 @@ func (s *Scheduler) fireSchedule(ctx context.Context, sched scheduleRow) { // This replaces the #115 "skip when busy" pattern which caused crons // to permanently miss when workspaces were perpetually busy from the // Orchestrator pulse delegation chain (~30% message drop rate on Dev Lead). + // Check workspace capacity — fire when active_tasks < max_concurrent_tasks. + // Default max is 1 (backward compatible). Workspaces can override via config + // to allow concurrent task processing (e.g. leaders handling A2A while cron runs). var activeTasks int + var maxConcurrent int if err := db.DB.QueryRowContext(ctx, - `SELECT COALESCE(active_tasks, 0) FROM workspaces WHERE id = $1`, + `SELECT COALESCE(active_tasks, 0), COALESCE(max_concurrent_tasks, 1) FROM workspaces WHERE id = $1`, sched.WorkspaceID, - ).Scan(&activeTasks); err == nil && activeTasks > 0 { - log.Printf("Scheduler: '%s' workspace %s busy (active_tasks=%d), deferring up to 2 min", - sched.Name, short(sched.WorkspaceID, 12), activeTasks) + ).Scan(&activeTasks, &maxConcurrent); err == nil && activeTasks >= maxConcurrent { + log.Printf("Scheduler: '%s' workspace %s at capacity (active_tasks=%d, max=%d), deferring up to 2 min", + sched.Name, short(sched.WorkspaceID, 12), activeTasks, maxConcurrent) // Poll every 10s for up to 2 minutes waited := false for i := 0; i < 12; i++ { time.Sleep(10 * time.Second) if err := db.DB.QueryRowContext(ctx, - `SELECT COALESCE(active_tasks, 0) FROM workspaces WHERE id = $1`, + `SELECT COALESCE(active_tasks, 0), COALESCE(max_concurrent_tasks, 1) FROM workspaces WHERE id = $1`, sched.WorkspaceID, - ).Scan(&activeTasks); err != nil || activeTasks == 0 { + ).Scan(&activeTasks, &maxConcurrent); err != nil || activeTasks < maxConcurrent { waited = true break } } - if !waited && activeTasks > 0 { - log.Printf("Scheduler: skipping '%s' on busy workspace %s after 2 min wait (active_tasks=%d)", - sched.Name, short(sched.WorkspaceID, 12), activeTasks) + if !waited && activeTasks >= maxConcurrent { + log.Printf("Scheduler: skipping '%s' on busy workspace %s after 2 min wait (active_tasks=%d, max=%d)", + sched.Name, short(sched.WorkspaceID, 12), activeTasks, maxConcurrent) s.recordSkipped(ctx, sched, activeTasks) return } - log.Printf("Scheduler: '%s' workspace %s now idle after deferral, firing", + log.Printf("Scheduler: '%s' workspace %s has capacity after deferral, firing", sched.Name, short(sched.WorkspaceID, 12)) } diff --git a/workspace-server/migrations/037_max_concurrent_tasks.down.sql b/workspace-server/migrations/037_max_concurrent_tasks.down.sql new file mode 100644 index 00000000..d5274526 --- /dev/null +++ b/workspace-server/migrations/037_max_concurrent_tasks.down.sql @@ -0,0 +1 @@ +ALTER TABLE workspaces DROP COLUMN IF EXISTS max_concurrent_tasks; diff --git a/workspace-server/migrations/037_max_concurrent_tasks.up.sql b/workspace-server/migrations/037_max_concurrent_tasks.up.sql new file mode 100644 index 00000000..644ea18c --- /dev/null +++ b/workspace-server/migrations/037_max_concurrent_tasks.up.sql @@ -0,0 +1,5 @@ +-- Per-workspace concurrency limit (#1408). +-- Default 1 preserves current behavior (single-task). Leaders can be +-- configured with higher values to accept A2A delegations while a cron runs. +ALTER TABLE workspaces + ADD COLUMN IF NOT EXISTS max_concurrent_tasks INTEGER NOT NULL DEFAULT 1; diff --git a/workspace/a2a_tools.py b/workspace/a2a_tools.py index b3ffbbfd..04633209 100644 --- a/workspace/a2a_tools.py +++ b/workspace/a2a_tools.py @@ -3,6 +3,7 @@ Imports shared client functions and constants from a2a_client. """ +import hashlib import json import uuid @@ -124,11 +125,16 @@ async def tool_delegate_task_async(workspace_id: str, task: str) -> str: if not workspace_id or not task: return "Error: workspace_id and task are required" + # Idempotency key: SHA-256 of (workspace_id, task) so that a restarted agent + # firing the same delegation gets the same key and the platform returns the + # existing delegation_id instead of creating a duplicate. Fixes #1456. + idem_key = hashlib.sha256(f"{workspace_id}:{task}".encode()).hexdigest()[:32] + try: async with httpx.AsyncClient(timeout=10.0) as client: resp = await client.post( f"{PLATFORM_URL}/workspaces/{WORKSPACE_ID}/delegate", - json={"target_id": workspace_id, "task": task}, + json={"target_id": workspace_id, "task": task, "idempotency_key": idem_key}, headers=_auth_headers_for_heartbeat(), ) if resp.status_code == 202: diff --git a/workspace/executor_helpers.py b/workspace/executor_helpers.py index 0d6e2d85..f40fa6b7 100644 --- a/workspace/executor_helpers.py +++ b/workspace/executor_helpers.py @@ -199,14 +199,26 @@ def read_delegation_results() -> str: # ======================================================================== async def set_current_task(heartbeat: "HeartbeatLoop | None", task: str) -> None: - """Update current task on heartbeat and push immediately via platform API.""" + """Update current task on heartbeat and push immediately via platform API. + + Uses increment/decrement instead of binary 0/1 so agents can track + multiple concurrent tasks (#1408). Pushes immediately on both + increment and decrement to avoid phantom-busy (#1372). + """ if heartbeat is not None: - heartbeat.current_task = task - heartbeat.active_tasks = 1 if task else 0 + if task: + heartbeat.active_tasks = getattr(heartbeat, "active_tasks", 0) + 1 + heartbeat.current_task = task + else: + heartbeat.active_tasks = max(0, getattr(heartbeat, "active_tasks", 0) - 1) + if heartbeat.active_tasks == 0: + heartbeat.current_task = "" workspace_id = os.environ.get("WORKSPACE_ID", "") platform_url = os.environ.get("PLATFORM_URL", "") if not (workspace_id and platform_url): return + active = getattr(heartbeat, "active_tasks", 0) if heartbeat is not None else (1 if task else 0) + cur_task = getattr(heartbeat, "current_task", task or "") if heartbeat is not None else (task or "") try: try: from platform_auth import auth_headers as _auth @@ -217,8 +229,8 @@ async def set_current_task(heartbeat: "HeartbeatLoop | None", task: str) -> None f"{platform_url}/registry/heartbeat", json={ "workspace_id": workspace_id, - "current_task": task, - "active_tasks": 1 if task else 0, + "current_task": cur_task, + "active_tasks": active, "error_rate": 0, "sample_error": "", "uptime_seconds": 0, diff --git a/workspace/shared_runtime.py b/workspace/shared_runtime.py index a3838664..dba05700 100644 --- a/workspace/shared_runtime.py +++ b/workspace/shared_runtime.py @@ -153,20 +153,21 @@ def brief_task(text: str, limit: int = 60) -> str: async def set_current_task(heartbeat: Any, task: str) -> None: """Update current task on heartbeat and push immediately to platform. - The heartbeat loop only fires every 30s, so quick tasks would finish - before the canvas ever sees them. Setting a task pushes immediately. - Clearing a task only updates the heartbeat object — the next heartbeat - cycle will broadcast the clear, keeping the task visible longer. + Uses increment/decrement instead of binary 0/1 so agents can track + multiple concurrent tasks (e.g. a cron running while an A2A delegation + arrives). The counter never goes below 0. + + Pushes immediately on BOTH increment and decrement to avoid phantom-busy + (#1372) where active_tasks=1 persisted in the platform DB indefinitely. """ if heartbeat: - heartbeat.current_task = task - heartbeat.active_tasks = 1 if task else 0 - - # Only push immediately when SETTING a task (not clearing) - # Clearing is handled by the next heartbeat cycle, which keeps - # the task visible on the canvas for quick A2A responses - if not task: - return + if task: + heartbeat.active_tasks = getattr(heartbeat, "active_tasks", 0) + 1 + heartbeat.current_task = task + else: + heartbeat.active_tasks = max(0, getattr(heartbeat, "active_tasks", 0) - 1) + if heartbeat.active_tasks == 0: + heartbeat.current_task = "" import os workspace_id = os.environ.get("WORKSPACE_ID", "") @@ -174,13 +175,15 @@ async def set_current_task(heartbeat: Any, task: str) -> None: if workspace_id and platform_url: try: import httpx + active = getattr(heartbeat, "active_tasks", 0) if heartbeat else (1 if task else 0) + cur_task = getattr(heartbeat, "current_task", task or "") if heartbeat else (task or "") async with httpx.AsyncClient(timeout=3.0) as client: await client.post( f"{platform_url}/registry/heartbeat", json={ "workspace_id": workspace_id, - "current_task": task, - "active_tasks": 1, + "current_task": cur_task, + "active_tasks": active, "error_rate": 0, "sample_error": "", "uptime_seconds": 0,