feat(auth): add skip persistence context key for file watcher events

Introduce `WithSkipPersist` to disable persistence during Manager Update/Register calls, preventing write-back loops caused by redundant file writes. Add corresponding tests and integrate with existing file store and conductor logic.
This commit is contained in:
Luis Pater
2026-01-26 18:20:19 +08:00
parent 2af4a8dc12
commit 9c341f5aa5
5 changed files with 92 additions and 40 deletions

View File

@@ -73,9 +73,7 @@ func (s *FileTokenStore) Save(ctx context.Context, auth *cliproxyauth.Auth) (str
return "", fmt.Errorf("auth filestore: marshal metadata failed: %w", errMarshal) return "", fmt.Errorf("auth filestore: marshal metadata failed: %w", errMarshal)
} }
if existing, errRead := os.ReadFile(path); errRead == nil { if existing, errRead := os.ReadFile(path); errRead == nil {
// Use metadataEqualIgnoringTimestamps to skip writes when only timestamp fields change. if jsonEqual(existing, raw) {
// This prevents the token refresh loop caused by timestamp/expired/expires_in changes.
if metadataEqualIgnoringTimestamps(existing, raw, auth.Provider) {
return path, nil return path, nil
} }
file, errOpen := os.OpenFile(path, os.O_WRONLY|os.O_TRUNC, 0o600) file, errOpen := os.OpenFile(path, os.O_WRONLY|os.O_TRUNC, 0o600)
@@ -299,8 +297,7 @@ func (s *FileTokenStore) baseDirSnapshot() string {
return s.baseDir return s.baseDir
} }
// DEPRECATED: Use metadataEqualIgnoringTimestamps for comparing auth metadata. // jsonEqual compares two JSON blobs by parsing them into Go objects and deep comparing.
// This function is kept for backward compatibility but can cause refresh loops.
func jsonEqual(a, b []byte) bool { func jsonEqual(a, b []byte) bool {
var objA any var objA any
var objB any var objB any
@@ -313,41 +310,6 @@ func jsonEqual(a, b []byte) bool {
return deepEqualJSON(objA, objB) return deepEqualJSON(objA, objB)
} }
// metadataEqualIgnoringTimestamps compares two metadata JSON blobs,
// ignoring fields that change on every refresh but don't affect functionality.
// This prevents unnecessary file writes that would trigger watcher events and
// create refresh loops.
// The provider parameter controls whether access_token is ignored: providers like
// Google OAuth (gemini, gemini-cli) can re-fetch tokens when needed, while others
// like iFlow require the refreshed token to be persisted.
func metadataEqualIgnoringTimestamps(a, b []byte, provider string) bool {
var objA, objB map[string]any
if err := json.Unmarshal(a, &objA); err != nil {
return false
}
if err := json.Unmarshal(b, &objB); err != nil {
return false
}
// Fields to ignore: these change on every refresh but don't affect authentication logic.
// - timestamp, expired, expires_in, last_refresh: time-related fields that change on refresh
ignoredFields := []string{"timestamp", "expired", "expires_in", "last_refresh"}
// For providers that can re-fetch tokens when needed (e.g., Google OAuth),
// we ignore access_token to avoid unnecessary file writes.
switch provider {
case "gemini", "gemini-cli", "antigravity":
ignoredFields = append(ignoredFields, "access_token")
}
for _, field := range ignoredFields {
delete(objA, field)
delete(objB, field)
}
return deepEqualJSON(objA, objB)
}
func deepEqualJSON(a, b any) bool { func deepEqualJSON(a, b any) bool {
switch valA := a.(type) { switch valA := a.(type) {
case map[string]any: case map[string]any:

View File

@@ -1642,6 +1642,9 @@ func (m *Manager) persist(ctx context.Context, auth *Auth) error {
if m.store == nil || auth == nil { if m.store == nil || auth == nil {
return nil return nil
} }
if shouldSkipPersist(ctx) {
return nil
}
if auth.Attributes != nil { if auth.Attributes != nil {
if v := strings.ToLower(strings.TrimSpace(auth.Attributes["runtime_only"])); v == "true" { if v := strings.ToLower(strings.TrimSpace(auth.Attributes["runtime_only"])); v == "true" {
return nil return nil

View File

@@ -0,0 +1,24 @@
package auth
import "context"
type skipPersistContextKey struct{}
// WithSkipPersist returns a derived context that disables persistence for Manager Update/Register calls.
// It is intended for code paths that are reacting to file watcher events, where the file on disk is
// already the source of truth and persisting again would create a write-back loop.
func WithSkipPersist(ctx context.Context) context.Context {
if ctx == nil {
ctx = context.Background()
}
return context.WithValue(ctx, skipPersistContextKey{}, true)
}
func shouldSkipPersist(ctx context.Context) bool {
if ctx == nil {
return false
}
v := ctx.Value(skipPersistContextKey{})
enabled, ok := v.(bool)
return ok && enabled
}

View File

@@ -0,0 +1,62 @@
package auth
import (
"context"
"sync/atomic"
"testing"
)
type countingStore struct {
saveCount atomic.Int32
}
func (s *countingStore) List(context.Context) ([]*Auth, error) { return nil, nil }
func (s *countingStore) Save(context.Context, *Auth) (string, error) {
s.saveCount.Add(1)
return "", nil
}
func (s *countingStore) Delete(context.Context, string) error { return nil }
func TestWithSkipPersist_DisablesUpdatePersistence(t *testing.T) {
store := &countingStore{}
mgr := NewManager(store, nil, nil)
auth := &Auth{
ID: "auth-1",
Provider: "antigravity",
Metadata: map[string]any{"type": "antigravity"},
}
if _, err := mgr.Update(context.Background(), auth); err != nil {
t.Fatalf("Update returned error: %v", err)
}
if got := store.saveCount.Load(); got != 1 {
t.Fatalf("expected 1 Save call, got %d", got)
}
ctxSkip := WithSkipPersist(context.Background())
if _, err := mgr.Update(ctxSkip, auth); err != nil {
t.Fatalf("Update(skipPersist) returned error: %v", err)
}
if got := store.saveCount.Load(); got != 1 {
t.Fatalf("expected Save call count to remain 1, got %d", got)
}
}
func TestWithSkipPersist_DisablesRegisterPersistence(t *testing.T) {
store := &countingStore{}
mgr := NewManager(store, nil, nil)
auth := &Auth{
ID: "auth-1",
Provider: "antigravity",
Metadata: map[string]any{"type": "antigravity"},
}
if _, err := mgr.Register(WithSkipPersist(context.Background()), auth); err != nil {
t.Fatalf("Register(skipPersist) returned error: %v", err)
}
if got := store.saveCount.Load(); got != 0 {
t.Fatalf("expected 0 Save calls, got %d", got)
}
}

View File

@@ -124,6 +124,7 @@ func (s *Service) ensureAuthUpdateQueue(ctx context.Context) {
} }
func (s *Service) consumeAuthUpdates(ctx context.Context) { func (s *Service) consumeAuthUpdates(ctx context.Context) {
ctx = coreauth.WithSkipPersist(ctx)
for { for {
select { select {
case <-ctx.Done(): case <-ctx.Done():