diff --git a/platform/internal/handlers/a2a_proxy.go b/platform/internal/handlers/a2a_proxy.go index f6ec9ce4..307c3311 100644 --- a/platform/internal/handlers/a2a_proxy.go +++ b/platform/internal/handlers/a2a_proxy.go @@ -251,6 +251,12 @@ func (h *WorkspaceHandler) proxyA2ARequest(ctx context.Context, workspaceID stri if logActivity { h.logA2ASuccess(ctx, workspaceID, callerID, body, respBody, a2aMethod, resp.StatusCode, durationMs) } + + // Track LLM token usage for cost transparency (#593). + // Fires in a detached goroutine so token accounting never adds latency + // to the critical A2A path. + go extractAndUpsertTokenUsage(context.WithoutCancel(ctx), workspaceID, respBody) + return resp.StatusCode, respBody, nil } @@ -577,3 +583,65 @@ func validateCallerToken(ctx context.Context, c *gin.Context, callerID string) e // token" branch so the handler-level guard can detect it without string // matching (the wsauth errors are typed for the invalid case). var errInvalidCallerToken = errors.New("missing caller auth token") + +// extractAndUpsertTokenUsage parses LLM usage from a raw A2A response body +// and persists it via upsertTokenUsage. Safe to call in a goroutine — logs +// errors but never panics. ctx must already be detached from the request. +func extractAndUpsertTokenUsage(ctx context.Context, workspaceID string, respBody []byte) { + in, out := parseUsageFromA2AResponse(respBody) + if in > 0 || out > 0 { + upsertTokenUsage(ctx, workspaceID, in, out) + } +} + +// parseUsageFromA2AResponse extracts input_tokens / output_tokens from an A2A +// JSON-RPC response. Inspects two locations in order of preference: +// 1. result.usage — the JSON-RPC 2.0 result envelope from workspace agents. +// 2. usage — top-level, for non-JSON-RPC or direct Anthropic-shaped payloads. +// +// Returns (0, 0) when no recognisable usage data is found. +func parseUsageFromA2AResponse(body []byte) (inputTokens, outputTokens int64) { + if len(body) == 0 { + return 0, 0 + } + var top map[string]json.RawMessage + if err := json.Unmarshal(body, &top); err != nil { + return 0, 0 + } + + // 1. result.usage (JSON-RPC 2.0 wrapper produced by workspace agents). + if rawResult, ok := top["result"]; ok { + var result map[string]json.RawMessage + if err := json.Unmarshal(rawResult, &result); err == nil { + if in, out, ok := readUsageMap(result); ok { + return in, out + } + } + } + + // 2. Fallback: top-level usage (direct Anthropic or non-JSON-RPC response). + if in, out, ok := readUsageMap(top); ok { + return in, out + } + return 0, 0 +} + +// readUsageMap extracts input_tokens / output_tokens from the "usage" key of m. +// Returns (0, 0, false) when the key is absent or contains no non-zero values. +func readUsageMap(m map[string]json.RawMessage) (inputTokens, outputTokens int64, ok bool) { + rawUsage, has := m["usage"] + if !has { + return 0, 0, false + } + var usage struct { + InputTokens int64 `json:"input_tokens"` + OutputTokens int64 `json:"output_tokens"` + } + if err := json.Unmarshal(rawUsage, &usage); err != nil { + return 0, 0, false + } + if usage.InputTokens == 0 && usage.OutputTokens == 0 { + return 0, 0, false + } + return usage.InputTokens, usage.OutputTokens, true +} diff --git a/platform/internal/handlers/workspace_metrics.go b/platform/internal/handlers/workspace_metrics.go new file mode 100644 index 00000000..db6400a3 --- /dev/null +++ b/platform/internal/handlers/workspace_metrics.go @@ -0,0 +1,125 @@ +package handlers + +import ( + "context" + "database/sql" + "fmt" + "log" + "net/http" + "time" + + "github.com/Molecule-AI/molecule-monorepo/platform/internal/db" + "github.com/gin-gonic/gin" +) + +// Pricing constants — Claude Sonnet default rates (USD per token). +// Callers with different models should override via env vars in a future phase. +const ( + tokenCostPerInputToken = 0.000003 // $3 / 1M input tokens + tokenCostPerOutputToken = 0.000015 // $15 / 1M output tokens +) + +// MetricsHandler serves GET /workspaces/:id/metrics. +type MetricsHandler struct{} + +// NewMetricsHandler returns a MetricsHandler. +func NewMetricsHandler() *MetricsHandler { return &MetricsHandler{} } + +// GetMetrics handles GET /workspaces/:id/metrics. +// +// Returns aggregated LLM token usage for the current UTC day. +// Auth: WorkspaceAuth middleware (bearer token bound to :id). +// +// Response: +// +// { +// "input_tokens": , +// "output_tokens": , +// "total_calls": , +// "estimated_cost_usd": "0.000000", +// "period_start": "2026-04-17T00:00:00Z", +// "period_end": "2026-04-18T00:00:00Z" +// } +func (h *MetricsHandler) GetMetrics(c *gin.Context) { + workspaceID := c.Param("id") + ctx := c.Request.Context() + + // Verify workspace exists — 404 before touching usage table. + var wsExists bool + if err := db.DB.QueryRowContext(ctx, + `SELECT EXISTS(SELECT 1 FROM workspaces WHERE id = $1)`, + workspaceID, + ).Scan(&wsExists); err != nil { + log.Printf("metrics: workspace check failed for %s: %v", workspaceID, err) + c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to verify workspace"}) + return + } + if !wsExists { + c.JSON(http.StatusNotFound, gin.H{"error": "workspace not found"}) + return + } + + periodStart := todayUTC() + periodEnd := periodStart.Add(24 * time.Hour) + + var inputTokens, outputTokens int64 + var callCount int64 + var estimatedCost float64 + + err := db.DB.QueryRowContext(ctx, ` + SELECT + COALESCE(SUM(input_tokens), 0), + COALESCE(SUM(output_tokens), 0), + COALESCE(SUM(call_count), 0), + COALESCE(SUM(estimated_cost_usd), 0) + FROM workspace_token_usage + WHERE workspace_id = $1 + AND period_start = $2 + `, workspaceID, periodStart).Scan(&inputTokens, &outputTokens, &callCount, &estimatedCost) + if err != nil && err != sql.ErrNoRows { + log.Printf("metrics: query failed for workspace %s: %v", workspaceID, err) + c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to fetch metrics"}) + return + } + + c.JSON(http.StatusOK, gin.H{ + "input_tokens": inputTokens, + "output_tokens": outputTokens, + "total_calls": callCount, + "estimated_cost_usd": fmt.Sprintf("%.6f", estimatedCost), + "period_start": periodStart.Format(time.RFC3339), + "period_end": periodEnd.Format(time.RFC3339), + }) +} + +// todayUTC returns the start of the current UTC day (midnight). +func todayUTC() time.Time { + now := time.Now().UTC() + return time.Date(now.Year(), now.Month(), now.Day(), 0, 0, 0, 0, time.UTC) +} + +// upsertTokenUsage accumulates input/output token counts for workspaceID's +// current UTC day. Cost is estimated using the default per-token pricing +// constants. Always call in a detached goroutine — never block the A2A path. +func upsertTokenUsage(ctx context.Context, workspaceID string, inputTokens, outputTokens int64) { + if inputTokens == 0 && outputTokens == 0 { + return + } + periodStart := todayUTC() + cost := float64(inputTokens)*tokenCostPerInputToken + float64(outputTokens)*tokenCostPerOutputToken + + _, err := db.DB.ExecContext(ctx, ` + INSERT INTO workspace_token_usage + (workspace_id, period_start, input_tokens, output_tokens, call_count, estimated_cost_usd, updated_at) + VALUES ($1, $2, $3, $4, 1, $5, NOW()) + ON CONFLICT (workspace_id, period_start) DO UPDATE SET + input_tokens = workspace_token_usage.input_tokens + EXCLUDED.input_tokens, + output_tokens = workspace_token_usage.output_tokens + EXCLUDED.output_tokens, + call_count = workspace_token_usage.call_count + 1, + estimated_cost_usd = workspace_token_usage.estimated_cost_usd + EXCLUDED.estimated_cost_usd, + updated_at = NOW() + `, workspaceID, periodStart, inputTokens, outputTokens, cost) + if err != nil { + log.Printf("upsertTokenUsage: failed for workspace %s: %v", workspaceID, err) + } +} diff --git a/platform/internal/handlers/workspace_metrics_test.go b/platform/internal/handlers/workspace_metrics_test.go new file mode 100644 index 00000000..63e64d49 --- /dev/null +++ b/platform/internal/handlers/workspace_metrics_test.go @@ -0,0 +1,262 @@ +package handlers + +import ( + "encoding/json" + "net/http" + "net/http/httptest" + "testing" + + "github.com/DATA-DOG/go-sqlmock" + "github.com/gin-gonic/gin" +) + +// usageColumns matches the SELECT in GetMetrics. +var usageColumns = []string{ + "sum_input_tokens", "sum_output_tokens", "sum_call_count", "sum_cost", +} + +// expectWorkspaceExistsMetrics queues the EXISTS check in GetMetrics. +func expectWorkspaceExistsMetrics(mock sqlmock.Sqlmock, workspaceID string, exists bool) { + mock.ExpectQuery(`SELECT EXISTS`). + WithArgs(workspaceID). + WillReturnRows(sqlmock.NewRows([]string{"exists"}).AddRow(exists)) +} + +// TestGetMetrics_HappyPath verifies the handler returns correct aggregated data. +func TestGetMetrics_HappyPath(t *testing.T) { + mock := setupTestDB(t) + + expectWorkspaceExistsMetrics(mock, "ws-1", true) + + // Simulate one row with usage data. + mock.ExpectQuery(`SELECT\s+COALESCE\(SUM\(input_tokens\)`). + WithArgs("ws-1", sqlmock.AnyArg()). + WillReturnRows(sqlmock.NewRows(usageColumns). + AddRow(int64(1500), int64(300), int64(5), float64(0.009))) + + h := NewMetricsHandler() + w := httptest.NewRecorder() + c, _ := gin.CreateTestContext(w) + c.Params = gin.Params{{Key: "id", Value: "ws-1"}} + c.Request = httptest.NewRequest("GET", "/workspaces/ws-1/metrics", nil) + + h.GetMetrics(c) + + if w.Code != http.StatusOK { + t.Fatalf("expected 200, got %d: %s", w.Code, w.Body.String()) + } + + var resp struct { + InputTokens int64 `json:"input_tokens"` + OutputTokens int64 `json:"output_tokens"` + TotalCalls int64 `json:"total_calls"` + EstimatedCost string `json:"estimated_cost_usd"` + PeriodStart string `json:"period_start"` + PeriodEnd string `json:"period_end"` + } + if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil { + t.Fatalf("invalid JSON: %v\n%s", err, w.Body.String()) + } + + if resp.InputTokens != 1500 { + t.Errorf("expected input_tokens=1500, got %d", resp.InputTokens) + } + if resp.OutputTokens != 300 { + t.Errorf("expected output_tokens=300, got %d", resp.OutputTokens) + } + if resp.TotalCalls != 5 { + t.Errorf("expected total_calls=5, got %d", resp.TotalCalls) + } + if resp.EstimatedCost == "" { + t.Error("expected non-empty estimated_cost_usd") + } + if resp.PeriodStart == "" { + t.Error("expected non-empty period_start") + } + if resp.PeriodEnd == "" { + t.Error("expected non-empty period_end") + } + + if err := mock.ExpectationsWereMet(); err != nil { + t.Fatalf("unmet DB expectations: %v", err) + } +} + +// TestGetMetrics_WorkspaceNotFound verifies a 404 when workspace is absent. +func TestGetMetrics_WorkspaceNotFound(t *testing.T) { + mock := setupTestDB(t) + expectWorkspaceExistsMetrics(mock, "ghost", false) + + h := NewMetricsHandler() + w := httptest.NewRecorder() + c, _ := gin.CreateTestContext(w) + c.Params = gin.Params{{Key: "id", Value: "ghost"}} + c.Request = httptest.NewRequest("GET", "/workspaces/ghost/metrics", nil) + + h.GetMetrics(c) + + if w.Code != http.StatusNotFound { + t.Fatalf("expected 404, got %d: %s", w.Code, w.Body.String()) + } + if err := mock.ExpectationsWereMet(); err != nil { + t.Fatalf("unmet DB expectations: %v", err) + } +} + +// TestGetMetrics_EmptyPeriod verifies the handler returns zeros when no usage exists yet. +func TestGetMetrics_EmptyPeriod(t *testing.T) { + mock := setupTestDB(t) + expectWorkspaceExistsMetrics(mock, "ws-new", true) + + // COALESCE returns 0 for each column when no rows match. + mock.ExpectQuery(`SELECT\s+COALESCE\(SUM\(input_tokens\)`). + WithArgs("ws-new", sqlmock.AnyArg()). + WillReturnRows(sqlmock.NewRows(usageColumns). + AddRow(int64(0), int64(0), int64(0), float64(0))) + + h := NewMetricsHandler() + w := httptest.NewRecorder() + c, _ := gin.CreateTestContext(w) + c.Params = gin.Params{{Key: "id", Value: "ws-new"}} + c.Request = httptest.NewRequest("GET", "/workspaces/ws-new/metrics", nil) + + h.GetMetrics(c) + + if w.Code != http.StatusOK { + t.Fatalf("expected 200, got %d: %s", w.Code, w.Body.String()) + } + + var resp map[string]interface{} + if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil { + t.Fatalf("invalid JSON: %v", err) + } + // Verify period_start and period_end are present and distinct. + ps, _ := resp["period_start"].(string) + pe, _ := resp["period_end"].(string) + if ps == "" || pe == "" { + t.Errorf("expected non-empty period_start/period_end, got %q / %q", ps, pe) + } + if ps == pe { + t.Errorf("period_start and period_end must differ, both are %q", ps) + } + if err := mock.ExpectationsWereMet(); err != nil { + t.Fatalf("unmet DB expectations: %v", err) + } +} + +// TestGetMetrics_CostFormat verifies estimated_cost_usd is formatted to 6 decimal places. +func TestGetMetrics_CostFormat(t *testing.T) { + mock := setupTestDB(t) + expectWorkspaceExistsMetrics(mock, "ws-1", true) + + mock.ExpectQuery(`SELECT\s+COALESCE\(SUM\(input_tokens\)`). + WithArgs("ws-1", sqlmock.AnyArg()). + WillReturnRows(sqlmock.NewRows(usageColumns). + AddRow(int64(1000000), int64(0), int64(1), float64(3.0))) + + h := NewMetricsHandler() + w := httptest.NewRecorder() + c, _ := gin.CreateTestContext(w) + c.Params = gin.Params{{Key: "id", Value: "ws-1"}} + c.Request = httptest.NewRequest("GET", "/workspaces/ws-1/metrics", nil) + + h.GetMetrics(c) + + if w.Code != http.StatusOK { + t.Fatalf("expected 200, got %d: %s", w.Code, w.Body.String()) + } + var resp map[string]interface{} + if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil { + t.Fatalf("invalid JSON: %v", err) + } + cost, _ := resp["estimated_cost_usd"].(string) + if len(cost) < 8 { + // "3.000000" is 8 chars minimum + t.Errorf("expected at least 8-char cost string, got %q", cost) + } +} + +// ---- parseUsageFromA2AResponse tests ---- + +func TestParseUsage_JSONRPCResultEnvelope(t *testing.T) { + body := []byte(`{ + "jsonrpc": "2.0", + "id": "abc", + "result": { + "usage": { + "input_tokens": 100, + "output_tokens": 50 + } + } + }`) + in, out := parseUsageFromA2AResponse(body) + if in != 100 { + t.Errorf("expected input_tokens=100, got %d", in) + } + if out != 50 { + t.Errorf("expected output_tokens=50, got %d", out) + } +} + +func TestParseUsage_TopLevelUsage(t *testing.T) { + body := []byte(`{ + "usage": { + "input_tokens": 200, + "output_tokens": 75 + } + }`) + in, out := parseUsageFromA2AResponse(body) + if in != 200 { + t.Errorf("expected input_tokens=200, got %d", in) + } + if out != 75 { + t.Errorf("expected output_tokens=75, got %d", out) + } +} + +func TestParseUsage_NoUsageField(t *testing.T) { + body := []byte(`{"jsonrpc":"2.0","id":"x","result":{"message":"hello"}}`) + in, out := parseUsageFromA2AResponse(body) + if in != 0 || out != 0 { + t.Errorf("expected (0, 0) with no usage field, got (%d, %d)", in, out) + } +} + +func TestParseUsage_ZeroTokensIgnored(t *testing.T) { + body := []byte(`{"result":{"usage":{"input_tokens":0,"output_tokens":0}}}`) + in, out := parseUsageFromA2AResponse(body) + if in != 0 || out != 0 { + t.Errorf("expected (0, 0) for zero tokens, got (%d, %d)", in, out) + } +} + +func TestParseUsage_EmptyBody(t *testing.T) { + in, out := parseUsageFromA2AResponse([]byte{}) + if in != 0 || out != 0 { + t.Errorf("expected (0, 0) for empty body, got (%d, %d)", in, out) + } +} + +func TestParseUsage_InvalidJSON(t *testing.T) { + in, out := parseUsageFromA2AResponse([]byte("not json")) + if in != 0 || out != 0 { + t.Errorf("expected (0, 0) for invalid JSON, got (%d, %d)", in, out) + } +} + +func TestParseUsage_NestedResultPreferredOverTopLevel(t *testing.T) { + // result.usage should be preferred over top-level usage. + body := []byte(`{ + "usage": {"input_tokens": 999, "output_tokens": 999}, + "result": { + "usage": {"input_tokens": 42, "output_tokens": 21} + } + }`) + in, out := parseUsageFromA2AResponse(body) + if in != 42 { + t.Errorf("expected result.usage.input_tokens=42, got %d", in) + } + if out != 21 { + t.Errorf("expected result.usage.output_tokens=21, got %d", out) + } +} diff --git a/platform/internal/router/router.go b/platform/internal/router/router.go index a2fbbf59..e126728f 100644 --- a/platform/internal/router/router.go +++ b/platform/internal/router/router.go @@ -279,6 +279,11 @@ func Setup(hub *ws.Hub, broadcaster *events.Broadcaster, prov *provisioner.Provi wsAuth.PUT("/secrets", sech.Set) wsAuth.DELETE("/secrets/:key", sech.Delete) wsAuth.GET("/model", sech.GetModel) + + // Token usage metrics — cost transparency (#593). + // WorkspaceAuth middleware (on wsAuth) binds the bearer to :id. + mtrh := handlers.NewMetricsHandler() + wsAuth.GET("/metrics", mtrh.GetMetrics) } // Global secrets — /settings/secrets is the canonical path; /admin/secrets kept for backward compat. diff --git a/platform/migrations/026_workspace_token_usage.down.sql b/platform/migrations/026_workspace_token_usage.down.sql new file mode 100644 index 00000000..91a963d3 --- /dev/null +++ b/platform/migrations/026_workspace_token_usage.down.sql @@ -0,0 +1 @@ +DROP TABLE IF EXISTS workspace_token_usage; diff --git a/platform/migrations/026_workspace_token_usage.up.sql b/platform/migrations/026_workspace_token_usage.up.sql new file mode 100644 index 00000000..acec2090 --- /dev/null +++ b/platform/migrations/026_workspace_token_usage.up.sql @@ -0,0 +1,17 @@ +-- Per-workspace LLM token usage tracking (#593 — canvas cost transparency). +-- Stores UTC-day aggregates upserted by the A2A proxy after each LLM call. +-- estimated_cost_usd is computed server-side using fixed per-model rates +-- (default: Claude Sonnet input $3/1M, output $15/1M). +CREATE TABLE IF NOT EXISTS workspace_token_usage ( + id UUID PRIMARY KEY DEFAULT gen_random_uuid(), + workspace_id TEXT NOT NULL REFERENCES workspaces(id) ON DELETE CASCADE, + period_start TIMESTAMPTZ NOT NULL, + input_tokens BIGINT NOT NULL DEFAULT 0, + output_tokens BIGINT NOT NULL DEFAULT 0, + call_count INTEGER NOT NULL DEFAULT 0, + estimated_cost_usd NUMERIC(12,6) NOT NULL DEFAULT 0, + updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW() +); + +CREATE UNIQUE INDEX IF NOT EXISTS workspace_token_usage_ws_period + ON workspace_token_usage(workspace_id, period_start);