fix(workspace-server): central codex OAuth refresher (single-owner, anti-burn) #2023
@@ -36,6 +36,7 @@ import (
|
||||
"time"
|
||||
|
||||
"git.moleculesai.app/molecule-ai/molecule-core/workspace-server/internal/channels"
|
||||
"git.moleculesai.app/molecule-ai/molecule-core/workspace-server/internal/codexauth"
|
||||
"git.moleculesai.app/molecule-ai/molecule-core/workspace-server/internal/crypto"
|
||||
"git.moleculesai.app/molecule-ai/molecule-core/workspace-server/internal/db"
|
||||
"git.moleculesai.app/molecule-ai/molecule-core/workspace-server/internal/events"
|
||||
@@ -334,6 +335,20 @@ func main() {
|
||||
pendinguploads.StartSweeper(c, pendinguploads.NewPostgres(db.DB), 0)
|
||||
})
|
||||
|
||||
// Codex shared-OAuth central refresher — the SINGLE owner of the rotating
|
||||
// refresh_token for the global codex (ChatGPT/Codex subscription) credential
|
||||
// (global_secrets key CODEX_AUTH_JSON). Multiple codex workspaces share ONE
|
||||
// ChatGPT-Pro OAuth token; OpenAI's refresh_token is single-use, so letting
|
||||
// each per-agent app-server refresh on its own 401 burned the seed within
|
||||
// seconds (a refresh storm). This goroutine is structurally single-flight
|
||||
// (one goroutine + a package mutex), refreshes only within a safety margin
|
||||
// of expiry, POSTs the refresh_token at most once per due cycle, and writes
|
||||
// the rotated blob back — workspaces now only GET the current token (see the
|
||||
// codex template's codex_auth_sync.sh). INERT when no CODEX_AUTH_JSON exists.
|
||||
go supervised.RunWithRecover(ctx, "codex-auth-refresher", func(c context.Context) {
|
||||
codexauth.StartCodexAuthRefresher(c, db.DB)
|
||||
})
|
||||
|
||||
// Provision-timeout sweep — flips workspaces that have been stuck in
|
||||
// status='provisioning' past the timeout window to 'failed' and emits
|
||||
// WORKSPACE_PROVISION_TIMEOUT. Without this the UI banner is cosmetic
|
||||
|
||||
@@ -0,0 +1,463 @@
|
||||
// Package codexauth owns the SINGLE, platform-side refresh of the global
|
||||
// codex (ChatGPT/Codex subscription) OAuth credential stored in the
|
||||
// global_secrets table under key CODEX_AUTH_JSON.
|
||||
//
|
||||
// THE PROBLEM IT FIXES (agents-team prod, 2026-05-31)
|
||||
//
|
||||
// Multiple codex workspaces share ONE ChatGPT-Pro OAuth token (the global
|
||||
// secret CODEX_AUTH_JSON). OpenAI's refresh_token is SINGLE-USE: every refresh
|
||||
// rotates it and invalidates the prior one. When each per-agent codex
|
||||
// app-server refreshed independently on a 401, the siblings' in-flight tokens
|
||||
// were invalidated within seconds — a refresh storm that burned the seed and
|
||||
// wedged every codex agent.
|
||||
//
|
||||
// THE FIX (two halves; this is the core half)
|
||||
//
|
||||
// 1. The per-workspace codex app-server NO LONGER refreshes (the template's
|
||||
// OAuth POST is gated off by default — see the codex template's
|
||||
// codex_auth_sync.sh / CODEX_AUTH_REFRESH_OWNER gate). Workspaces only ever
|
||||
// GET the current token and write it to auth.json.
|
||||
// 2. ONE owner refreshes the rotating refresh_token: this background goroutine
|
||||
// in the platform. It is structurally single-flight (one goroutine + a
|
||||
// package mutex), refreshes ONLY when the access_token is within a safety
|
||||
// margin of expiry, POSTs the refresh_token at most ONCE per due cycle, and
|
||||
// writes the rotated blob back to global_secrets. On a permanent failure
|
||||
// (the seed was already burned by an out-of-band login) it logs ONCE and
|
||||
// backs off — it never hot-loops a dead refresh_token.
|
||||
//
|
||||
// Billing-mode resolution and the byok strip are UNTOUCHED by this package.
|
||||
package codexauth
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"log"
|
||||
"net/http"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"git.moleculesai.app/molecule-ai/molecule-core/workspace-server/internal/crypto"
|
||||
)
|
||||
|
||||
const (
|
||||
// CodexAuthSecretKey is the global_secrets key holding the shared codex
|
||||
// ChatGPT/Codex subscription OAuth blob (auth.json contents).
|
||||
CodexAuthSecretKey = "CODEX_AUTH_JSON"
|
||||
|
||||
// oauthTokenURL is OpenAI's OAuth token endpoint. The ONLY endpoint this
|
||||
// package ever POSTs to, and only for a due refresh.
|
||||
oauthTokenURL = "https://auth.openai.com/oauth/token"
|
||||
|
||||
// codexOAuthClientID is the public Codex CLI OAuth client id (the same id
|
||||
// the codex CLI sends). Not a secret.
|
||||
codexOAuthClientID = "app_EMoamEEZ73f0CkXaXp7hrann"
|
||||
|
||||
// refreshSafetyMargin is how far ahead of access_token expiry a refresh is
|
||||
// considered DUE. A token expiring within this window is refreshed now; one
|
||||
// expiring later is left untouched (skip-when-fresh). Generous so a slow
|
||||
// tick can never let the shared token lapse for the fleet.
|
||||
refreshSafetyMargin = 15 * time.Minute
|
||||
|
||||
// defaultInterval is how often the loop wakes to check due-ness. The check
|
||||
// is cheap (decrypt + JWT exp parse) and only POSTs when actually due.
|
||||
defaultInterval = 5 * time.Minute
|
||||
|
||||
// permanentFailureBackoff is how long the loop waits after a PERMANENT
|
||||
// refresh failure (invalid_grant / "refresh token already used"). The seed
|
||||
// is burned until a human re-seeds a fresh login; there is nothing to retry,
|
||||
// so we back off hard rather than hammer the dead token.
|
||||
permanentFailureBackoff = 1 * time.Hour
|
||||
)
|
||||
|
||||
// SecretStore is the minimal global_secrets surface the refresher needs. The
|
||||
// production implementation (postgresStore) is backed by *sql.DB; tests inject
|
||||
// a fake. It is deliberately tiny — read one key, write one key — so the test
|
||||
// double is trivial and the refresher never reaches for the package-global DB.
|
||||
type SecretStore interface {
|
||||
// Get returns the decrypted secret value and true, or ("", false) when the
|
||||
// key is absent. A non-nil error is a real read failure (not absence).
|
||||
Get(ctx context.Context, key string) (value string, found bool, err error)
|
||||
// Put encrypts and upserts value under key, bumping the row's updated_at
|
||||
// (the "last_refresh" timestamp). It is the rotated-blob write-back.
|
||||
Put(ctx context.Context, key, value string) error
|
||||
}
|
||||
|
||||
// httpDoer is the http client seam (real *http.Client in prod, fake transport
|
||||
// in tests). Tests NEVER hit the network.
|
||||
type httpDoer interface {
|
||||
Do(req *http.Request) (*http.Response, error)
|
||||
}
|
||||
|
||||
// refresher is the single-owner refresh engine. The package-level mutex makes
|
||||
// the refresh structurally single-flight: even if two refreshOnce calls raced
|
||||
// (they cannot in prod — one goroutine drives it — but a test or a future
|
||||
// caller might), only one POSTs at a time, and the access-token freshness
|
||||
// re-check inside the lock means the second sees a freshly-rotated token and
|
||||
// skips. One goroutine + this mutex = single-flight by construction.
|
||||
type refresher struct {
|
||||
store SecretStore
|
||||
client httpDoer
|
||||
now func() time.Time
|
||||
|
||||
// permanentlyFailed records that the current seed's refresh_token was
|
||||
// rejected as already-used/invalid. While set, refreshOnce is INERT (it
|
||||
// will not re-POST the dead token) until the secret value CHANGES (a human
|
||||
// re-seed), detected by comparing the stored blob. This is the anti-storm
|
||||
// latch — it lives on the struct, not globally, so it resets if the seed is
|
||||
// replaced out of band.
|
||||
failedSeed string // the auth-json blob that failed; "" = no known failure
|
||||
}
|
||||
|
||||
// mu serializes refreshOnce across the process. Package-level so the
|
||||
// single-flight guarantee holds regardless of how many refresher values exist
|
||||
// (in prod there is exactly one).
|
||||
var mu sync.Mutex
|
||||
|
||||
// oauthTokens is the token trio inside auth.json (and the OAuth response).
|
||||
type oauthTokens struct {
|
||||
AccessToken string `json:"access_token"`
|
||||
RefreshToken string `json:"refresh_token"`
|
||||
IDToken string `json:"id_token,omitempty"`
|
||||
}
|
||||
|
||||
// StartCodexAuthRefresher launches the single background refresher goroutine.
|
||||
// It returns immediately; the loop runs until ctx is cancelled. Wire it under
|
||||
// supervised.RunWithRecover in main.go like the other Start* sweeps.
|
||||
//
|
||||
// db may be nil only in tests that drive refreshOnce directly; in prod it is
|
||||
// the server's *sql.DB. The loop is INERT (logs once, keeps ticking) whenever
|
||||
// CODEX_AUTH_JSON is absent — a deployment with no shared codex seed pays only
|
||||
// a cheap periodic read.
|
||||
func StartCodexAuthRefresher(ctx context.Context, db *sql.DB) {
|
||||
r := &refresher{
|
||||
store: &postgresStore{db: db},
|
||||
client: &http.Client{Timeout: 30 * time.Second},
|
||||
now: time.Now,
|
||||
}
|
||||
r.run(ctx, defaultInterval)
|
||||
}
|
||||
|
||||
// run is the tick loop. It checks due-ness every interval and on a permanent
|
||||
// failure waits permanentFailureBackoff before the next check (never a tight
|
||||
// retry of a burned token).
|
||||
func (r *refresher) run(ctx context.Context, interval time.Duration) {
|
||||
// Check once promptly on boot, then on the interval.
|
||||
for {
|
||||
wait := interval
|
||||
if perm := r.refreshOnce(ctx); perm {
|
||||
// Permanent failure this cycle — the seed is burned. Back off hard;
|
||||
// a human must re-seed. We keep ticking (a re-seed CHANGES the blob,
|
||||
// which clears the latch) but slowly.
|
||||
wait = permanentFailureBackoff
|
||||
}
|
||||
|
||||
timer := time.NewTimer(wait)
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
timer.Stop()
|
||||
log.Printf("codexauth: context done; stopping refresher")
|
||||
return
|
||||
case <-timer.C:
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// refreshOnce performs ONE due-check + at most one refresh POST. It returns
|
||||
// permanentFailure=true iff the refresh_token was permanently rejected this
|
||||
// cycle (the caller backs off). All other outcomes (inert/skip/rotated/transient
|
||||
// error) return false.
|
||||
//
|
||||
// It is single-flight: the package mutex is held for the whole read→decide→
|
||||
// POST→write-back so two callers cannot both POST the (single-use) refresh_token.
|
||||
func (r *refresher) refreshOnce(ctx context.Context) (permanentFailure bool) {
|
||||
mu.Lock()
|
||||
defer mu.Unlock()
|
||||
|
||||
blob, found, err := r.store.Get(ctx, CodexAuthSecretKey)
|
||||
if err != nil {
|
||||
log.Printf("codexauth: read CODEX_AUTH_JSON failed: %v (skipping this cycle)", err)
|
||||
return false
|
||||
}
|
||||
if !found || strings.TrimSpace(blob) == "" {
|
||||
// INERT: no shared codex seed in this deployment. Cheap no-op.
|
||||
log.Printf("codexauth: no CODEX_AUTH_JSON in global_secrets — refresher inert")
|
||||
// A previously-failed seed that has since been DELETED clears the latch.
|
||||
r.failedSeed = ""
|
||||
return false
|
||||
}
|
||||
|
||||
// Anti-storm latch: if THIS exact blob already failed permanently, do not
|
||||
// re-POST its dead refresh_token. A re-seed changes the blob and clears it.
|
||||
if r.failedSeed != "" && r.failedSeed == blob {
|
||||
return false
|
||||
}
|
||||
if r.failedSeed != "" && r.failedSeed != blob {
|
||||
// The seed changed out of band (human re-login) — give it a fresh chance.
|
||||
r.failedSeed = ""
|
||||
}
|
||||
|
||||
tokens, err := parseTokens(blob)
|
||||
if err != nil {
|
||||
log.Printf("codexauth: CODEX_AUTH_JSON is not parseable codex auth json: %v (skipping)", err)
|
||||
return false
|
||||
}
|
||||
if tokens.RefreshToken == "" {
|
||||
log.Printf("codexauth: CODEX_AUTH_JSON carries no refresh_token (skipping)")
|
||||
return false
|
||||
}
|
||||
|
||||
// Skip-when-fresh: only refresh within the safety margin of expiry. A blob
|
||||
// with an unparseable/absent access_token exp is treated as DUE (better to
|
||||
// refresh a token we cannot date than let the fleet lapse).
|
||||
exp, haveExp := jwtExp(tokens.AccessToken)
|
||||
if haveExp {
|
||||
remaining := exp.Sub(r.now())
|
||||
if remaining > refreshSafetyMargin {
|
||||
// Fresh — nothing to do. No POST.
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
// DUE: POST the refresh_token ONCE.
|
||||
newTokens, perm, err := r.doRefresh(ctx, tokens.RefreshToken)
|
||||
if err != nil {
|
||||
if perm {
|
||||
// Permanent: the seed is burned. Latch it so we don't re-POST, log
|
||||
// ONCE, and DO NOT write anything back.
|
||||
log.Printf("codexauth: PERMANENT refresh failure (refresh_token rejected): %v — "+
|
||||
"NOT writing back; the shared CODEX_AUTH_JSON seed is burned and must be re-seeded "+
|
||||
"via a fresh codex login. Backing off.", err)
|
||||
r.failedSeed = blob
|
||||
return true
|
||||
}
|
||||
// Transient (network/5xx): no write-back, retry next cycle (no backoff).
|
||||
log.Printf("codexauth: transient refresh error: %v (will retry next cycle)", err)
|
||||
return false
|
||||
}
|
||||
|
||||
// Success: merge the rotated trio into the blob (preserving every other
|
||||
// field) and write it back encrypted, bumping updated_at (last_refresh).
|
||||
rotated, err := mergeTokens(blob, newTokens)
|
||||
if err != nil {
|
||||
log.Printf("codexauth: failed to merge rotated tokens into auth json: %v (NOT writing back)", err)
|
||||
return false
|
||||
}
|
||||
if err := r.store.Put(ctx, CodexAuthSecretKey, rotated); err != nil {
|
||||
log.Printf("codexauth: write-back of rotated CODEX_AUTH_JSON failed: %v", err)
|
||||
return false
|
||||
}
|
||||
r.failedSeed = "" // success clears any stale latch
|
||||
log.Printf("codexauth: rotated shared CODEX_AUTH_JSON (single-owner refresh)")
|
||||
return false
|
||||
}
|
||||
|
||||
// doRefresh POSTs the refresh_token to OpenAI's OAuth endpoint exactly once and
|
||||
// returns the rotated trio. permanent=true marks an unrecoverable rejection
|
||||
// (HTTP 400 invalid_grant / "refresh token already used") so the caller latches
|
||||
// and backs off instead of retrying.
|
||||
func (r *refresher) doRefresh(ctx context.Context, refreshToken string) (tokens oauthTokens, permanent bool, err error) {
|
||||
body, _ := json.Marshal(map[string]string{
|
||||
"grant_type": "refresh_token",
|
||||
"client_id": codexOAuthClientID,
|
||||
"refresh_token": refreshToken,
|
||||
})
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, oauthTokenURL, strings.NewReader(string(body)))
|
||||
if err != nil {
|
||||
return oauthTokens{}, false, err
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("Accept", "application/json")
|
||||
|
||||
resp, err := r.client.Do(req)
|
||||
if err != nil {
|
||||
return oauthTokens{}, false, err // transient: network
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
respBody, _ := io.ReadAll(io.LimitReader(resp.Body, 1<<20))
|
||||
|
||||
if resp.StatusCode == http.StatusOK {
|
||||
var t oauthTokens
|
||||
if err := json.Unmarshal(respBody, &t); err != nil {
|
||||
return oauthTokens{}, false, fmt.Errorf("decode token response: %w", err)
|
||||
}
|
||||
if t.AccessToken == "" {
|
||||
return oauthTokens{}, false, fmt.Errorf("token response missing access_token")
|
||||
}
|
||||
return t, false, nil
|
||||
}
|
||||
|
||||
// Non-200. A 400 (and any body naming invalid_grant / already-used) is a
|
||||
// PERMANENT rejection of the refresh_token. 401/403 likewise mean the seed
|
||||
// is no good. Everything else (429/5xx/network-shaped) is transient.
|
||||
lowerBody := strings.ToLower(string(respBody))
|
||||
isInvalidGrant := strings.Contains(lowerBody, "invalid_grant") ||
|
||||
strings.Contains(lowerBody, "refresh token already used") ||
|
||||
strings.Contains(lowerBody, "already been used") ||
|
||||
strings.Contains(lowerBody, "token has been revoked")
|
||||
switch {
|
||||
case resp.StatusCode == http.StatusBadRequest && isInvalidGrant:
|
||||
return oauthTokens{}, true, fmt.Errorf("oauth %d: %s", resp.StatusCode, strings.TrimSpace(string(respBody)))
|
||||
case resp.StatusCode == http.StatusUnauthorized || resp.StatusCode == http.StatusForbidden:
|
||||
return oauthTokens{}, true, fmt.Errorf("oauth %d: %s", resp.StatusCode, strings.TrimSpace(string(respBody)))
|
||||
default:
|
||||
return oauthTokens{}, false, fmt.Errorf("oauth %d: %s", resp.StatusCode, strings.TrimSpace(string(respBody)))
|
||||
}
|
||||
}
|
||||
|
||||
// parseTokens extracts the OAuth trio from an auth.json blob, accepting both
|
||||
// the nested `{"tokens":{...}}` shape the codex CLI writes and a flat top-level
|
||||
// shape some seeds use.
|
||||
func parseTokens(blob string) (oauthTokens, error) {
|
||||
var top map[string]json.RawMessage
|
||||
if err := json.Unmarshal([]byte(blob), &top); err != nil {
|
||||
return oauthTokens{}, err
|
||||
}
|
||||
if nested, ok := top["tokens"]; ok {
|
||||
var t oauthTokens
|
||||
if err := json.Unmarshal(nested, &t); err != nil {
|
||||
return oauthTokens{}, fmt.Errorf("decode nested tokens: %w", err)
|
||||
}
|
||||
return t, nil
|
||||
}
|
||||
var t oauthTokens
|
||||
if err := json.Unmarshal([]byte(blob), &t); err != nil {
|
||||
return oauthTokens{}, err
|
||||
}
|
||||
return t, nil
|
||||
}
|
||||
|
||||
// mergeTokens writes the rotated trio back into the original blob in-place,
|
||||
// preserving the blob's shape (nested-vs-flat) and every other field. A field
|
||||
// in the OAuth response that is empty (e.g. id_token omitted) does NOT clobber
|
||||
// the existing value.
|
||||
func mergeTokens(blob string, rotated oauthTokens) (string, error) {
|
||||
var top map[string]json.RawMessage
|
||||
if err := json.Unmarshal([]byte(blob), &top); err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
applyTo := func(m map[string]json.RawMessage) error {
|
||||
setStr := func(key, val string) error {
|
||||
if val == "" {
|
||||
return nil // don't clobber an existing value with an empty one
|
||||
}
|
||||
b, err := json.Marshal(val)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
m[key] = b
|
||||
return nil
|
||||
}
|
||||
if err := setStr("access_token", rotated.AccessToken); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := setStr("refresh_token", rotated.RefreshToken); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := setStr("id_token", rotated.IDToken); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
if nestedRaw, ok := top["tokens"]; ok {
|
||||
var nested map[string]json.RawMessage
|
||||
if err := json.Unmarshal(nestedRaw, &nested); err != nil {
|
||||
return "", fmt.Errorf("decode nested tokens for merge: %w", err)
|
||||
}
|
||||
if err := applyTo(nested); err != nil {
|
||||
return "", err
|
||||
}
|
||||
nb, err := json.Marshal(nested)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
top["tokens"] = nb
|
||||
} else {
|
||||
if err := applyTo(top); err != nil {
|
||||
return "", err
|
||||
}
|
||||
}
|
||||
|
||||
out, err := json.Marshal(top)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return string(out), nil
|
||||
}
|
||||
|
||||
// jwtExp decodes the `exp` claim (Unix seconds) from a JWT access token WITHOUT
|
||||
// verifying the signature (we only need the expiry to decide due-ness; the
|
||||
// token's validity is OpenAI's to enforce). Returns ok=false when the token is
|
||||
// not a parseable 3-part JWT or carries no numeric exp.
|
||||
func jwtExp(token string) (time.Time, bool) {
|
||||
parts := strings.Split(token, ".")
|
||||
if len(parts) != 3 {
|
||||
return time.Time{}, false
|
||||
}
|
||||
payload, err := base64.RawURLEncoding.DecodeString(parts[1])
|
||||
if err != nil {
|
||||
// Some encoders pad; tolerate standard base64url with padding too.
|
||||
payload, err = base64.URLEncoding.DecodeString(parts[1])
|
||||
if err != nil {
|
||||
return time.Time{}, false
|
||||
}
|
||||
}
|
||||
var claims struct {
|
||||
Exp json.Number `json:"exp"`
|
||||
}
|
||||
if err := json.Unmarshal(payload, &claims); err != nil {
|
||||
return time.Time{}, false
|
||||
}
|
||||
secs, err := claims.Exp.Int64()
|
||||
if err != nil || secs <= 0 {
|
||||
return time.Time{}, false
|
||||
}
|
||||
return time.Unix(secs, 0), true
|
||||
}
|
||||
|
||||
// postgresStore is the production SecretStore backed by global_secrets, using
|
||||
// the SAME crypto path the secrets handler uses (DecryptVersioned on read,
|
||||
// Encrypt + CurrentEncryptionVersion on write).
|
||||
type postgresStore struct {
|
||||
db *sql.DB
|
||||
}
|
||||
|
||||
func (s *postgresStore) Get(ctx context.Context, key string) (string, bool, error) {
|
||||
var enc []byte
|
||||
var ver int
|
||||
err := s.db.QueryRowContext(ctx,
|
||||
`SELECT encrypted_value, encryption_version FROM global_secrets WHERE key = $1`, key).
|
||||
Scan(&enc, &ver)
|
||||
if err == sql.ErrNoRows {
|
||||
return "", false, nil
|
||||
}
|
||||
if err != nil {
|
||||
return "", false, err
|
||||
}
|
||||
plain, err := crypto.DecryptVersioned(enc, ver)
|
||||
if err != nil {
|
||||
return "", false, err
|
||||
}
|
||||
return string(plain), true, nil
|
||||
}
|
||||
|
||||
func (s *postgresStore) Put(ctx context.Context, key, value string) error {
|
||||
enc, err := crypto.Encrypt([]byte(value))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
ver := crypto.CurrentEncryptionVersion()
|
||||
_, err = s.db.ExecContext(ctx, `
|
||||
INSERT INTO global_secrets (key, encrypted_value, encryption_version)
|
||||
VALUES ($1, $2, $3)
|
||||
ON CONFLICT (key) DO UPDATE
|
||||
SET encrypted_value = $2, encryption_version = $3, updated_at = now()
|
||||
`, key, enc, ver)
|
||||
return err
|
||||
}
|
||||
@@ -0,0 +1,425 @@
|
||||
package codexauth
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
// --- test doubles -----------------------------------------------------------
|
||||
|
||||
// fakeStore is an in-memory SecretStore. nil entry = absent key.
|
||||
type fakeStore struct {
|
||||
mu sync.Mutex
|
||||
values map[string]string
|
||||
getErr error
|
||||
putErr error
|
||||
puts int32 // count of successful Put calls
|
||||
}
|
||||
|
||||
func newFakeStore() *fakeStore { return &fakeStore{values: map[string]string{}} }
|
||||
|
||||
func (f *fakeStore) Get(_ context.Context, key string) (string, bool, error) {
|
||||
f.mu.Lock()
|
||||
defer f.mu.Unlock()
|
||||
if f.getErr != nil {
|
||||
return "", false, f.getErr
|
||||
}
|
||||
v, ok := f.values[key]
|
||||
return v, ok, nil
|
||||
}
|
||||
|
||||
func (f *fakeStore) Put(_ context.Context, key, value string) error {
|
||||
f.mu.Lock()
|
||||
defer f.mu.Unlock()
|
||||
if f.putErr != nil {
|
||||
return f.putErr
|
||||
}
|
||||
f.values[key] = value
|
||||
atomic.AddInt32(&f.puts, 1)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (f *fakeStore) get(key string) string {
|
||||
f.mu.Lock()
|
||||
defer f.mu.Unlock()
|
||||
return f.values[key]
|
||||
}
|
||||
|
||||
// fakeTransport records every request and returns a scripted response. It is
|
||||
// the network seam — tests NEVER make a real request.
|
||||
type fakeTransport struct {
|
||||
mu sync.Mutex
|
||||
calls int32
|
||||
urls []string
|
||||
methods []string
|
||||
bodies []string
|
||||
status int
|
||||
respBody string
|
||||
transport func(*http.Request) (*http.Response, error) // optional override
|
||||
}
|
||||
|
||||
func (t *fakeTransport) Do(req *http.Request) (*http.Response, error) {
|
||||
atomic.AddInt32(&t.calls, 1)
|
||||
t.mu.Lock()
|
||||
t.urls = append(t.urls, req.URL.String())
|
||||
t.methods = append(t.methods, req.Method)
|
||||
if req.Body != nil {
|
||||
b, _ := io.ReadAll(req.Body)
|
||||
t.bodies = append(t.bodies, string(b))
|
||||
} else {
|
||||
t.bodies = append(t.bodies, "")
|
||||
}
|
||||
t.mu.Unlock()
|
||||
|
||||
if t.transport != nil {
|
||||
return t.transport(req)
|
||||
}
|
||||
status := t.status
|
||||
if status == 0 {
|
||||
status = http.StatusOK
|
||||
}
|
||||
return &http.Response{
|
||||
StatusCode: status,
|
||||
Body: io.NopCloser(strings.NewReader(t.respBody)),
|
||||
Header: make(http.Header),
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (t *fakeTransport) callCount() int { return int(atomic.LoadInt32(&t.calls)) }
|
||||
|
||||
// --- helpers ----------------------------------------------------------------
|
||||
|
||||
// makeJWT builds an unsigned-but-parseable JWT whose payload carries exp.
|
||||
func makeJWT(exp time.Time) string {
|
||||
header := base64.RawURLEncoding.EncodeToString([]byte(`{"alg":"none","typ":"JWT"}`))
|
||||
payload := base64.RawURLEncoding.EncodeToString([]byte(
|
||||
fmt.Sprintf(`{"exp":%d,"sub":"codex"}`, exp.Unix())))
|
||||
sig := base64.RawURLEncoding.EncodeToString([]byte("sig"))
|
||||
return header + "." + payload + "." + sig
|
||||
}
|
||||
|
||||
// authBlob builds a nested codex auth.json blob with the given tokens.
|
||||
func authBlob(access, refresh string) string {
|
||||
b, _ := json.Marshal(map[string]any{
|
||||
"tokens": map[string]any{
|
||||
"access_token": access,
|
||||
"refresh_token": refresh,
|
||||
"id_token": "id-original",
|
||||
},
|
||||
"OPENAI_API_KEY": nil,
|
||||
"last_refresh": "2026-01-01T00:00:00Z",
|
||||
})
|
||||
return string(b)
|
||||
}
|
||||
|
||||
func newTestRefresher(store SecretStore, client httpDoer, now time.Time) *refresher {
|
||||
return &refresher{
|
||||
store: store,
|
||||
client: client,
|
||||
now: func() time.Time { return now },
|
||||
}
|
||||
}
|
||||
|
||||
func okRefreshResponse(access, refresh string) string {
|
||||
b, _ := json.Marshal(oauthTokens{AccessToken: access, RefreshToken: refresh, IDToken: "id-new"})
|
||||
return string(b)
|
||||
}
|
||||
|
||||
// --- tests ------------------------------------------------------------------
|
||||
|
||||
// TestJWTExpParse covers the exp decode (valid, malformed, missing).
|
||||
func TestJWTExpParse(t *testing.T) {
|
||||
want := time.Now().Add(2 * time.Hour).Truncate(time.Second)
|
||||
got, ok := jwtExp(makeJWT(want))
|
||||
if !ok {
|
||||
t.Fatalf("jwtExp(valid) ok=false, want true")
|
||||
}
|
||||
if !got.Equal(want) {
|
||||
t.Errorf("jwtExp = %v, want %v", got, want)
|
||||
}
|
||||
|
||||
if _, ok := jwtExp("not-a-jwt"); ok {
|
||||
t.Errorf("jwtExp(non-jwt) ok=true, want false")
|
||||
}
|
||||
if _, ok := jwtExp("a.b.c"); ok {
|
||||
t.Errorf("jwtExp(garbage parts) ok=true, want false")
|
||||
}
|
||||
// 3 parts but payload has no exp.
|
||||
noExp := base64.RawURLEncoding.EncodeToString([]byte("{}"))
|
||||
if _, ok := jwtExp("h." + noExp + ".s"); ok {
|
||||
t.Errorf("jwtExp(no exp claim) ok=true, want false")
|
||||
}
|
||||
}
|
||||
|
||||
// TestRefreshOnce_SkipWhenFresh: a token well outside the safety margin is NOT
|
||||
// refreshed — no POST, no write-back.
|
||||
func TestRefreshOnce_SkipWhenFresh(t *testing.T) {
|
||||
now := time.Now()
|
||||
store := newFakeStore()
|
||||
store.values[CodexAuthSecretKey] = authBlob(makeJWT(now.Add(2*time.Hour)), "rt-1")
|
||||
tr := &fakeTransport{status: http.StatusOK, respBody: okRefreshResponse("new-at", "rt-2")}
|
||||
r := newTestRefresher(store, tr, now)
|
||||
|
||||
if perm := r.refreshOnce(context.Background()); perm {
|
||||
t.Fatalf("fresh token: permanentFailure=true, want false")
|
||||
}
|
||||
if tr.callCount() != 0 {
|
||||
t.Errorf("fresh token: %d OAuth POSTs, want 0", tr.callCount())
|
||||
}
|
||||
if atomic.LoadInt32(&store.puts) != 0 {
|
||||
t.Errorf("fresh token: %d write-backs, want 0", store.puts)
|
||||
}
|
||||
}
|
||||
|
||||
// TestRefreshOnce_RotateThenReskip: a token inside the margin is refreshed once
|
||||
// (POST + write-back of the rotated blob); a subsequent call on the now-fresh
|
||||
// rotated token skips (no second POST). Proves rotate→write-back→re-skip.
|
||||
func TestRefreshOnce_RotateThenReskip(t *testing.T) {
|
||||
now := time.Now()
|
||||
store := newFakeStore()
|
||||
// Expires in 5m — inside the 15m safety margin → DUE.
|
||||
store.values[CodexAuthSecretKey] = authBlob(makeJWT(now.Add(5*time.Minute)), "rt-1")
|
||||
// Rotated access token is fresh (2h out); rotated refresh is rt-2.
|
||||
tr := &fakeTransport{status: http.StatusOK, respBody: okRefreshResponse(makeJWT(now.Add(2*time.Hour)), "rt-2")}
|
||||
r := newTestRefresher(store, tr, now)
|
||||
|
||||
if perm := r.refreshOnce(context.Background()); perm {
|
||||
t.Fatalf("due token: permanentFailure=true, want false")
|
||||
}
|
||||
if tr.callCount() != 1 {
|
||||
t.Fatalf("due token: %d OAuth POSTs, want exactly 1", tr.callCount())
|
||||
}
|
||||
if atomic.LoadInt32(&store.puts) != 1 {
|
||||
t.Fatalf("due token: %d write-backs, want exactly 1", store.puts)
|
||||
}
|
||||
|
||||
// The written blob must carry the rotated refresh_token and preserve the
|
||||
// non-token field.
|
||||
rotated := store.get(CodexAuthSecretKey)
|
||||
tokens, err := parseTokens(rotated)
|
||||
if err != nil {
|
||||
t.Fatalf("parse rotated blob: %v", err)
|
||||
}
|
||||
if tokens.RefreshToken != "rt-2" {
|
||||
t.Errorf("rotated refresh_token = %q, want rt-2", tokens.RefreshToken)
|
||||
}
|
||||
if !strings.Contains(rotated, "last_refresh") {
|
||||
t.Errorf("rotated blob dropped the preserved last_refresh field: %s", rotated)
|
||||
}
|
||||
|
||||
// Second call: the rotated access token is fresh → skip, no new POST.
|
||||
if perm := r.refreshOnce(context.Background()); perm {
|
||||
t.Fatalf("re-skip: permanentFailure=true, want false")
|
||||
}
|
||||
if tr.callCount() != 1 {
|
||||
t.Errorf("re-skip: %d total OAuth POSTs, want still 1", tr.callCount())
|
||||
}
|
||||
if atomic.LoadInt32(&store.puts) != 1 {
|
||||
t.Errorf("re-skip: %d total write-backs, want still 1", store.puts)
|
||||
}
|
||||
}
|
||||
|
||||
// TestRefreshOnce_NoSecretInert: absent CODEX_AUTH_JSON → inert (no POST, no
|
||||
// write-back, no error/permanent).
|
||||
func TestRefreshOnce_NoSecretInert(t *testing.T) {
|
||||
store := newFakeStore() // empty
|
||||
tr := &fakeTransport{}
|
||||
r := newTestRefresher(store, tr, time.Now())
|
||||
|
||||
if perm := r.refreshOnce(context.Background()); perm {
|
||||
t.Fatalf("no secret: permanentFailure=true, want false")
|
||||
}
|
||||
if tr.callCount() != 0 {
|
||||
t.Errorf("no secret: %d POSTs, want 0", tr.callCount())
|
||||
}
|
||||
if atomic.LoadInt32(&store.puts) != 0 {
|
||||
t.Errorf("no secret: %d write-backs, want 0", store.puts)
|
||||
}
|
||||
}
|
||||
|
||||
// TestRefreshOnce_PermanentFailNoWriteNoStorm: a 400 invalid_grant must (a) not
|
||||
// write back, (b) return permanentFailure=true, and (c) NOT re-POST on the next
|
||||
// cycle for the same (burned) seed — the anti-storm latch.
|
||||
func TestRefreshOnce_PermanentFailNoWriteNoStorm(t *testing.T) {
|
||||
now := time.Now()
|
||||
store := newFakeStore()
|
||||
store.values[CodexAuthSecretKey] = authBlob(makeJWT(now.Add(1*time.Minute)), "rt-burned")
|
||||
tr := &fakeTransport{
|
||||
status: http.StatusBadRequest,
|
||||
respBody: `{"error":"invalid_grant","error_description":"refresh token already used"}`,
|
||||
}
|
||||
r := newTestRefresher(store, tr, now)
|
||||
|
||||
perm := r.refreshOnce(context.Background())
|
||||
if !perm {
|
||||
t.Fatalf("invalid_grant: permanentFailure=false, want true")
|
||||
}
|
||||
if tr.callCount() != 1 {
|
||||
t.Fatalf("invalid_grant: %d POSTs, want exactly 1", tr.callCount())
|
||||
}
|
||||
if atomic.LoadInt32(&store.puts) != 0 {
|
||||
t.Fatalf("invalid_grant: %d write-backs, want 0 (must NOT persist a failed refresh)", store.puts)
|
||||
}
|
||||
|
||||
// Next cycle, SAME burned seed: must NOT re-POST (anti-storm latch).
|
||||
perm2 := r.refreshOnce(context.Background())
|
||||
if tr.callCount() != 1 {
|
||||
t.Errorf("anti-storm: re-POSTed a burned refresh_token (%d total POSTs, want still 1)", tr.callCount())
|
||||
}
|
||||
_ = perm2 // latched cycle returns false (already-known failure, nothing new)
|
||||
|
||||
// A RE-SEED (blob changes) clears the latch and allows a fresh attempt.
|
||||
store.mu.Lock()
|
||||
store.values[CodexAuthSecretKey] = authBlob(makeJWT(now.Add(1*time.Minute)), "rt-freshly-seeded")
|
||||
store.mu.Unlock()
|
||||
tr.status = http.StatusOK
|
||||
tr.respBody = okRefreshResponse(makeJWT(now.Add(2*time.Hour)), "rt-rotated")
|
||||
if perm := r.refreshOnce(context.Background()); perm {
|
||||
t.Fatalf("post-reseed: permanentFailure=true, want false")
|
||||
}
|
||||
if tr.callCount() != 2 {
|
||||
t.Errorf("post-reseed: %d total POSTs, want 2 (latch should clear on re-seed)", tr.callCount())
|
||||
}
|
||||
}
|
||||
|
||||
// TestRefreshOnce_TransientNoWriteNoLatch: a 5xx is transient — no write-back,
|
||||
// returns false (no hard backoff latch), and a later cycle retries.
|
||||
func TestRefreshOnce_TransientNoWriteNoLatch(t *testing.T) {
|
||||
now := time.Now()
|
||||
store := newFakeStore()
|
||||
store.values[CodexAuthSecretKey] = authBlob(makeJWT(now.Add(1*time.Minute)), "rt-1")
|
||||
tr := &fakeTransport{status: http.StatusServiceUnavailable, respBody: "upstream down"}
|
||||
r := newTestRefresher(store, tr, now)
|
||||
|
||||
if perm := r.refreshOnce(context.Background()); perm {
|
||||
t.Fatalf("503: permanentFailure=true, want false (transient)")
|
||||
}
|
||||
if atomic.LoadInt32(&store.puts) != 0 {
|
||||
t.Errorf("503: %d write-backs, want 0", store.puts)
|
||||
}
|
||||
// Retry next cycle succeeds (no latch on transient).
|
||||
tr.status = http.StatusOK
|
||||
tr.respBody = okRefreshResponse(makeJWT(now.Add(2*time.Hour)), "rt-2")
|
||||
if perm := r.refreshOnce(context.Background()); perm {
|
||||
t.Fatalf("retry after 503: permanentFailure=true, want false")
|
||||
}
|
||||
if tr.callCount() != 2 {
|
||||
t.Errorf("transient retry: %d total POSTs, want 2", tr.callCount())
|
||||
}
|
||||
if atomic.LoadInt32(&store.puts) != 1 {
|
||||
t.Errorf("transient retry: %d write-backs, want 1", store.puts)
|
||||
}
|
||||
}
|
||||
|
||||
// TestRefreshOnce_SingleFlight: concurrent refreshOnce calls on a DUE token must
|
||||
// POST exactly once total — the package mutex serializes them and the second
|
||||
// sees the freshly-rotated (now-fresh) token and skips. Structural single-flight.
|
||||
func TestRefreshOnce_SingleFlight(t *testing.T) {
|
||||
now := time.Now()
|
||||
store := newFakeStore()
|
||||
store.values[CodexAuthSecretKey] = authBlob(makeJWT(now.Add(1*time.Minute)), "rt-1")
|
||||
// Every successful rotation yields a FRESH (2h) access token, so once one
|
||||
// caller rotates, the other sees fresh and skips.
|
||||
tr := &fakeTransport{status: http.StatusOK, respBody: okRefreshResponse(makeJWT(now.Add(2*time.Hour)), "rt-2")}
|
||||
r := newTestRefresher(store, tr, now)
|
||||
|
||||
const n = 16
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(n)
|
||||
for i := 0; i < n; i++ {
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
r.refreshOnce(context.Background())
|
||||
}()
|
||||
}
|
||||
wg.Wait()
|
||||
|
||||
if tr.callCount() != 1 {
|
||||
t.Errorf("single-flight: %d OAuth POSTs across %d concurrent calls, want exactly 1", tr.callCount(), n)
|
||||
}
|
||||
if atomic.LoadInt32(&store.puts) != 1 {
|
||||
t.Errorf("single-flight: %d write-backs, want exactly 1", store.puts)
|
||||
}
|
||||
}
|
||||
|
||||
// TestRefreshOnce_PostsExactlyOnceToOAuthEndpoint: when it DOES refresh, the
|
||||
// single POST goes to the OAuth token URL with the refresh_token grant body.
|
||||
func TestRefreshOnce_PostsExactlyOnceToOAuthEndpoint(t *testing.T) {
|
||||
now := time.Now()
|
||||
store := newFakeStore()
|
||||
store.values[CodexAuthSecretKey] = authBlob(makeJWT(now.Add(1*time.Minute)), "rt-secret")
|
||||
tr := &fakeTransport{status: http.StatusOK, respBody: okRefreshResponse(makeJWT(now.Add(2*time.Hour)), "rt-2")}
|
||||
r := newTestRefresher(store, tr, now)
|
||||
|
||||
r.refreshOnce(context.Background())
|
||||
|
||||
if tr.callCount() != 1 {
|
||||
t.Fatalf("%d POSTs, want exactly 1", tr.callCount())
|
||||
}
|
||||
if tr.urls[0] != oauthTokenURL {
|
||||
t.Errorf("POST URL = %q, want %q", tr.urls[0], oauthTokenURL)
|
||||
}
|
||||
if tr.methods[0] != http.MethodPost {
|
||||
t.Errorf("method = %q, want POST", tr.methods[0])
|
||||
}
|
||||
var body map[string]string
|
||||
if err := json.Unmarshal([]byte(tr.bodies[0]), &body); err != nil {
|
||||
t.Fatalf("request body not json: %v (%s)", err, tr.bodies[0])
|
||||
}
|
||||
if body["grant_type"] != "refresh_token" {
|
||||
t.Errorf("grant_type = %q, want refresh_token", body["grant_type"])
|
||||
}
|
||||
if body["refresh_token"] != "rt-secret" {
|
||||
t.Errorf("refresh_token = %q, want rt-secret", body["refresh_token"])
|
||||
}
|
||||
if body["client_id"] != codexOAuthClientID {
|
||||
t.Errorf("client_id = %q, want %q", body["client_id"], codexOAuthClientID)
|
||||
}
|
||||
}
|
||||
|
||||
// TestRefreshOnce_ReadErrorSkips: a store read error is a transient skip (no
|
||||
// POST, no permanent latch).
|
||||
func TestRefreshOnce_ReadErrorSkips(t *testing.T) {
|
||||
store := newFakeStore()
|
||||
store.getErr = fmt.Errorf("db down")
|
||||
tr := &fakeTransport{}
|
||||
r := newTestRefresher(store, tr, time.Now())
|
||||
if perm := r.refreshOnce(context.Background()); perm {
|
||||
t.Errorf("read error: permanentFailure=true, want false")
|
||||
}
|
||||
if tr.callCount() != 0 {
|
||||
t.Errorf("read error: %d POSTs, want 0", tr.callCount())
|
||||
}
|
||||
}
|
||||
|
||||
// TestMergeTokens_PreservesOtherFields proves the rotated write-back keeps every
|
||||
// non-token field and does not clobber id_token with an empty rotated value.
|
||||
func TestMergeTokens_PreservesOtherFields(t *testing.T) {
|
||||
blob := authBlob("old-at", "old-rt")
|
||||
out, err := mergeTokens(blob, oauthTokens{AccessToken: "new-at", RefreshToken: "new-rt"}) // no id_token
|
||||
if err != nil {
|
||||
t.Fatalf("mergeTokens: %v", err)
|
||||
}
|
||||
tokens, err := parseTokens(out)
|
||||
if err != nil {
|
||||
t.Fatalf("parse merged: %v", err)
|
||||
}
|
||||
if tokens.AccessToken != "new-at" || tokens.RefreshToken != "new-rt" {
|
||||
t.Errorf("merged tokens = %+v, want new-at/new-rt", tokens)
|
||||
}
|
||||
if tokens.IDToken != "id-original" {
|
||||
t.Errorf("empty rotated id_token clobbered the original: got %q, want id-original", tokens.IDToken)
|
||||
}
|
||||
if !strings.Contains(out, "last_refresh") {
|
||||
t.Errorf("merge dropped preserved field: %s", out)
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user