fix(workspace-server): central codex OAuth refresher (single-owner, anti-burn) #2023

Merged
devops-engineer merged 1 commits from fix/codex-central-refresher into main 2026-05-31 19:52:22 +00:00
3 changed files with 903 additions and 0 deletions
+15
View File
@@ -36,6 +36,7 @@ import (
"time"
"git.moleculesai.app/molecule-ai/molecule-core/workspace-server/internal/channels"
"git.moleculesai.app/molecule-ai/molecule-core/workspace-server/internal/codexauth"
"git.moleculesai.app/molecule-ai/molecule-core/workspace-server/internal/crypto"
"git.moleculesai.app/molecule-ai/molecule-core/workspace-server/internal/db"
"git.moleculesai.app/molecule-ai/molecule-core/workspace-server/internal/events"
@@ -334,6 +335,20 @@ func main() {
pendinguploads.StartSweeper(c, pendinguploads.NewPostgres(db.DB), 0)
})
// Codex shared-OAuth central refresher — the SINGLE owner of the rotating
// refresh_token for the global codex (ChatGPT/Codex subscription) credential
// (global_secrets key CODEX_AUTH_JSON). Multiple codex workspaces share ONE
// ChatGPT-Pro OAuth token; OpenAI's refresh_token is single-use, so letting
// each per-agent app-server refresh on its own 401 burned the seed within
// seconds (a refresh storm). This goroutine is structurally single-flight
// (one goroutine + a package mutex), refreshes only within a safety margin
// of expiry, POSTs the refresh_token at most once per due cycle, and writes
// the rotated blob back — workspaces now only GET the current token (see the
// codex template's codex_auth_sync.sh). INERT when no CODEX_AUTH_JSON exists.
go supervised.RunWithRecover(ctx, "codex-auth-refresher", func(c context.Context) {
codexauth.StartCodexAuthRefresher(c, db.DB)
})
// Provision-timeout sweep — flips workspaces that have been stuck in
// status='provisioning' past the timeout window to 'failed' and emits
// WORKSPACE_PROVISION_TIMEOUT. Without this the UI banner is cosmetic
@@ -0,0 +1,463 @@
// Package codexauth owns the SINGLE, platform-side refresh of the global
// codex (ChatGPT/Codex subscription) OAuth credential stored in the
// global_secrets table under key CODEX_AUTH_JSON.
//
// THE PROBLEM IT FIXES (agents-team prod, 2026-05-31)
//
// Multiple codex workspaces share ONE ChatGPT-Pro OAuth token (the global
// secret CODEX_AUTH_JSON). OpenAI's refresh_token is SINGLE-USE: every refresh
// rotates it and invalidates the prior one. When each per-agent codex
// app-server refreshed independently on a 401, the siblings' in-flight tokens
// were invalidated within seconds — a refresh storm that burned the seed and
// wedged every codex agent.
//
// THE FIX (two halves; this is the core half)
//
// 1. The per-workspace codex app-server NO LONGER refreshes (the template's
// OAuth POST is gated off by default — see the codex template's
// codex_auth_sync.sh / CODEX_AUTH_REFRESH_OWNER gate). Workspaces only ever
// GET the current token and write it to auth.json.
// 2. ONE owner refreshes the rotating refresh_token: this background goroutine
// in the platform. It is structurally single-flight (one goroutine + a
// package mutex), refreshes ONLY when the access_token is within a safety
// margin of expiry, POSTs the refresh_token at most ONCE per due cycle, and
// writes the rotated blob back to global_secrets. On a permanent failure
// (the seed was already burned by an out-of-band login) it logs ONCE and
// backs off — it never hot-loops a dead refresh_token.
//
// Billing-mode resolution and the byok strip are UNTOUCHED by this package.
package codexauth
import (
"context"
"database/sql"
"encoding/base64"
"encoding/json"
"fmt"
"io"
"log"
"net/http"
"strings"
"sync"
"time"
"git.moleculesai.app/molecule-ai/molecule-core/workspace-server/internal/crypto"
)
const (
// CodexAuthSecretKey is the global_secrets key holding the shared codex
// ChatGPT/Codex subscription OAuth blob (auth.json contents).
CodexAuthSecretKey = "CODEX_AUTH_JSON"
// oauthTokenURL is OpenAI's OAuth token endpoint. The ONLY endpoint this
// package ever POSTs to, and only for a due refresh.
oauthTokenURL = "https://auth.openai.com/oauth/token"
// codexOAuthClientID is the public Codex CLI OAuth client id (the same id
// the codex CLI sends). Not a secret.
codexOAuthClientID = "app_EMoamEEZ73f0CkXaXp7hrann"
// refreshSafetyMargin is how far ahead of access_token expiry a refresh is
// considered DUE. A token expiring within this window is refreshed now; one
// expiring later is left untouched (skip-when-fresh). Generous so a slow
// tick can never let the shared token lapse for the fleet.
refreshSafetyMargin = 15 * time.Minute
// defaultInterval is how often the loop wakes to check due-ness. The check
// is cheap (decrypt + JWT exp parse) and only POSTs when actually due.
defaultInterval = 5 * time.Minute
// permanentFailureBackoff is how long the loop waits after a PERMANENT
// refresh failure (invalid_grant / "refresh token already used"). The seed
// is burned until a human re-seeds a fresh login; there is nothing to retry,
// so we back off hard rather than hammer the dead token.
permanentFailureBackoff = 1 * time.Hour
)
// SecretStore is the minimal global_secrets surface the refresher needs. The
// production implementation (postgresStore) is backed by *sql.DB; tests inject
// a fake. It is deliberately tiny — read one key, write one key — so the test
// double is trivial and the refresher never reaches for the package-global DB.
type SecretStore interface {
// Get returns the decrypted secret value and true, or ("", false) when the
// key is absent. A non-nil error is a real read failure (not absence).
Get(ctx context.Context, key string) (value string, found bool, err error)
// Put encrypts and upserts value under key, bumping the row's updated_at
// (the "last_refresh" timestamp). It is the rotated-blob write-back.
Put(ctx context.Context, key, value string) error
}
// httpDoer is the http client seam (real *http.Client in prod, fake transport
// in tests). Tests NEVER hit the network.
type httpDoer interface {
Do(req *http.Request) (*http.Response, error)
}
// refresher is the single-owner refresh engine. The package-level mutex makes
// the refresh structurally single-flight: even if two refreshOnce calls raced
// (they cannot in prod — one goroutine drives it — but a test or a future
// caller might), only one POSTs at a time, and the access-token freshness
// re-check inside the lock means the second sees a freshly-rotated token and
// skips. One goroutine + this mutex = single-flight by construction.
type refresher struct {
store SecretStore
client httpDoer
now func() time.Time
// permanentlyFailed records that the current seed's refresh_token was
// rejected as already-used/invalid. While set, refreshOnce is INERT (it
// will not re-POST the dead token) until the secret value CHANGES (a human
// re-seed), detected by comparing the stored blob. This is the anti-storm
// latch — it lives on the struct, not globally, so it resets if the seed is
// replaced out of band.
failedSeed string // the auth-json blob that failed; "" = no known failure
}
// mu serializes refreshOnce across the process. Package-level so the
// single-flight guarantee holds regardless of how many refresher values exist
// (in prod there is exactly one).
var mu sync.Mutex
// oauthTokens is the token trio inside auth.json (and the OAuth response).
type oauthTokens struct {
AccessToken string `json:"access_token"`
RefreshToken string `json:"refresh_token"`
IDToken string `json:"id_token,omitempty"`
}
// StartCodexAuthRefresher launches the single background refresher goroutine.
// It returns immediately; the loop runs until ctx is cancelled. Wire it under
// supervised.RunWithRecover in main.go like the other Start* sweeps.
//
// db may be nil only in tests that drive refreshOnce directly; in prod it is
// the server's *sql.DB. The loop is INERT (logs once, keeps ticking) whenever
// CODEX_AUTH_JSON is absent — a deployment with no shared codex seed pays only
// a cheap periodic read.
func StartCodexAuthRefresher(ctx context.Context, db *sql.DB) {
r := &refresher{
store: &postgresStore{db: db},
client: &http.Client{Timeout: 30 * time.Second},
now: time.Now,
}
r.run(ctx, defaultInterval)
}
// run is the tick loop. It checks due-ness every interval and on a permanent
// failure waits permanentFailureBackoff before the next check (never a tight
// retry of a burned token).
func (r *refresher) run(ctx context.Context, interval time.Duration) {
// Check once promptly on boot, then on the interval.
for {
wait := interval
if perm := r.refreshOnce(ctx); perm {
// Permanent failure this cycle — the seed is burned. Back off hard;
// a human must re-seed. We keep ticking (a re-seed CHANGES the blob,
// which clears the latch) but slowly.
wait = permanentFailureBackoff
}
timer := time.NewTimer(wait)
select {
case <-ctx.Done():
timer.Stop()
log.Printf("codexauth: context done; stopping refresher")
return
case <-timer.C:
}
}
}
// refreshOnce performs ONE due-check + at most one refresh POST. It returns
// permanentFailure=true iff the refresh_token was permanently rejected this
// cycle (the caller backs off). All other outcomes (inert/skip/rotated/transient
// error) return false.
//
// It is single-flight: the package mutex is held for the whole read→decide→
// POST→write-back so two callers cannot both POST the (single-use) refresh_token.
func (r *refresher) refreshOnce(ctx context.Context) (permanentFailure bool) {
mu.Lock()
defer mu.Unlock()
blob, found, err := r.store.Get(ctx, CodexAuthSecretKey)
if err != nil {
log.Printf("codexauth: read CODEX_AUTH_JSON failed: %v (skipping this cycle)", err)
return false
}
if !found || strings.TrimSpace(blob) == "" {
// INERT: no shared codex seed in this deployment. Cheap no-op.
log.Printf("codexauth: no CODEX_AUTH_JSON in global_secrets — refresher inert")
// A previously-failed seed that has since been DELETED clears the latch.
r.failedSeed = ""
return false
}
// Anti-storm latch: if THIS exact blob already failed permanently, do not
// re-POST its dead refresh_token. A re-seed changes the blob and clears it.
if r.failedSeed != "" && r.failedSeed == blob {
return false
}
if r.failedSeed != "" && r.failedSeed != blob {
// The seed changed out of band (human re-login) — give it a fresh chance.
r.failedSeed = ""
}
tokens, err := parseTokens(blob)
if err != nil {
log.Printf("codexauth: CODEX_AUTH_JSON is not parseable codex auth json: %v (skipping)", err)
return false
}
if tokens.RefreshToken == "" {
log.Printf("codexauth: CODEX_AUTH_JSON carries no refresh_token (skipping)")
return false
}
// Skip-when-fresh: only refresh within the safety margin of expiry. A blob
// with an unparseable/absent access_token exp is treated as DUE (better to
// refresh a token we cannot date than let the fleet lapse).
exp, haveExp := jwtExp(tokens.AccessToken)
if haveExp {
remaining := exp.Sub(r.now())
if remaining > refreshSafetyMargin {
// Fresh — nothing to do. No POST.
return false
}
}
// DUE: POST the refresh_token ONCE.
newTokens, perm, err := r.doRefresh(ctx, tokens.RefreshToken)
if err != nil {
if perm {
// Permanent: the seed is burned. Latch it so we don't re-POST, log
// ONCE, and DO NOT write anything back.
log.Printf("codexauth: PERMANENT refresh failure (refresh_token rejected): %v — "+
"NOT writing back; the shared CODEX_AUTH_JSON seed is burned and must be re-seeded "+
"via a fresh codex login. Backing off.", err)
r.failedSeed = blob
return true
}
// Transient (network/5xx): no write-back, retry next cycle (no backoff).
log.Printf("codexauth: transient refresh error: %v (will retry next cycle)", err)
return false
}
// Success: merge the rotated trio into the blob (preserving every other
// field) and write it back encrypted, bumping updated_at (last_refresh).
rotated, err := mergeTokens(blob, newTokens)
if err != nil {
log.Printf("codexauth: failed to merge rotated tokens into auth json: %v (NOT writing back)", err)
return false
}
if err := r.store.Put(ctx, CodexAuthSecretKey, rotated); err != nil {
log.Printf("codexauth: write-back of rotated CODEX_AUTH_JSON failed: %v", err)
return false
}
r.failedSeed = "" // success clears any stale latch
log.Printf("codexauth: rotated shared CODEX_AUTH_JSON (single-owner refresh)")
return false
}
// doRefresh POSTs the refresh_token to OpenAI's OAuth endpoint exactly once and
// returns the rotated trio. permanent=true marks an unrecoverable rejection
// (HTTP 400 invalid_grant / "refresh token already used") so the caller latches
// and backs off instead of retrying.
func (r *refresher) doRefresh(ctx context.Context, refreshToken string) (tokens oauthTokens, permanent bool, err error) {
body, _ := json.Marshal(map[string]string{
"grant_type": "refresh_token",
"client_id": codexOAuthClientID,
"refresh_token": refreshToken,
})
req, err := http.NewRequestWithContext(ctx, http.MethodPost, oauthTokenURL, strings.NewReader(string(body)))
if err != nil {
return oauthTokens{}, false, err
}
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Accept", "application/json")
resp, err := r.client.Do(req)
if err != nil {
return oauthTokens{}, false, err // transient: network
}
defer resp.Body.Close()
respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 1<<20))
if resp.StatusCode == http.StatusOK {
var t oauthTokens
if err := json.Unmarshal(respBody, &t); err != nil {
return oauthTokens{}, false, fmt.Errorf("decode token response: %w", err)
}
if t.AccessToken == "" {
return oauthTokens{}, false, fmt.Errorf("token response missing access_token")
}
return t, false, nil
}
// Non-200. A 400 (and any body naming invalid_grant / already-used) is a
// PERMANENT rejection of the refresh_token. 401/403 likewise mean the seed
// is no good. Everything else (429/5xx/network-shaped) is transient.
lowerBody := strings.ToLower(string(respBody))
isInvalidGrant := strings.Contains(lowerBody, "invalid_grant") ||
strings.Contains(lowerBody, "refresh token already used") ||
strings.Contains(lowerBody, "already been used") ||
strings.Contains(lowerBody, "token has been revoked")
switch {
case resp.StatusCode == http.StatusBadRequest && isInvalidGrant:
return oauthTokens{}, true, fmt.Errorf("oauth %d: %s", resp.StatusCode, strings.TrimSpace(string(respBody)))
case resp.StatusCode == http.StatusUnauthorized || resp.StatusCode == http.StatusForbidden:
return oauthTokens{}, true, fmt.Errorf("oauth %d: %s", resp.StatusCode, strings.TrimSpace(string(respBody)))
default:
return oauthTokens{}, false, fmt.Errorf("oauth %d: %s", resp.StatusCode, strings.TrimSpace(string(respBody)))
}
}
// parseTokens extracts the OAuth trio from an auth.json blob, accepting both
// the nested `{"tokens":{...}}` shape the codex CLI writes and a flat top-level
// shape some seeds use.
func parseTokens(blob string) (oauthTokens, error) {
var top map[string]json.RawMessage
if err := json.Unmarshal([]byte(blob), &top); err != nil {
return oauthTokens{}, err
}
if nested, ok := top["tokens"]; ok {
var t oauthTokens
if err := json.Unmarshal(nested, &t); err != nil {
return oauthTokens{}, fmt.Errorf("decode nested tokens: %w", err)
}
return t, nil
}
var t oauthTokens
if err := json.Unmarshal([]byte(blob), &t); err != nil {
return oauthTokens{}, err
}
return t, nil
}
// mergeTokens writes the rotated trio back into the original blob in-place,
// preserving the blob's shape (nested-vs-flat) and every other field. A field
// in the OAuth response that is empty (e.g. id_token omitted) does NOT clobber
// the existing value.
func mergeTokens(blob string, rotated oauthTokens) (string, error) {
var top map[string]json.RawMessage
if err := json.Unmarshal([]byte(blob), &top); err != nil {
return "", err
}
applyTo := func(m map[string]json.RawMessage) error {
setStr := func(key, val string) error {
if val == "" {
return nil // don't clobber an existing value with an empty one
}
b, err := json.Marshal(val)
if err != nil {
return err
}
m[key] = b
return nil
}
if err := setStr("access_token", rotated.AccessToken); err != nil {
return err
}
if err := setStr("refresh_token", rotated.RefreshToken); err != nil {
return err
}
if err := setStr("id_token", rotated.IDToken); err != nil {
return err
}
return nil
}
if nestedRaw, ok := top["tokens"]; ok {
var nested map[string]json.RawMessage
if err := json.Unmarshal(nestedRaw, &nested); err != nil {
return "", fmt.Errorf("decode nested tokens for merge: %w", err)
}
if err := applyTo(nested); err != nil {
return "", err
}
nb, err := json.Marshal(nested)
if err != nil {
return "", err
}
top["tokens"] = nb
} else {
if err := applyTo(top); err != nil {
return "", err
}
}
out, err := json.Marshal(top)
if err != nil {
return "", err
}
return string(out), nil
}
// jwtExp decodes the `exp` claim (Unix seconds) from a JWT access token WITHOUT
// verifying the signature (we only need the expiry to decide due-ness; the
// token's validity is OpenAI's to enforce). Returns ok=false when the token is
// not a parseable 3-part JWT or carries no numeric exp.
func jwtExp(token string) (time.Time, bool) {
parts := strings.Split(token, ".")
if len(parts) != 3 {
return time.Time{}, false
}
payload, err := base64.RawURLEncoding.DecodeString(parts[1])
if err != nil {
// Some encoders pad; tolerate standard base64url with padding too.
payload, err = base64.URLEncoding.DecodeString(parts[1])
if err != nil {
return time.Time{}, false
}
}
var claims struct {
Exp json.Number `json:"exp"`
}
if err := json.Unmarshal(payload, &claims); err != nil {
return time.Time{}, false
}
secs, err := claims.Exp.Int64()
if err != nil || secs <= 0 {
return time.Time{}, false
}
return time.Unix(secs, 0), true
}
// postgresStore is the production SecretStore backed by global_secrets, using
// the SAME crypto path the secrets handler uses (DecryptVersioned on read,
// Encrypt + CurrentEncryptionVersion on write).
type postgresStore struct {
db *sql.DB
}
func (s *postgresStore) Get(ctx context.Context, key string) (string, bool, error) {
var enc []byte
var ver int
err := s.db.QueryRowContext(ctx,
`SELECT encrypted_value, encryption_version FROM global_secrets WHERE key = $1`, key).
Scan(&enc, &ver)
if err == sql.ErrNoRows {
return "", false, nil
}
if err != nil {
return "", false, err
}
plain, err := crypto.DecryptVersioned(enc, ver)
if err != nil {
return "", false, err
}
return string(plain), true, nil
}
func (s *postgresStore) Put(ctx context.Context, key, value string) error {
enc, err := crypto.Encrypt([]byte(value))
if err != nil {
return err
}
ver := crypto.CurrentEncryptionVersion()
_, err = s.db.ExecContext(ctx, `
INSERT INTO global_secrets (key, encrypted_value, encryption_version)
VALUES ($1, $2, $3)
ON CONFLICT (key) DO UPDATE
SET encrypted_value = $2, encryption_version = $3, updated_at = now()
`, key, enc, ver)
return err
}
@@ -0,0 +1,425 @@
package codexauth
import (
"context"
"encoding/base64"
"encoding/json"
"fmt"
"io"
"net/http"
"strings"
"sync"
"sync/atomic"
"testing"
"time"
)
// --- test doubles -----------------------------------------------------------
// fakeStore is an in-memory SecretStore. nil entry = absent key.
type fakeStore struct {
mu sync.Mutex
values map[string]string
getErr error
putErr error
puts int32 // count of successful Put calls
}
func newFakeStore() *fakeStore { return &fakeStore{values: map[string]string{}} }
func (f *fakeStore) Get(_ context.Context, key string) (string, bool, error) {
f.mu.Lock()
defer f.mu.Unlock()
if f.getErr != nil {
return "", false, f.getErr
}
v, ok := f.values[key]
return v, ok, nil
}
func (f *fakeStore) Put(_ context.Context, key, value string) error {
f.mu.Lock()
defer f.mu.Unlock()
if f.putErr != nil {
return f.putErr
}
f.values[key] = value
atomic.AddInt32(&f.puts, 1)
return nil
}
func (f *fakeStore) get(key string) string {
f.mu.Lock()
defer f.mu.Unlock()
return f.values[key]
}
// fakeTransport records every request and returns a scripted response. It is
// the network seam — tests NEVER make a real request.
type fakeTransport struct {
mu sync.Mutex
calls int32
urls []string
methods []string
bodies []string
status int
respBody string
transport func(*http.Request) (*http.Response, error) // optional override
}
func (t *fakeTransport) Do(req *http.Request) (*http.Response, error) {
atomic.AddInt32(&t.calls, 1)
t.mu.Lock()
t.urls = append(t.urls, req.URL.String())
t.methods = append(t.methods, req.Method)
if req.Body != nil {
b, _ := io.ReadAll(req.Body)
t.bodies = append(t.bodies, string(b))
} else {
t.bodies = append(t.bodies, "")
}
t.mu.Unlock()
if t.transport != nil {
return t.transport(req)
}
status := t.status
if status == 0 {
status = http.StatusOK
}
return &http.Response{
StatusCode: status,
Body: io.NopCloser(strings.NewReader(t.respBody)),
Header: make(http.Header),
}, nil
}
func (t *fakeTransport) callCount() int { return int(atomic.LoadInt32(&t.calls)) }
// --- helpers ----------------------------------------------------------------
// makeJWT builds an unsigned-but-parseable JWT whose payload carries exp.
func makeJWT(exp time.Time) string {
header := base64.RawURLEncoding.EncodeToString([]byte(`{"alg":"none","typ":"JWT"}`))
payload := base64.RawURLEncoding.EncodeToString([]byte(
fmt.Sprintf(`{"exp":%d,"sub":"codex"}`, exp.Unix())))
sig := base64.RawURLEncoding.EncodeToString([]byte("sig"))
return header + "." + payload + "." + sig
}
// authBlob builds a nested codex auth.json blob with the given tokens.
func authBlob(access, refresh string) string {
b, _ := json.Marshal(map[string]any{
"tokens": map[string]any{
"access_token": access,
"refresh_token": refresh,
"id_token": "id-original",
},
"OPENAI_API_KEY": nil,
"last_refresh": "2026-01-01T00:00:00Z",
})
return string(b)
}
func newTestRefresher(store SecretStore, client httpDoer, now time.Time) *refresher {
return &refresher{
store: store,
client: client,
now: func() time.Time { return now },
}
}
func okRefreshResponse(access, refresh string) string {
b, _ := json.Marshal(oauthTokens{AccessToken: access, RefreshToken: refresh, IDToken: "id-new"})
return string(b)
}
// --- tests ------------------------------------------------------------------
// TestJWTExpParse covers the exp decode (valid, malformed, missing).
func TestJWTExpParse(t *testing.T) {
want := time.Now().Add(2 * time.Hour).Truncate(time.Second)
got, ok := jwtExp(makeJWT(want))
if !ok {
t.Fatalf("jwtExp(valid) ok=false, want true")
}
if !got.Equal(want) {
t.Errorf("jwtExp = %v, want %v", got, want)
}
if _, ok := jwtExp("not-a-jwt"); ok {
t.Errorf("jwtExp(non-jwt) ok=true, want false")
}
if _, ok := jwtExp("a.b.c"); ok {
t.Errorf("jwtExp(garbage parts) ok=true, want false")
}
// 3 parts but payload has no exp.
noExp := base64.RawURLEncoding.EncodeToString([]byte("{}"))
if _, ok := jwtExp("h." + noExp + ".s"); ok {
t.Errorf("jwtExp(no exp claim) ok=true, want false")
}
}
// TestRefreshOnce_SkipWhenFresh: a token well outside the safety margin is NOT
// refreshed — no POST, no write-back.
func TestRefreshOnce_SkipWhenFresh(t *testing.T) {
now := time.Now()
store := newFakeStore()
store.values[CodexAuthSecretKey] = authBlob(makeJWT(now.Add(2*time.Hour)), "rt-1")
tr := &fakeTransport{status: http.StatusOK, respBody: okRefreshResponse("new-at", "rt-2")}
r := newTestRefresher(store, tr, now)
if perm := r.refreshOnce(context.Background()); perm {
t.Fatalf("fresh token: permanentFailure=true, want false")
}
if tr.callCount() != 0 {
t.Errorf("fresh token: %d OAuth POSTs, want 0", tr.callCount())
}
if atomic.LoadInt32(&store.puts) != 0 {
t.Errorf("fresh token: %d write-backs, want 0", store.puts)
}
}
// TestRefreshOnce_RotateThenReskip: a token inside the margin is refreshed once
// (POST + write-back of the rotated blob); a subsequent call on the now-fresh
// rotated token skips (no second POST). Proves rotate→write-back→re-skip.
func TestRefreshOnce_RotateThenReskip(t *testing.T) {
now := time.Now()
store := newFakeStore()
// Expires in 5m — inside the 15m safety margin → DUE.
store.values[CodexAuthSecretKey] = authBlob(makeJWT(now.Add(5*time.Minute)), "rt-1")
// Rotated access token is fresh (2h out); rotated refresh is rt-2.
tr := &fakeTransport{status: http.StatusOK, respBody: okRefreshResponse(makeJWT(now.Add(2*time.Hour)), "rt-2")}
r := newTestRefresher(store, tr, now)
if perm := r.refreshOnce(context.Background()); perm {
t.Fatalf("due token: permanentFailure=true, want false")
}
if tr.callCount() != 1 {
t.Fatalf("due token: %d OAuth POSTs, want exactly 1", tr.callCount())
}
if atomic.LoadInt32(&store.puts) != 1 {
t.Fatalf("due token: %d write-backs, want exactly 1", store.puts)
}
// The written blob must carry the rotated refresh_token and preserve the
// non-token field.
rotated := store.get(CodexAuthSecretKey)
tokens, err := parseTokens(rotated)
if err != nil {
t.Fatalf("parse rotated blob: %v", err)
}
if tokens.RefreshToken != "rt-2" {
t.Errorf("rotated refresh_token = %q, want rt-2", tokens.RefreshToken)
}
if !strings.Contains(rotated, "last_refresh") {
t.Errorf("rotated blob dropped the preserved last_refresh field: %s", rotated)
}
// Second call: the rotated access token is fresh → skip, no new POST.
if perm := r.refreshOnce(context.Background()); perm {
t.Fatalf("re-skip: permanentFailure=true, want false")
}
if tr.callCount() != 1 {
t.Errorf("re-skip: %d total OAuth POSTs, want still 1", tr.callCount())
}
if atomic.LoadInt32(&store.puts) != 1 {
t.Errorf("re-skip: %d total write-backs, want still 1", store.puts)
}
}
// TestRefreshOnce_NoSecretInert: absent CODEX_AUTH_JSON → inert (no POST, no
// write-back, no error/permanent).
func TestRefreshOnce_NoSecretInert(t *testing.T) {
store := newFakeStore() // empty
tr := &fakeTransport{}
r := newTestRefresher(store, tr, time.Now())
if perm := r.refreshOnce(context.Background()); perm {
t.Fatalf("no secret: permanentFailure=true, want false")
}
if tr.callCount() != 0 {
t.Errorf("no secret: %d POSTs, want 0", tr.callCount())
}
if atomic.LoadInt32(&store.puts) != 0 {
t.Errorf("no secret: %d write-backs, want 0", store.puts)
}
}
// TestRefreshOnce_PermanentFailNoWriteNoStorm: a 400 invalid_grant must (a) not
// write back, (b) return permanentFailure=true, and (c) NOT re-POST on the next
// cycle for the same (burned) seed — the anti-storm latch.
func TestRefreshOnce_PermanentFailNoWriteNoStorm(t *testing.T) {
now := time.Now()
store := newFakeStore()
store.values[CodexAuthSecretKey] = authBlob(makeJWT(now.Add(1*time.Minute)), "rt-burned")
tr := &fakeTransport{
status: http.StatusBadRequest,
respBody: `{"error":"invalid_grant","error_description":"refresh token already used"}`,
}
r := newTestRefresher(store, tr, now)
perm := r.refreshOnce(context.Background())
if !perm {
t.Fatalf("invalid_grant: permanentFailure=false, want true")
}
if tr.callCount() != 1 {
t.Fatalf("invalid_grant: %d POSTs, want exactly 1", tr.callCount())
}
if atomic.LoadInt32(&store.puts) != 0 {
t.Fatalf("invalid_grant: %d write-backs, want 0 (must NOT persist a failed refresh)", store.puts)
}
// Next cycle, SAME burned seed: must NOT re-POST (anti-storm latch).
perm2 := r.refreshOnce(context.Background())
if tr.callCount() != 1 {
t.Errorf("anti-storm: re-POSTed a burned refresh_token (%d total POSTs, want still 1)", tr.callCount())
}
_ = perm2 // latched cycle returns false (already-known failure, nothing new)
// A RE-SEED (blob changes) clears the latch and allows a fresh attempt.
store.mu.Lock()
store.values[CodexAuthSecretKey] = authBlob(makeJWT(now.Add(1*time.Minute)), "rt-freshly-seeded")
store.mu.Unlock()
tr.status = http.StatusOK
tr.respBody = okRefreshResponse(makeJWT(now.Add(2*time.Hour)), "rt-rotated")
if perm := r.refreshOnce(context.Background()); perm {
t.Fatalf("post-reseed: permanentFailure=true, want false")
}
if tr.callCount() != 2 {
t.Errorf("post-reseed: %d total POSTs, want 2 (latch should clear on re-seed)", tr.callCount())
}
}
// TestRefreshOnce_TransientNoWriteNoLatch: a 5xx is transient — no write-back,
// returns false (no hard backoff latch), and a later cycle retries.
func TestRefreshOnce_TransientNoWriteNoLatch(t *testing.T) {
now := time.Now()
store := newFakeStore()
store.values[CodexAuthSecretKey] = authBlob(makeJWT(now.Add(1*time.Minute)), "rt-1")
tr := &fakeTransport{status: http.StatusServiceUnavailable, respBody: "upstream down"}
r := newTestRefresher(store, tr, now)
if perm := r.refreshOnce(context.Background()); perm {
t.Fatalf("503: permanentFailure=true, want false (transient)")
}
if atomic.LoadInt32(&store.puts) != 0 {
t.Errorf("503: %d write-backs, want 0", store.puts)
}
// Retry next cycle succeeds (no latch on transient).
tr.status = http.StatusOK
tr.respBody = okRefreshResponse(makeJWT(now.Add(2*time.Hour)), "rt-2")
if perm := r.refreshOnce(context.Background()); perm {
t.Fatalf("retry after 503: permanentFailure=true, want false")
}
if tr.callCount() != 2 {
t.Errorf("transient retry: %d total POSTs, want 2", tr.callCount())
}
if atomic.LoadInt32(&store.puts) != 1 {
t.Errorf("transient retry: %d write-backs, want 1", store.puts)
}
}
// TestRefreshOnce_SingleFlight: concurrent refreshOnce calls on a DUE token must
// POST exactly once total — the package mutex serializes them and the second
// sees the freshly-rotated (now-fresh) token and skips. Structural single-flight.
func TestRefreshOnce_SingleFlight(t *testing.T) {
now := time.Now()
store := newFakeStore()
store.values[CodexAuthSecretKey] = authBlob(makeJWT(now.Add(1*time.Minute)), "rt-1")
// Every successful rotation yields a FRESH (2h) access token, so once one
// caller rotates, the other sees fresh and skips.
tr := &fakeTransport{status: http.StatusOK, respBody: okRefreshResponse(makeJWT(now.Add(2*time.Hour)), "rt-2")}
r := newTestRefresher(store, tr, now)
const n = 16
var wg sync.WaitGroup
wg.Add(n)
for i := 0; i < n; i++ {
go func() {
defer wg.Done()
r.refreshOnce(context.Background())
}()
}
wg.Wait()
if tr.callCount() != 1 {
t.Errorf("single-flight: %d OAuth POSTs across %d concurrent calls, want exactly 1", tr.callCount(), n)
}
if atomic.LoadInt32(&store.puts) != 1 {
t.Errorf("single-flight: %d write-backs, want exactly 1", store.puts)
}
}
// TestRefreshOnce_PostsExactlyOnceToOAuthEndpoint: when it DOES refresh, the
// single POST goes to the OAuth token URL with the refresh_token grant body.
func TestRefreshOnce_PostsExactlyOnceToOAuthEndpoint(t *testing.T) {
now := time.Now()
store := newFakeStore()
store.values[CodexAuthSecretKey] = authBlob(makeJWT(now.Add(1*time.Minute)), "rt-secret")
tr := &fakeTransport{status: http.StatusOK, respBody: okRefreshResponse(makeJWT(now.Add(2*time.Hour)), "rt-2")}
r := newTestRefresher(store, tr, now)
r.refreshOnce(context.Background())
if tr.callCount() != 1 {
t.Fatalf("%d POSTs, want exactly 1", tr.callCount())
}
if tr.urls[0] != oauthTokenURL {
t.Errorf("POST URL = %q, want %q", tr.urls[0], oauthTokenURL)
}
if tr.methods[0] != http.MethodPost {
t.Errorf("method = %q, want POST", tr.methods[0])
}
var body map[string]string
if err := json.Unmarshal([]byte(tr.bodies[0]), &body); err != nil {
t.Fatalf("request body not json: %v (%s)", err, tr.bodies[0])
}
if body["grant_type"] != "refresh_token" {
t.Errorf("grant_type = %q, want refresh_token", body["grant_type"])
}
if body["refresh_token"] != "rt-secret" {
t.Errorf("refresh_token = %q, want rt-secret", body["refresh_token"])
}
if body["client_id"] != codexOAuthClientID {
t.Errorf("client_id = %q, want %q", body["client_id"], codexOAuthClientID)
}
}
// TestRefreshOnce_ReadErrorSkips: a store read error is a transient skip (no
// POST, no permanent latch).
func TestRefreshOnce_ReadErrorSkips(t *testing.T) {
store := newFakeStore()
store.getErr = fmt.Errorf("db down")
tr := &fakeTransport{}
r := newTestRefresher(store, tr, time.Now())
if perm := r.refreshOnce(context.Background()); perm {
t.Errorf("read error: permanentFailure=true, want false")
}
if tr.callCount() != 0 {
t.Errorf("read error: %d POSTs, want 0", tr.callCount())
}
}
// TestMergeTokens_PreservesOtherFields proves the rotated write-back keeps every
// non-token field and does not clobber id_token with an empty rotated value.
func TestMergeTokens_PreservesOtherFields(t *testing.T) {
blob := authBlob("old-at", "old-rt")
out, err := mergeTokens(blob, oauthTokens{AccessToken: "new-at", RefreshToken: "new-rt"}) // no id_token
if err != nil {
t.Fatalf("mergeTokens: %v", err)
}
tokens, err := parseTokens(out)
if err != nil {
t.Fatalf("parse merged: %v", err)
}
if tokens.AccessToken != "new-at" || tokens.RefreshToken != "new-rt" {
t.Errorf("merged tokens = %+v, want new-at/new-rt", tokens)
}
if tokens.IDToken != "id-original" {
t.Errorf("empty rotated id_token clobbered the original: got %q, want id-original", tokens.IDToken)
}
if !strings.Contains(out, "last_refresh") {
t.Errorf("merge dropped preserved field: %s", out)
}
}