From df972a85e22aa117a3d256ab738d8af4a1c092b5 Mon Sep 17 00:00:00 2001 From: "Hongming (CTO)" Date: Sun, 31 May 2026 12:38:13 -0700 Subject: [PATCH] fix(workspace-server): central codex OAuth refresher (single-owner, anti-burn) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Multiple codex workspaces share ONE ChatGPT-Pro OAuth token (global_secrets key CODEX_AUTH_JSON). OpenAI's refresh_token is single-use, so letting each per-agent codex app-server refresh on its own 401 burned the shared seed within seconds (a refresh storm → token_invalidated + "refresh token already used"). This adds a single platform-side owner of the refresh: - internal/codexauth/refresher.go: one background goroutine, structurally single-flight (one goroutine + package mutex). Reads the global CODEX_AUTH_JSON, decodes the access_token JWT exp, and only within a safety margin of expiry POSTs the refresh_token ONCE per due cycle, then re-encrypts and writes the rotated blob back to global_secrets. Inert when the secret is absent; on a permanent failure (invalid_grant / "already used") it logs once and does NOT hot-loop. Billing-mode resolution + byok are untouched. - cmd/server/main.go: wired under supervised.RunWithRecover like the other background sweeps. Pairs with the codex template's codex_auth_sync.sh (GET-only re-sync; per-agent OAuth POST disabled) so workspaces only consume the current token and never rotate it themselves. Co-Authored-By: Claude Opus 4.8 (1M context) --- workspace-server/cmd/server/main.go | 15 + .../internal/codexauth/refresher.go | 463 ++++++++++++++++++ .../internal/codexauth/refresher_test.go | 425 ++++++++++++++++ 3 files changed, 903 insertions(+) create mode 100644 workspace-server/internal/codexauth/refresher.go create mode 100644 workspace-server/internal/codexauth/refresher_test.go diff --git a/workspace-server/cmd/server/main.go b/workspace-server/cmd/server/main.go index 04dff6d57..b79a61d6b 100644 --- a/workspace-server/cmd/server/main.go +++ b/workspace-server/cmd/server/main.go @@ -36,6 +36,7 @@ import ( "time" "git.moleculesai.app/molecule-ai/molecule-core/workspace-server/internal/channels" + "git.moleculesai.app/molecule-ai/molecule-core/workspace-server/internal/codexauth" "git.moleculesai.app/molecule-ai/molecule-core/workspace-server/internal/crypto" "git.moleculesai.app/molecule-ai/molecule-core/workspace-server/internal/db" "git.moleculesai.app/molecule-ai/molecule-core/workspace-server/internal/events" @@ -334,6 +335,20 @@ func main() { pendinguploads.StartSweeper(c, pendinguploads.NewPostgres(db.DB), 0) }) + // Codex shared-OAuth central refresher — the SINGLE owner of the rotating + // refresh_token for the global codex (ChatGPT/Codex subscription) credential + // (global_secrets key CODEX_AUTH_JSON). Multiple codex workspaces share ONE + // ChatGPT-Pro OAuth token; OpenAI's refresh_token is single-use, so letting + // each per-agent app-server refresh on its own 401 burned the seed within + // seconds (a refresh storm). This goroutine is structurally single-flight + // (one goroutine + a package mutex), refreshes only within a safety margin + // of expiry, POSTs the refresh_token at most once per due cycle, and writes + // the rotated blob back — workspaces now only GET the current token (see the + // codex template's codex_auth_sync.sh). INERT when no CODEX_AUTH_JSON exists. + go supervised.RunWithRecover(ctx, "codex-auth-refresher", func(c context.Context) { + codexauth.StartCodexAuthRefresher(c, db.DB) + }) + // Provision-timeout sweep — flips workspaces that have been stuck in // status='provisioning' past the timeout window to 'failed' and emits // WORKSPACE_PROVISION_TIMEOUT. Without this the UI banner is cosmetic diff --git a/workspace-server/internal/codexauth/refresher.go b/workspace-server/internal/codexauth/refresher.go new file mode 100644 index 000000000..c81e55253 --- /dev/null +++ b/workspace-server/internal/codexauth/refresher.go @@ -0,0 +1,463 @@ +// Package codexauth owns the SINGLE, platform-side refresh of the global +// codex (ChatGPT/Codex subscription) OAuth credential stored in the +// global_secrets table under key CODEX_AUTH_JSON. +// +// THE PROBLEM IT FIXES (agents-team prod, 2026-05-31) +// +// Multiple codex workspaces share ONE ChatGPT-Pro OAuth token (the global +// secret CODEX_AUTH_JSON). OpenAI's refresh_token is SINGLE-USE: every refresh +// rotates it and invalidates the prior one. When each per-agent codex +// app-server refreshed independently on a 401, the siblings' in-flight tokens +// were invalidated within seconds — a refresh storm that burned the seed and +// wedged every codex agent. +// +// THE FIX (two halves; this is the core half) +// +// 1. The per-workspace codex app-server NO LONGER refreshes (the template's +// OAuth POST is gated off by default — see the codex template's +// codex_auth_sync.sh / CODEX_AUTH_REFRESH_OWNER gate). Workspaces only ever +// GET the current token and write it to auth.json. +// 2. ONE owner refreshes the rotating refresh_token: this background goroutine +// in the platform. It is structurally single-flight (one goroutine + a +// package mutex), refreshes ONLY when the access_token is within a safety +// margin of expiry, POSTs the refresh_token at most ONCE per due cycle, and +// writes the rotated blob back to global_secrets. On a permanent failure +// (the seed was already burned by an out-of-band login) it logs ONCE and +// backs off — it never hot-loops a dead refresh_token. +// +// Billing-mode resolution and the byok strip are UNTOUCHED by this package. +package codexauth + +import ( + "context" + "database/sql" + "encoding/base64" + "encoding/json" + "fmt" + "io" + "log" + "net/http" + "strings" + "sync" + "time" + + "git.moleculesai.app/molecule-ai/molecule-core/workspace-server/internal/crypto" +) + +const ( + // CodexAuthSecretKey is the global_secrets key holding the shared codex + // ChatGPT/Codex subscription OAuth blob (auth.json contents). + CodexAuthSecretKey = "CODEX_AUTH_JSON" + + // oauthTokenURL is OpenAI's OAuth token endpoint. The ONLY endpoint this + // package ever POSTs to, and only for a due refresh. + oauthTokenURL = "https://auth.openai.com/oauth/token" + + // codexOAuthClientID is the public Codex CLI OAuth client id (the same id + // the codex CLI sends). Not a secret. + codexOAuthClientID = "app_EMoamEEZ73f0CkXaXp7hrann" + + // refreshSafetyMargin is how far ahead of access_token expiry a refresh is + // considered DUE. A token expiring within this window is refreshed now; one + // expiring later is left untouched (skip-when-fresh). Generous so a slow + // tick can never let the shared token lapse for the fleet. + refreshSafetyMargin = 15 * time.Minute + + // defaultInterval is how often the loop wakes to check due-ness. The check + // is cheap (decrypt + JWT exp parse) and only POSTs when actually due. + defaultInterval = 5 * time.Minute + + // permanentFailureBackoff is how long the loop waits after a PERMANENT + // refresh failure (invalid_grant / "refresh token already used"). The seed + // is burned until a human re-seeds a fresh login; there is nothing to retry, + // so we back off hard rather than hammer the dead token. + permanentFailureBackoff = 1 * time.Hour +) + +// SecretStore is the minimal global_secrets surface the refresher needs. The +// production implementation (postgresStore) is backed by *sql.DB; tests inject +// a fake. It is deliberately tiny — read one key, write one key — so the test +// double is trivial and the refresher never reaches for the package-global DB. +type SecretStore interface { + // Get returns the decrypted secret value and true, or ("", false) when the + // key is absent. A non-nil error is a real read failure (not absence). + Get(ctx context.Context, key string) (value string, found bool, err error) + // Put encrypts and upserts value under key, bumping the row's updated_at + // (the "last_refresh" timestamp). It is the rotated-blob write-back. + Put(ctx context.Context, key, value string) error +} + +// httpDoer is the http client seam (real *http.Client in prod, fake transport +// in tests). Tests NEVER hit the network. +type httpDoer interface { + Do(req *http.Request) (*http.Response, error) +} + +// refresher is the single-owner refresh engine. The package-level mutex makes +// the refresh structurally single-flight: even if two refreshOnce calls raced +// (they cannot in prod — one goroutine drives it — but a test or a future +// caller might), only one POSTs at a time, and the access-token freshness +// re-check inside the lock means the second sees a freshly-rotated token and +// skips. One goroutine + this mutex = single-flight by construction. +type refresher struct { + store SecretStore + client httpDoer + now func() time.Time + + // permanentlyFailed records that the current seed's refresh_token was + // rejected as already-used/invalid. While set, refreshOnce is INERT (it + // will not re-POST the dead token) until the secret value CHANGES (a human + // re-seed), detected by comparing the stored blob. This is the anti-storm + // latch — it lives on the struct, not globally, so it resets if the seed is + // replaced out of band. + failedSeed string // the auth-json blob that failed; "" = no known failure +} + +// mu serializes refreshOnce across the process. Package-level so the +// single-flight guarantee holds regardless of how many refresher values exist +// (in prod there is exactly one). +var mu sync.Mutex + +// oauthTokens is the token trio inside auth.json (and the OAuth response). +type oauthTokens struct { + AccessToken string `json:"access_token"` + RefreshToken string `json:"refresh_token"` + IDToken string `json:"id_token,omitempty"` +} + +// StartCodexAuthRefresher launches the single background refresher goroutine. +// It returns immediately; the loop runs until ctx is cancelled. Wire it under +// supervised.RunWithRecover in main.go like the other Start* sweeps. +// +// db may be nil only in tests that drive refreshOnce directly; in prod it is +// the server's *sql.DB. The loop is INERT (logs once, keeps ticking) whenever +// CODEX_AUTH_JSON is absent — a deployment with no shared codex seed pays only +// a cheap periodic read. +func StartCodexAuthRefresher(ctx context.Context, db *sql.DB) { + r := &refresher{ + store: &postgresStore{db: db}, + client: &http.Client{Timeout: 30 * time.Second}, + now: time.Now, + } + r.run(ctx, defaultInterval) +} + +// run is the tick loop. It checks due-ness every interval and on a permanent +// failure waits permanentFailureBackoff before the next check (never a tight +// retry of a burned token). +func (r *refresher) run(ctx context.Context, interval time.Duration) { + // Check once promptly on boot, then on the interval. + for { + wait := interval + if perm := r.refreshOnce(ctx); perm { + // Permanent failure this cycle — the seed is burned. Back off hard; + // a human must re-seed. We keep ticking (a re-seed CHANGES the blob, + // which clears the latch) but slowly. + wait = permanentFailureBackoff + } + + timer := time.NewTimer(wait) + select { + case <-ctx.Done(): + timer.Stop() + log.Printf("codexauth: context done; stopping refresher") + return + case <-timer.C: + } + } +} + +// refreshOnce performs ONE due-check + at most one refresh POST. It returns +// permanentFailure=true iff the refresh_token was permanently rejected this +// cycle (the caller backs off). All other outcomes (inert/skip/rotated/transient +// error) return false. +// +// It is single-flight: the package mutex is held for the whole read→decide→ +// POST→write-back so two callers cannot both POST the (single-use) refresh_token. +func (r *refresher) refreshOnce(ctx context.Context) (permanentFailure bool) { + mu.Lock() + defer mu.Unlock() + + blob, found, err := r.store.Get(ctx, CodexAuthSecretKey) + if err != nil { + log.Printf("codexauth: read CODEX_AUTH_JSON failed: %v (skipping this cycle)", err) + return false + } + if !found || strings.TrimSpace(blob) == "" { + // INERT: no shared codex seed in this deployment. Cheap no-op. + log.Printf("codexauth: no CODEX_AUTH_JSON in global_secrets — refresher inert") + // A previously-failed seed that has since been DELETED clears the latch. + r.failedSeed = "" + return false + } + + // Anti-storm latch: if THIS exact blob already failed permanently, do not + // re-POST its dead refresh_token. A re-seed changes the blob and clears it. + if r.failedSeed != "" && r.failedSeed == blob { + return false + } + if r.failedSeed != "" && r.failedSeed != blob { + // The seed changed out of band (human re-login) — give it a fresh chance. + r.failedSeed = "" + } + + tokens, err := parseTokens(blob) + if err != nil { + log.Printf("codexauth: CODEX_AUTH_JSON is not parseable codex auth json: %v (skipping)", err) + return false + } + if tokens.RefreshToken == "" { + log.Printf("codexauth: CODEX_AUTH_JSON carries no refresh_token (skipping)") + return false + } + + // Skip-when-fresh: only refresh within the safety margin of expiry. A blob + // with an unparseable/absent access_token exp is treated as DUE (better to + // refresh a token we cannot date than let the fleet lapse). + exp, haveExp := jwtExp(tokens.AccessToken) + if haveExp { + remaining := exp.Sub(r.now()) + if remaining > refreshSafetyMargin { + // Fresh — nothing to do. No POST. + return false + } + } + + // DUE: POST the refresh_token ONCE. + newTokens, perm, err := r.doRefresh(ctx, tokens.RefreshToken) + if err != nil { + if perm { + // Permanent: the seed is burned. Latch it so we don't re-POST, log + // ONCE, and DO NOT write anything back. + log.Printf("codexauth: PERMANENT refresh failure (refresh_token rejected): %v — "+ + "NOT writing back; the shared CODEX_AUTH_JSON seed is burned and must be re-seeded "+ + "via a fresh codex login. Backing off.", err) + r.failedSeed = blob + return true + } + // Transient (network/5xx): no write-back, retry next cycle (no backoff). + log.Printf("codexauth: transient refresh error: %v (will retry next cycle)", err) + return false + } + + // Success: merge the rotated trio into the blob (preserving every other + // field) and write it back encrypted, bumping updated_at (last_refresh). + rotated, err := mergeTokens(blob, newTokens) + if err != nil { + log.Printf("codexauth: failed to merge rotated tokens into auth json: %v (NOT writing back)", err) + return false + } + if err := r.store.Put(ctx, CodexAuthSecretKey, rotated); err != nil { + log.Printf("codexauth: write-back of rotated CODEX_AUTH_JSON failed: %v", err) + return false + } + r.failedSeed = "" // success clears any stale latch + log.Printf("codexauth: rotated shared CODEX_AUTH_JSON (single-owner refresh)") + return false +} + +// doRefresh POSTs the refresh_token to OpenAI's OAuth endpoint exactly once and +// returns the rotated trio. permanent=true marks an unrecoverable rejection +// (HTTP 400 invalid_grant / "refresh token already used") so the caller latches +// and backs off instead of retrying. +func (r *refresher) doRefresh(ctx context.Context, refreshToken string) (tokens oauthTokens, permanent bool, err error) { + body, _ := json.Marshal(map[string]string{ + "grant_type": "refresh_token", + "client_id": codexOAuthClientID, + "refresh_token": refreshToken, + }) + req, err := http.NewRequestWithContext(ctx, http.MethodPost, oauthTokenURL, strings.NewReader(string(body))) + if err != nil { + return oauthTokens{}, false, err + } + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Accept", "application/json") + + resp, err := r.client.Do(req) + if err != nil { + return oauthTokens{}, false, err // transient: network + } + defer resp.Body.Close() + respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 1<<20)) + + if resp.StatusCode == http.StatusOK { + var t oauthTokens + if err := json.Unmarshal(respBody, &t); err != nil { + return oauthTokens{}, false, fmt.Errorf("decode token response: %w", err) + } + if t.AccessToken == "" { + return oauthTokens{}, false, fmt.Errorf("token response missing access_token") + } + return t, false, nil + } + + // Non-200. A 400 (and any body naming invalid_grant / already-used) is a + // PERMANENT rejection of the refresh_token. 401/403 likewise mean the seed + // is no good. Everything else (429/5xx/network-shaped) is transient. + lowerBody := strings.ToLower(string(respBody)) + isInvalidGrant := strings.Contains(lowerBody, "invalid_grant") || + strings.Contains(lowerBody, "refresh token already used") || + strings.Contains(lowerBody, "already been used") || + strings.Contains(lowerBody, "token has been revoked") + switch { + case resp.StatusCode == http.StatusBadRequest && isInvalidGrant: + return oauthTokens{}, true, fmt.Errorf("oauth %d: %s", resp.StatusCode, strings.TrimSpace(string(respBody))) + case resp.StatusCode == http.StatusUnauthorized || resp.StatusCode == http.StatusForbidden: + return oauthTokens{}, true, fmt.Errorf("oauth %d: %s", resp.StatusCode, strings.TrimSpace(string(respBody))) + default: + return oauthTokens{}, false, fmt.Errorf("oauth %d: %s", resp.StatusCode, strings.TrimSpace(string(respBody))) + } +} + +// parseTokens extracts the OAuth trio from an auth.json blob, accepting both +// the nested `{"tokens":{...}}` shape the codex CLI writes and a flat top-level +// shape some seeds use. +func parseTokens(blob string) (oauthTokens, error) { + var top map[string]json.RawMessage + if err := json.Unmarshal([]byte(blob), &top); err != nil { + return oauthTokens{}, err + } + if nested, ok := top["tokens"]; ok { + var t oauthTokens + if err := json.Unmarshal(nested, &t); err != nil { + return oauthTokens{}, fmt.Errorf("decode nested tokens: %w", err) + } + return t, nil + } + var t oauthTokens + if err := json.Unmarshal([]byte(blob), &t); err != nil { + return oauthTokens{}, err + } + return t, nil +} + +// mergeTokens writes the rotated trio back into the original blob in-place, +// preserving the blob's shape (nested-vs-flat) and every other field. A field +// in the OAuth response that is empty (e.g. id_token omitted) does NOT clobber +// the existing value. +func mergeTokens(blob string, rotated oauthTokens) (string, error) { + var top map[string]json.RawMessage + if err := json.Unmarshal([]byte(blob), &top); err != nil { + return "", err + } + + applyTo := func(m map[string]json.RawMessage) error { + setStr := func(key, val string) error { + if val == "" { + return nil // don't clobber an existing value with an empty one + } + b, err := json.Marshal(val) + if err != nil { + return err + } + m[key] = b + return nil + } + if err := setStr("access_token", rotated.AccessToken); err != nil { + return err + } + if err := setStr("refresh_token", rotated.RefreshToken); err != nil { + return err + } + if err := setStr("id_token", rotated.IDToken); err != nil { + return err + } + return nil + } + + if nestedRaw, ok := top["tokens"]; ok { + var nested map[string]json.RawMessage + if err := json.Unmarshal(nestedRaw, &nested); err != nil { + return "", fmt.Errorf("decode nested tokens for merge: %w", err) + } + if err := applyTo(nested); err != nil { + return "", err + } + nb, err := json.Marshal(nested) + if err != nil { + return "", err + } + top["tokens"] = nb + } else { + if err := applyTo(top); err != nil { + return "", err + } + } + + out, err := json.Marshal(top) + if err != nil { + return "", err + } + return string(out), nil +} + +// jwtExp decodes the `exp` claim (Unix seconds) from a JWT access token WITHOUT +// verifying the signature (we only need the expiry to decide due-ness; the +// token's validity is OpenAI's to enforce). Returns ok=false when the token is +// not a parseable 3-part JWT or carries no numeric exp. +func jwtExp(token string) (time.Time, bool) { + parts := strings.Split(token, ".") + if len(parts) != 3 { + return time.Time{}, false + } + payload, err := base64.RawURLEncoding.DecodeString(parts[1]) + if err != nil { + // Some encoders pad; tolerate standard base64url with padding too. + payload, err = base64.URLEncoding.DecodeString(parts[1]) + if err != nil { + return time.Time{}, false + } + } + var claims struct { + Exp json.Number `json:"exp"` + } + if err := json.Unmarshal(payload, &claims); err != nil { + return time.Time{}, false + } + secs, err := claims.Exp.Int64() + if err != nil || secs <= 0 { + return time.Time{}, false + } + return time.Unix(secs, 0), true +} + +// postgresStore is the production SecretStore backed by global_secrets, using +// the SAME crypto path the secrets handler uses (DecryptVersioned on read, +// Encrypt + CurrentEncryptionVersion on write). +type postgresStore struct { + db *sql.DB +} + +func (s *postgresStore) Get(ctx context.Context, key string) (string, bool, error) { + var enc []byte + var ver int + err := s.db.QueryRowContext(ctx, + `SELECT encrypted_value, encryption_version FROM global_secrets WHERE key = $1`, key). + Scan(&enc, &ver) + if err == sql.ErrNoRows { + return "", false, nil + } + if err != nil { + return "", false, err + } + plain, err := crypto.DecryptVersioned(enc, ver) + if err != nil { + return "", false, err + } + return string(plain), true, nil +} + +func (s *postgresStore) Put(ctx context.Context, key, value string) error { + enc, err := crypto.Encrypt([]byte(value)) + if err != nil { + return err + } + ver := crypto.CurrentEncryptionVersion() + _, err = s.db.ExecContext(ctx, ` + INSERT INTO global_secrets (key, encrypted_value, encryption_version) + VALUES ($1, $2, $3) + ON CONFLICT (key) DO UPDATE + SET encrypted_value = $2, encryption_version = $3, updated_at = now() + `, key, enc, ver) + return err +} diff --git a/workspace-server/internal/codexauth/refresher_test.go b/workspace-server/internal/codexauth/refresher_test.go new file mode 100644 index 000000000..312ceae34 --- /dev/null +++ b/workspace-server/internal/codexauth/refresher_test.go @@ -0,0 +1,425 @@ +package codexauth + +import ( + "context" + "encoding/base64" + "encoding/json" + "fmt" + "io" + "net/http" + "strings" + "sync" + "sync/atomic" + "testing" + "time" +) + +// --- test doubles ----------------------------------------------------------- + +// fakeStore is an in-memory SecretStore. nil entry = absent key. +type fakeStore struct { + mu sync.Mutex + values map[string]string + getErr error + putErr error + puts int32 // count of successful Put calls +} + +func newFakeStore() *fakeStore { return &fakeStore{values: map[string]string{}} } + +func (f *fakeStore) Get(_ context.Context, key string) (string, bool, error) { + f.mu.Lock() + defer f.mu.Unlock() + if f.getErr != nil { + return "", false, f.getErr + } + v, ok := f.values[key] + return v, ok, nil +} + +func (f *fakeStore) Put(_ context.Context, key, value string) error { + f.mu.Lock() + defer f.mu.Unlock() + if f.putErr != nil { + return f.putErr + } + f.values[key] = value + atomic.AddInt32(&f.puts, 1) + return nil +} + +func (f *fakeStore) get(key string) string { + f.mu.Lock() + defer f.mu.Unlock() + return f.values[key] +} + +// fakeTransport records every request and returns a scripted response. It is +// the network seam — tests NEVER make a real request. +type fakeTransport struct { + mu sync.Mutex + calls int32 + urls []string + methods []string + bodies []string + status int + respBody string + transport func(*http.Request) (*http.Response, error) // optional override +} + +func (t *fakeTransport) Do(req *http.Request) (*http.Response, error) { + atomic.AddInt32(&t.calls, 1) + t.mu.Lock() + t.urls = append(t.urls, req.URL.String()) + t.methods = append(t.methods, req.Method) + if req.Body != nil { + b, _ := io.ReadAll(req.Body) + t.bodies = append(t.bodies, string(b)) + } else { + t.bodies = append(t.bodies, "") + } + t.mu.Unlock() + + if t.transport != nil { + return t.transport(req) + } + status := t.status + if status == 0 { + status = http.StatusOK + } + return &http.Response{ + StatusCode: status, + Body: io.NopCloser(strings.NewReader(t.respBody)), + Header: make(http.Header), + }, nil +} + +func (t *fakeTransport) callCount() int { return int(atomic.LoadInt32(&t.calls)) } + +// --- helpers ---------------------------------------------------------------- + +// makeJWT builds an unsigned-but-parseable JWT whose payload carries exp. +func makeJWT(exp time.Time) string { + header := base64.RawURLEncoding.EncodeToString([]byte(`{"alg":"none","typ":"JWT"}`)) + payload := base64.RawURLEncoding.EncodeToString([]byte( + fmt.Sprintf(`{"exp":%d,"sub":"codex"}`, exp.Unix()))) + sig := base64.RawURLEncoding.EncodeToString([]byte("sig")) + return header + "." + payload + "." + sig +} + +// authBlob builds a nested codex auth.json blob with the given tokens. +func authBlob(access, refresh string) string { + b, _ := json.Marshal(map[string]any{ + "tokens": map[string]any{ + "access_token": access, + "refresh_token": refresh, + "id_token": "id-original", + }, + "OPENAI_API_KEY": nil, + "last_refresh": "2026-01-01T00:00:00Z", + }) + return string(b) +} + +func newTestRefresher(store SecretStore, client httpDoer, now time.Time) *refresher { + return &refresher{ + store: store, + client: client, + now: func() time.Time { return now }, + } +} + +func okRefreshResponse(access, refresh string) string { + b, _ := json.Marshal(oauthTokens{AccessToken: access, RefreshToken: refresh, IDToken: "id-new"}) + return string(b) +} + +// --- tests ------------------------------------------------------------------ + +// TestJWTExpParse covers the exp decode (valid, malformed, missing). +func TestJWTExpParse(t *testing.T) { + want := time.Now().Add(2 * time.Hour).Truncate(time.Second) + got, ok := jwtExp(makeJWT(want)) + if !ok { + t.Fatalf("jwtExp(valid) ok=false, want true") + } + if !got.Equal(want) { + t.Errorf("jwtExp = %v, want %v", got, want) + } + + if _, ok := jwtExp("not-a-jwt"); ok { + t.Errorf("jwtExp(non-jwt) ok=true, want false") + } + if _, ok := jwtExp("a.b.c"); ok { + t.Errorf("jwtExp(garbage parts) ok=true, want false") + } + // 3 parts but payload has no exp. + noExp := base64.RawURLEncoding.EncodeToString([]byte("{}")) + if _, ok := jwtExp("h." + noExp + ".s"); ok { + t.Errorf("jwtExp(no exp claim) ok=true, want false") + } +} + +// TestRefreshOnce_SkipWhenFresh: a token well outside the safety margin is NOT +// refreshed — no POST, no write-back. +func TestRefreshOnce_SkipWhenFresh(t *testing.T) { + now := time.Now() + store := newFakeStore() + store.values[CodexAuthSecretKey] = authBlob(makeJWT(now.Add(2*time.Hour)), "rt-1") + tr := &fakeTransport{status: http.StatusOK, respBody: okRefreshResponse("new-at", "rt-2")} + r := newTestRefresher(store, tr, now) + + if perm := r.refreshOnce(context.Background()); perm { + t.Fatalf("fresh token: permanentFailure=true, want false") + } + if tr.callCount() != 0 { + t.Errorf("fresh token: %d OAuth POSTs, want 0", tr.callCount()) + } + if atomic.LoadInt32(&store.puts) != 0 { + t.Errorf("fresh token: %d write-backs, want 0", store.puts) + } +} + +// TestRefreshOnce_RotateThenReskip: a token inside the margin is refreshed once +// (POST + write-back of the rotated blob); a subsequent call on the now-fresh +// rotated token skips (no second POST). Proves rotate→write-back→re-skip. +func TestRefreshOnce_RotateThenReskip(t *testing.T) { + now := time.Now() + store := newFakeStore() + // Expires in 5m — inside the 15m safety margin → DUE. + store.values[CodexAuthSecretKey] = authBlob(makeJWT(now.Add(5*time.Minute)), "rt-1") + // Rotated access token is fresh (2h out); rotated refresh is rt-2. + tr := &fakeTransport{status: http.StatusOK, respBody: okRefreshResponse(makeJWT(now.Add(2*time.Hour)), "rt-2")} + r := newTestRefresher(store, tr, now) + + if perm := r.refreshOnce(context.Background()); perm { + t.Fatalf("due token: permanentFailure=true, want false") + } + if tr.callCount() != 1 { + t.Fatalf("due token: %d OAuth POSTs, want exactly 1", tr.callCount()) + } + if atomic.LoadInt32(&store.puts) != 1 { + t.Fatalf("due token: %d write-backs, want exactly 1", store.puts) + } + + // The written blob must carry the rotated refresh_token and preserve the + // non-token field. + rotated := store.get(CodexAuthSecretKey) + tokens, err := parseTokens(rotated) + if err != nil { + t.Fatalf("parse rotated blob: %v", err) + } + if tokens.RefreshToken != "rt-2" { + t.Errorf("rotated refresh_token = %q, want rt-2", tokens.RefreshToken) + } + if !strings.Contains(rotated, "last_refresh") { + t.Errorf("rotated blob dropped the preserved last_refresh field: %s", rotated) + } + + // Second call: the rotated access token is fresh → skip, no new POST. + if perm := r.refreshOnce(context.Background()); perm { + t.Fatalf("re-skip: permanentFailure=true, want false") + } + if tr.callCount() != 1 { + t.Errorf("re-skip: %d total OAuth POSTs, want still 1", tr.callCount()) + } + if atomic.LoadInt32(&store.puts) != 1 { + t.Errorf("re-skip: %d total write-backs, want still 1", store.puts) + } +} + +// TestRefreshOnce_NoSecretInert: absent CODEX_AUTH_JSON → inert (no POST, no +// write-back, no error/permanent). +func TestRefreshOnce_NoSecretInert(t *testing.T) { + store := newFakeStore() // empty + tr := &fakeTransport{} + r := newTestRefresher(store, tr, time.Now()) + + if perm := r.refreshOnce(context.Background()); perm { + t.Fatalf("no secret: permanentFailure=true, want false") + } + if tr.callCount() != 0 { + t.Errorf("no secret: %d POSTs, want 0", tr.callCount()) + } + if atomic.LoadInt32(&store.puts) != 0 { + t.Errorf("no secret: %d write-backs, want 0", store.puts) + } +} + +// TestRefreshOnce_PermanentFailNoWriteNoStorm: a 400 invalid_grant must (a) not +// write back, (b) return permanentFailure=true, and (c) NOT re-POST on the next +// cycle for the same (burned) seed — the anti-storm latch. +func TestRefreshOnce_PermanentFailNoWriteNoStorm(t *testing.T) { + now := time.Now() + store := newFakeStore() + store.values[CodexAuthSecretKey] = authBlob(makeJWT(now.Add(1*time.Minute)), "rt-burned") + tr := &fakeTransport{ + status: http.StatusBadRequest, + respBody: `{"error":"invalid_grant","error_description":"refresh token already used"}`, + } + r := newTestRefresher(store, tr, now) + + perm := r.refreshOnce(context.Background()) + if !perm { + t.Fatalf("invalid_grant: permanentFailure=false, want true") + } + if tr.callCount() != 1 { + t.Fatalf("invalid_grant: %d POSTs, want exactly 1", tr.callCount()) + } + if atomic.LoadInt32(&store.puts) != 0 { + t.Fatalf("invalid_grant: %d write-backs, want 0 (must NOT persist a failed refresh)", store.puts) + } + + // Next cycle, SAME burned seed: must NOT re-POST (anti-storm latch). + perm2 := r.refreshOnce(context.Background()) + if tr.callCount() != 1 { + t.Errorf("anti-storm: re-POSTed a burned refresh_token (%d total POSTs, want still 1)", tr.callCount()) + } + _ = perm2 // latched cycle returns false (already-known failure, nothing new) + + // A RE-SEED (blob changes) clears the latch and allows a fresh attempt. + store.mu.Lock() + store.values[CodexAuthSecretKey] = authBlob(makeJWT(now.Add(1*time.Minute)), "rt-freshly-seeded") + store.mu.Unlock() + tr.status = http.StatusOK + tr.respBody = okRefreshResponse(makeJWT(now.Add(2*time.Hour)), "rt-rotated") + if perm := r.refreshOnce(context.Background()); perm { + t.Fatalf("post-reseed: permanentFailure=true, want false") + } + if tr.callCount() != 2 { + t.Errorf("post-reseed: %d total POSTs, want 2 (latch should clear on re-seed)", tr.callCount()) + } +} + +// TestRefreshOnce_TransientNoWriteNoLatch: a 5xx is transient — no write-back, +// returns false (no hard backoff latch), and a later cycle retries. +func TestRefreshOnce_TransientNoWriteNoLatch(t *testing.T) { + now := time.Now() + store := newFakeStore() + store.values[CodexAuthSecretKey] = authBlob(makeJWT(now.Add(1*time.Minute)), "rt-1") + tr := &fakeTransport{status: http.StatusServiceUnavailable, respBody: "upstream down"} + r := newTestRefresher(store, tr, now) + + if perm := r.refreshOnce(context.Background()); perm { + t.Fatalf("503: permanentFailure=true, want false (transient)") + } + if atomic.LoadInt32(&store.puts) != 0 { + t.Errorf("503: %d write-backs, want 0", store.puts) + } + // Retry next cycle succeeds (no latch on transient). + tr.status = http.StatusOK + tr.respBody = okRefreshResponse(makeJWT(now.Add(2*time.Hour)), "rt-2") + if perm := r.refreshOnce(context.Background()); perm { + t.Fatalf("retry after 503: permanentFailure=true, want false") + } + if tr.callCount() != 2 { + t.Errorf("transient retry: %d total POSTs, want 2", tr.callCount()) + } + if atomic.LoadInt32(&store.puts) != 1 { + t.Errorf("transient retry: %d write-backs, want 1", store.puts) + } +} + +// TestRefreshOnce_SingleFlight: concurrent refreshOnce calls on a DUE token must +// POST exactly once total — the package mutex serializes them and the second +// sees the freshly-rotated (now-fresh) token and skips. Structural single-flight. +func TestRefreshOnce_SingleFlight(t *testing.T) { + now := time.Now() + store := newFakeStore() + store.values[CodexAuthSecretKey] = authBlob(makeJWT(now.Add(1*time.Minute)), "rt-1") + // Every successful rotation yields a FRESH (2h) access token, so once one + // caller rotates, the other sees fresh and skips. + tr := &fakeTransport{status: http.StatusOK, respBody: okRefreshResponse(makeJWT(now.Add(2*time.Hour)), "rt-2")} + r := newTestRefresher(store, tr, now) + + const n = 16 + var wg sync.WaitGroup + wg.Add(n) + for i := 0; i < n; i++ { + go func() { + defer wg.Done() + r.refreshOnce(context.Background()) + }() + } + wg.Wait() + + if tr.callCount() != 1 { + t.Errorf("single-flight: %d OAuth POSTs across %d concurrent calls, want exactly 1", tr.callCount(), n) + } + if atomic.LoadInt32(&store.puts) != 1 { + t.Errorf("single-flight: %d write-backs, want exactly 1", store.puts) + } +} + +// TestRefreshOnce_PostsExactlyOnceToOAuthEndpoint: when it DOES refresh, the +// single POST goes to the OAuth token URL with the refresh_token grant body. +func TestRefreshOnce_PostsExactlyOnceToOAuthEndpoint(t *testing.T) { + now := time.Now() + store := newFakeStore() + store.values[CodexAuthSecretKey] = authBlob(makeJWT(now.Add(1*time.Minute)), "rt-secret") + tr := &fakeTransport{status: http.StatusOK, respBody: okRefreshResponse(makeJWT(now.Add(2*time.Hour)), "rt-2")} + r := newTestRefresher(store, tr, now) + + r.refreshOnce(context.Background()) + + if tr.callCount() != 1 { + t.Fatalf("%d POSTs, want exactly 1", tr.callCount()) + } + if tr.urls[0] != oauthTokenURL { + t.Errorf("POST URL = %q, want %q", tr.urls[0], oauthTokenURL) + } + if tr.methods[0] != http.MethodPost { + t.Errorf("method = %q, want POST", tr.methods[0]) + } + var body map[string]string + if err := json.Unmarshal([]byte(tr.bodies[0]), &body); err != nil { + t.Fatalf("request body not json: %v (%s)", err, tr.bodies[0]) + } + if body["grant_type"] != "refresh_token" { + t.Errorf("grant_type = %q, want refresh_token", body["grant_type"]) + } + if body["refresh_token"] != "rt-secret" { + t.Errorf("refresh_token = %q, want rt-secret", body["refresh_token"]) + } + if body["client_id"] != codexOAuthClientID { + t.Errorf("client_id = %q, want %q", body["client_id"], codexOAuthClientID) + } +} + +// TestRefreshOnce_ReadErrorSkips: a store read error is a transient skip (no +// POST, no permanent latch). +func TestRefreshOnce_ReadErrorSkips(t *testing.T) { + store := newFakeStore() + store.getErr = fmt.Errorf("db down") + tr := &fakeTransport{} + r := newTestRefresher(store, tr, time.Now()) + if perm := r.refreshOnce(context.Background()); perm { + t.Errorf("read error: permanentFailure=true, want false") + } + if tr.callCount() != 0 { + t.Errorf("read error: %d POSTs, want 0", tr.callCount()) + } +} + +// TestMergeTokens_PreservesOtherFields proves the rotated write-back keeps every +// non-token field and does not clobber id_token with an empty rotated value. +func TestMergeTokens_PreservesOtherFields(t *testing.T) { + blob := authBlob("old-at", "old-rt") + out, err := mergeTokens(blob, oauthTokens{AccessToken: "new-at", RefreshToken: "new-rt"}) // no id_token + if err != nil { + t.Fatalf("mergeTokens: %v", err) + } + tokens, err := parseTokens(out) + if err != nil { + t.Fatalf("parse merged: %v", err) + } + if tokens.AccessToken != "new-at" || tokens.RefreshToken != "new-rt" { + t.Errorf("merged tokens = %+v, want new-at/new-rt", tokens) + } + if tokens.IDToken != "id-original" { + t.Errorf("empty rotated id_token clobbered the original: got %q, want id-original", tokens.IDToken) + } + if !strings.Contains(out, "last_refresh") { + t.Errorf("merge dropped preserved field: %s", out) + } +} -- 2.52.0