diff --git a/internal/api/handlers/management/auth_files.go b/internal/api/handlers/management/auth_files.go index bf5a5b9c..41a4fde4 100644 --- a/internal/api/handlers/management/auth_files.go +++ b/internal/api/handlers/management/auth_files.go @@ -197,6 +197,19 @@ func stopCallbackForwarder(port int) { stopForwarderInstance(port, forwarder) } +func stopCallbackForwarderInstance(port int, forwarder *callbackForwarder) { + if forwarder == nil { + return + } + callbackForwardersMu.Lock() + if current := callbackForwarders[port]; current == forwarder { + delete(callbackForwarders, port) + } + callbackForwardersMu.Unlock() + + stopForwarderInstance(port, forwarder) +} + func stopForwarderInstance(port int, forwarder *callbackForwarder) { if forwarder == nil || forwarder.server == nil { return @@ -785,6 +798,7 @@ func (h *Handler) RequestAnthropicToken(c *gin.Context) { RegisterOAuthSession(state, "anthropic") isWebUI := isWebUIRequest(c) + var forwarder *callbackForwarder if isWebUI { targetURL, errTarget := h.managementCallbackURL("/anthropic/callback") if errTarget != nil { @@ -792,7 +806,8 @@ func (h *Handler) RequestAnthropicToken(c *gin.Context) { c.JSON(http.StatusInternalServerError, gin.H{"error": "callback server unavailable"}) return } - if _, errStart := startCallbackForwarder(anthropicCallbackPort, "anthropic", targetURL); errStart != nil { + var errStart error + if forwarder, errStart = startCallbackForwarder(anthropicCallbackPort, "anthropic", targetURL); errStart != nil { log.WithError(errStart).Error("failed to start anthropic callback forwarder") c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to start callback server"}) return @@ -801,7 +816,7 @@ func (h *Handler) RequestAnthropicToken(c *gin.Context) { go func() { if isWebUI { - defer stopCallbackForwarder(anthropicCallbackPort) + defer stopCallbackForwarderInstance(anthropicCallbackPort, forwarder) } // Helper: wait for callback file @@ -809,6 +824,9 @@ func (h *Handler) RequestAnthropicToken(c *gin.Context) { waitForFile := func(path string, timeout time.Duration) (map[string]string, error) { deadline := time.Now().Add(timeout) for { + if !IsOAuthSessionPending(state, "anthropic") { + return nil, errOAuthSessionNotPending + } if time.Now().After(deadline) { SetOAuthSessionError(state, "Timeout waiting for OAuth callback") return nil, fmt.Errorf("timeout waiting for OAuth callback") @@ -828,6 +846,9 @@ func (h *Handler) RequestAnthropicToken(c *gin.Context) { // Wait up to 5 minutes resultMap, errWait := waitForFile(waitFile, 5*time.Minute) if errWait != nil { + if errors.Is(errWait, errOAuthSessionNotPending) { + return + } authErr := claude.NewAuthenticationError(claude.ErrCallbackTimeout, errWait) log.Error(claude.GetUserFriendlyMessage(authErr)) return @@ -933,6 +954,7 @@ func (h *Handler) RequestAnthropicToken(c *gin.Context) { } fmt.Println("You can now use Claude services through this CLI") CompleteOAuthSession(state) + CompleteOAuthSessionsByProvider("anthropic") }() c.JSON(200, gin.H{"status": "ok", "url": authURL, "state": state}) @@ -968,6 +990,7 @@ func (h *Handler) RequestGeminiCLIToken(c *gin.Context) { RegisterOAuthSession(state, "gemini") isWebUI := isWebUIRequest(c) + var forwarder *callbackForwarder if isWebUI { targetURL, errTarget := h.managementCallbackURL("/google/callback") if errTarget != nil { @@ -975,7 +998,8 @@ func (h *Handler) RequestGeminiCLIToken(c *gin.Context) { c.JSON(http.StatusInternalServerError, gin.H{"error": "callback server unavailable"}) return } - if _, errStart := startCallbackForwarder(geminiCallbackPort, "gemini", targetURL); errStart != nil { + var errStart error + if forwarder, errStart = startCallbackForwarder(geminiCallbackPort, "gemini", targetURL); errStart != nil { log.WithError(errStart).Error("failed to start gemini callback forwarder") c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to start callback server"}) return @@ -984,7 +1008,7 @@ func (h *Handler) RequestGeminiCLIToken(c *gin.Context) { go func() { if isWebUI { - defer stopCallbackForwarder(geminiCallbackPort) + defer stopCallbackForwarderInstance(geminiCallbackPort, forwarder) } // Wait for callback file written by server route @@ -993,6 +1017,9 @@ func (h *Handler) RequestGeminiCLIToken(c *gin.Context) { deadline := time.Now().Add(5 * time.Minute) var authCode string for { + if !IsOAuthSessionPending(state, "gemini") { + return + } if time.Now().After(deadline) { log.Error("oauth flow timed out") SetOAuthSessionError(state, "OAuth flow timed out") @@ -1093,7 +1120,9 @@ func (h *Handler) RequestGeminiCLIToken(c *gin.Context) { // Initialize authenticated HTTP client via GeminiAuth to honor proxy settings gemAuth := geminiAuth.NewGeminiAuth() - gemClient, errGetClient := gemAuth.GetAuthenticatedClient(ctx, &ts, h.cfg, true) + gemClient, errGetClient := gemAuth.GetAuthenticatedClient(ctx, &ts, h.cfg, &geminiAuth.WebLoginOptions{ + NoBrowser: true, + }) if errGetClient != nil { log.Errorf("failed to get authenticated client: %v", errGetClient) SetOAuthSessionError(state, "Failed to get authenticated client") @@ -1166,6 +1195,7 @@ func (h *Handler) RequestGeminiCLIToken(c *gin.Context) { } CompleteOAuthSession(state) + CompleteOAuthSessionsByProvider("gemini") fmt.Printf("You can now use Gemini CLI services through this CLI; token saved to %s\n", savedPath) }() @@ -1207,6 +1237,7 @@ func (h *Handler) RequestCodexToken(c *gin.Context) { RegisterOAuthSession(state, "codex") isWebUI := isWebUIRequest(c) + var forwarder *callbackForwarder if isWebUI { targetURL, errTarget := h.managementCallbackURL("/codex/callback") if errTarget != nil { @@ -1214,7 +1245,8 @@ func (h *Handler) RequestCodexToken(c *gin.Context) { c.JSON(http.StatusInternalServerError, gin.H{"error": "callback server unavailable"}) return } - if _, errStart := startCallbackForwarder(codexCallbackPort, "codex", targetURL); errStart != nil { + var errStart error + if forwarder, errStart = startCallbackForwarder(codexCallbackPort, "codex", targetURL); errStart != nil { log.WithError(errStart).Error("failed to start codex callback forwarder") c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to start callback server"}) return @@ -1223,7 +1255,7 @@ func (h *Handler) RequestCodexToken(c *gin.Context) { go func() { if isWebUI { - defer stopCallbackForwarder(codexCallbackPort) + defer stopCallbackForwarderInstance(codexCallbackPort, forwarder) } // Wait for callback file @@ -1231,6 +1263,9 @@ func (h *Handler) RequestCodexToken(c *gin.Context) { deadline := time.Now().Add(5 * time.Minute) var code string for { + if !IsOAuthSessionPending(state, "codex") { + return + } if time.Now().After(deadline) { authErr := codex.NewAuthenticationError(codex.ErrCallbackTimeout, fmt.Errorf("timeout waiting for OAuth callback")) log.Error(codex.GetUserFriendlyMessage(authErr)) @@ -1346,6 +1381,7 @@ func (h *Handler) RequestCodexToken(c *gin.Context) { } fmt.Println("You can now use Codex services through this CLI") CompleteOAuthSession(state) + CompleteOAuthSessionsByProvider("codex") }() c.JSON(200, gin.H{"status": "ok", "url": authURL, "state": state}) @@ -1391,6 +1427,7 @@ func (h *Handler) RequestAntigravityToken(c *gin.Context) { RegisterOAuthSession(state, "antigravity") isWebUI := isWebUIRequest(c) + var forwarder *callbackForwarder if isWebUI { targetURL, errTarget := h.managementCallbackURL("/antigravity/callback") if errTarget != nil { @@ -1398,7 +1435,8 @@ func (h *Handler) RequestAntigravityToken(c *gin.Context) { c.JSON(http.StatusInternalServerError, gin.H{"error": "callback server unavailable"}) return } - if _, errStart := startCallbackForwarder(antigravityCallbackPort, "antigravity", targetURL); errStart != nil { + var errStart error + if forwarder, errStart = startCallbackForwarder(antigravityCallbackPort, "antigravity", targetURL); errStart != nil { log.WithError(errStart).Error("failed to start antigravity callback forwarder") c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to start callback server"}) return @@ -1407,13 +1445,16 @@ func (h *Handler) RequestAntigravityToken(c *gin.Context) { go func() { if isWebUI { - defer stopCallbackForwarder(antigravityCallbackPort) + defer stopCallbackForwarderInstance(antigravityCallbackPort, forwarder) } waitFile := filepath.Join(h.cfg.AuthDir, fmt.Sprintf(".oauth-antigravity-%s.oauth", state)) deadline := time.Now().Add(5 * time.Minute) var authCode string for { + if !IsOAuthSessionPending(state, "antigravity") { + return + } if time.Now().After(deadline) { log.Error("oauth flow timed out") SetOAuthSessionError(state, "OAuth flow timed out") @@ -1576,6 +1617,7 @@ func (h *Handler) RequestAntigravityToken(c *gin.Context) { } CompleteOAuthSession(state) + CompleteOAuthSessionsByProvider("antigravity") fmt.Printf("Authentication successful! Token saved to %s\n", savedPath) if projectID != "" { fmt.Printf("Using GCP project: %s\n", projectID) @@ -1653,6 +1695,7 @@ func (h *Handler) RequestIFlowToken(c *gin.Context) { RegisterOAuthSession(state, "iflow") isWebUI := isWebUIRequest(c) + var forwarder *callbackForwarder if isWebUI { targetURL, errTarget := h.managementCallbackURL("/iflow/callback") if errTarget != nil { @@ -1660,7 +1703,8 @@ func (h *Handler) RequestIFlowToken(c *gin.Context) { c.JSON(http.StatusInternalServerError, gin.H{"status": "error", "error": "callback server unavailable"}) return } - if _, errStart := startCallbackForwarder(iflowauth.CallbackPort, "iflow", targetURL); errStart != nil { + var errStart error + if forwarder, errStart = startCallbackForwarder(iflowauth.CallbackPort, "iflow", targetURL); errStart != nil { log.WithError(errStart).Error("failed to start iflow callback forwarder") c.JSON(http.StatusInternalServerError, gin.H{"status": "error", "error": "failed to start callback server"}) return @@ -1669,7 +1713,7 @@ func (h *Handler) RequestIFlowToken(c *gin.Context) { go func() { if isWebUI { - defer stopCallbackForwarder(iflowauth.CallbackPort) + defer stopCallbackForwarderInstance(iflowauth.CallbackPort, forwarder) } fmt.Println("Waiting for authentication...") @@ -1677,6 +1721,9 @@ func (h *Handler) RequestIFlowToken(c *gin.Context) { deadline := time.Now().Add(5 * time.Minute) var resultMap map[string]string for { + if !IsOAuthSessionPending(state, "iflow") { + return + } if time.Now().After(deadline) { SetOAuthSessionError(state, "Authentication failed") fmt.Println("Authentication failed: timeout waiting for callback") @@ -1743,6 +1790,7 @@ func (h *Handler) RequestIFlowToken(c *gin.Context) { } fmt.Println("You can now use iFlow services through this CLI") CompleteOAuthSession(state) + CompleteOAuthSessionsByProvider("iflow") }() c.JSON(http.StatusOK, gin.H{"status": "ok", "url": authURL, "state": state}) diff --git a/internal/api/handlers/management/config_lists.go b/internal/api/handlers/management/config_lists.go index a0d0b169..7e42b64b 100644 --- a/internal/api/handlers/management/config_lists.go +++ b/internal/api/handlers/management/config_lists.go @@ -145,71 +145,74 @@ func (h *Handler) PutGeminiKeys(c *gin.Context) { h.persist(c) } func (h *Handler) PatchGeminiKey(c *gin.Context) { + type geminiKeyPatch struct { + APIKey *string `json:"api-key"` + Prefix *string `json:"prefix"` + BaseURL *string `json:"base-url"` + ProxyURL *string `json:"proxy-url"` + Headers *map[string]string `json:"headers"` + ExcludedModels *[]string `json:"excluded-models"` + } var body struct { - Index *int `json:"index"` - Match *string `json:"match"` - Value *config.GeminiKey `json:"value"` + Index *int `json:"index"` + Match *string `json:"match"` + Value *geminiKeyPatch `json:"value"` } if err := c.ShouldBindJSON(&body); err != nil || body.Value == nil { c.JSON(400, gin.H{"error": "invalid body"}) return } - value := *body.Value - value.APIKey = strings.TrimSpace(value.APIKey) - value.BaseURL = strings.TrimSpace(value.BaseURL) - value.ProxyURL = strings.TrimSpace(value.ProxyURL) - value.ExcludedModels = config.NormalizeExcludedModels(value.ExcludedModels) - if value.APIKey == "" { - // Treat empty API key as delete. - if body.Index != nil && *body.Index >= 0 && *body.Index < len(h.cfg.GeminiKey) { - h.cfg.GeminiKey = append(h.cfg.GeminiKey[:*body.Index], h.cfg.GeminiKey[*body.Index+1:]...) - h.cfg.SanitizeGeminiKeys() - h.persist(c) - return - } - if body.Match != nil { - match := strings.TrimSpace(*body.Match) - if match != "" { - out := make([]config.GeminiKey, 0, len(h.cfg.GeminiKey)) - removed := false - for i := range h.cfg.GeminiKey { - if !removed && h.cfg.GeminiKey[i].APIKey == match { - removed = true - continue - } - out = append(out, h.cfg.GeminiKey[i]) - } - if removed { - h.cfg.GeminiKey = out - h.cfg.SanitizeGeminiKeys() - h.persist(c) - return + targetIndex := -1 + if body.Index != nil && *body.Index >= 0 && *body.Index < len(h.cfg.GeminiKey) { + targetIndex = *body.Index + } + if targetIndex == -1 && body.Match != nil { + match := strings.TrimSpace(*body.Match) + if match != "" { + for i := range h.cfg.GeminiKey { + if h.cfg.GeminiKey[i].APIKey == match { + targetIndex = i + break } } } + } + if targetIndex == -1 { c.JSON(404, gin.H{"error": "item not found"}) return } - if body.Index != nil && *body.Index >= 0 && *body.Index < len(h.cfg.GeminiKey) { - h.cfg.GeminiKey[*body.Index] = value - h.cfg.SanitizeGeminiKeys() - h.persist(c) - return - } - if body.Match != nil { - match := strings.TrimSpace(*body.Match) - for i := range h.cfg.GeminiKey { - if h.cfg.GeminiKey[i].APIKey == match { - h.cfg.GeminiKey[i] = value - h.cfg.SanitizeGeminiKeys() - h.persist(c) - return - } + entry := h.cfg.GeminiKey[targetIndex] + if body.Value.APIKey != nil { + trimmed := strings.TrimSpace(*body.Value.APIKey) + if trimmed == "" { + h.cfg.GeminiKey = append(h.cfg.GeminiKey[:targetIndex], h.cfg.GeminiKey[targetIndex+1:]...) + h.cfg.SanitizeGeminiKeys() + h.persist(c) + return } + entry.APIKey = trimmed } - c.JSON(404, gin.H{"error": "item not found"}) + if body.Value.Prefix != nil { + entry.Prefix = strings.TrimSpace(*body.Value.Prefix) + } + if body.Value.BaseURL != nil { + entry.BaseURL = strings.TrimSpace(*body.Value.BaseURL) + } + if body.Value.ProxyURL != nil { + entry.ProxyURL = strings.TrimSpace(*body.Value.ProxyURL) + } + if body.Value.Headers != nil { + entry.Headers = config.NormalizeHeaders(*body.Value.Headers) + } + if body.Value.ExcludedModels != nil { + entry.ExcludedModels = config.NormalizeExcludedModels(*body.Value.ExcludedModels) + } + h.cfg.GeminiKey[targetIndex] = entry + h.cfg.SanitizeGeminiKeys() + h.persist(c) } + func (h *Handler) DeleteGeminiKey(c *gin.Context) { if val := strings.TrimSpace(c.Query("api-key")); val != "" { out := make([]config.GeminiKey, 0, len(h.cfg.GeminiKey)) @@ -268,35 +271,70 @@ func (h *Handler) PutClaudeKeys(c *gin.Context) { h.persist(c) } func (h *Handler) PatchClaudeKey(c *gin.Context) { + type claudeKeyPatch struct { + APIKey *string `json:"api-key"` + Prefix *string `json:"prefix"` + BaseURL *string `json:"base-url"` + ProxyURL *string `json:"proxy-url"` + Models *[]config.ClaudeModel `json:"models"` + Headers *map[string]string `json:"headers"` + ExcludedModels *[]string `json:"excluded-models"` + } var body struct { - Index *int `json:"index"` - Match *string `json:"match"` - Value *config.ClaudeKey `json:"value"` + Index *int `json:"index"` + Match *string `json:"match"` + Value *claudeKeyPatch `json:"value"` } if err := c.ShouldBindJSON(&body); err != nil || body.Value == nil { c.JSON(400, gin.H{"error": "invalid body"}) return } - value := *body.Value - normalizeClaudeKey(&value) + targetIndex := -1 if body.Index != nil && *body.Index >= 0 && *body.Index < len(h.cfg.ClaudeKey) { - h.cfg.ClaudeKey[*body.Index] = value - h.cfg.SanitizeClaudeKeys() - h.persist(c) - return + targetIndex = *body.Index } - if body.Match != nil { + if targetIndex == -1 && body.Match != nil { + match := strings.TrimSpace(*body.Match) for i := range h.cfg.ClaudeKey { - if h.cfg.ClaudeKey[i].APIKey == *body.Match { - h.cfg.ClaudeKey[i] = value - h.cfg.SanitizeClaudeKeys() - h.persist(c) - return + if h.cfg.ClaudeKey[i].APIKey == match { + targetIndex = i + break } } } - c.JSON(404, gin.H{"error": "item not found"}) + if targetIndex == -1 { + c.JSON(404, gin.H{"error": "item not found"}) + return + } + + entry := h.cfg.ClaudeKey[targetIndex] + if body.Value.APIKey != nil { + entry.APIKey = strings.TrimSpace(*body.Value.APIKey) + } + if body.Value.Prefix != nil { + entry.Prefix = strings.TrimSpace(*body.Value.Prefix) + } + if body.Value.BaseURL != nil { + entry.BaseURL = strings.TrimSpace(*body.Value.BaseURL) + } + if body.Value.ProxyURL != nil { + entry.ProxyURL = strings.TrimSpace(*body.Value.ProxyURL) + } + if body.Value.Models != nil { + entry.Models = append([]config.ClaudeModel(nil), (*body.Value.Models)...) + } + if body.Value.Headers != nil { + entry.Headers = config.NormalizeHeaders(*body.Value.Headers) + } + if body.Value.ExcludedModels != nil { + entry.ExcludedModels = config.NormalizeExcludedModels(*body.Value.ExcludedModels) + } + normalizeClaudeKey(&entry) + h.cfg.ClaudeKey[targetIndex] = entry + h.cfg.SanitizeClaudeKeys() + h.persist(c) } + func (h *Handler) DeleteClaudeKey(c *gin.Context) { if val := c.Query("api-key"); val != "" { out := make([]config.ClaudeKey, 0, len(h.cfg.ClaudeKey)) @@ -356,62 +394,73 @@ func (h *Handler) PutOpenAICompat(c *gin.Context) { h.persist(c) } func (h *Handler) PatchOpenAICompat(c *gin.Context) { + type openAICompatPatch struct { + Name *string `json:"name"` + Prefix *string `json:"prefix"` + BaseURL *string `json:"base-url"` + APIKeyEntries *[]config.OpenAICompatibilityAPIKey `json:"api-key-entries"` + Models *[]config.OpenAICompatibilityModel `json:"models"` + Headers *map[string]string `json:"headers"` + } var body struct { - Name *string `json:"name"` - Index *int `json:"index"` - Value *config.OpenAICompatibility `json:"value"` + Name *string `json:"name"` + Index *int `json:"index"` + Value *openAICompatPatch `json:"value"` } if err := c.ShouldBindJSON(&body); err != nil || body.Value == nil { c.JSON(400, gin.H{"error": "invalid body"}) return } - normalizeOpenAICompatibilityEntry(body.Value) - // If base-url becomes empty, delete the provider instead of updating - if strings.TrimSpace(body.Value.BaseURL) == "" { - if body.Index != nil && *body.Index >= 0 && *body.Index < len(h.cfg.OpenAICompatibility) { - h.cfg.OpenAICompatibility = append(h.cfg.OpenAICompatibility[:*body.Index], h.cfg.OpenAICompatibility[*body.Index+1:]...) + targetIndex := -1 + if body.Index != nil && *body.Index >= 0 && *body.Index < len(h.cfg.OpenAICompatibility) { + targetIndex = *body.Index + } + if targetIndex == -1 && body.Name != nil { + match := strings.TrimSpace(*body.Name) + for i := range h.cfg.OpenAICompatibility { + if h.cfg.OpenAICompatibility[i].Name == match { + targetIndex = i + break + } + } + } + if targetIndex == -1 { + c.JSON(404, gin.H{"error": "item not found"}) + return + } + + entry := h.cfg.OpenAICompatibility[targetIndex] + if body.Value.Name != nil { + entry.Name = strings.TrimSpace(*body.Value.Name) + } + if body.Value.Prefix != nil { + entry.Prefix = strings.TrimSpace(*body.Value.Prefix) + } + if body.Value.BaseURL != nil { + trimmed := strings.TrimSpace(*body.Value.BaseURL) + if trimmed == "" { + h.cfg.OpenAICompatibility = append(h.cfg.OpenAICompatibility[:targetIndex], h.cfg.OpenAICompatibility[targetIndex+1:]...) h.cfg.SanitizeOpenAICompatibility() h.persist(c) return } - if body.Name != nil { - out := make([]config.OpenAICompatibility, 0, len(h.cfg.OpenAICompatibility)) - removed := false - for i := range h.cfg.OpenAICompatibility { - if !removed && h.cfg.OpenAICompatibility[i].Name == *body.Name { - removed = true - continue - } - out = append(out, h.cfg.OpenAICompatibility[i]) - } - if removed { - h.cfg.OpenAICompatibility = out - h.cfg.SanitizeOpenAICompatibility() - h.persist(c) - return - } - } - c.JSON(404, gin.H{"error": "item not found"}) - return + entry.BaseURL = trimmed } - if body.Index != nil && *body.Index >= 0 && *body.Index < len(h.cfg.OpenAICompatibility) { - h.cfg.OpenAICompatibility[*body.Index] = *body.Value - h.cfg.SanitizeOpenAICompatibility() - h.persist(c) - return + if body.Value.APIKeyEntries != nil { + entry.APIKeyEntries = append([]config.OpenAICompatibilityAPIKey(nil), (*body.Value.APIKeyEntries)...) } - if body.Name != nil { - for i := range h.cfg.OpenAICompatibility { - if h.cfg.OpenAICompatibility[i].Name == *body.Name { - h.cfg.OpenAICompatibility[i] = *body.Value - h.cfg.SanitizeOpenAICompatibility() - h.persist(c) - return - } - } + if body.Value.Models != nil { + entry.Models = append([]config.OpenAICompatibilityModel(nil), (*body.Value.Models)...) } - c.JSON(404, gin.H{"error": "item not found"}) + if body.Value.Headers != nil { + entry.Headers = config.NormalizeHeaders(*body.Value.Headers) + } + normalizeOpenAICompatibilityEntry(&entry) + h.cfg.OpenAICompatibility[targetIndex] = entry + h.cfg.SanitizeOpenAICompatibility() + h.persist(c) } + func (h *Handler) DeleteOpenAICompat(c *gin.Context) { if name := c.Query("name"); name != "" { out := make([]config.OpenAICompatibility, 0, len(h.cfg.OpenAICompatibility)) @@ -563,66 +612,72 @@ func (h *Handler) PutCodexKeys(c *gin.Context) { h.persist(c) } func (h *Handler) PatchCodexKey(c *gin.Context) { + type codexKeyPatch struct { + APIKey *string `json:"api-key"` + Prefix *string `json:"prefix"` + BaseURL *string `json:"base-url"` + ProxyURL *string `json:"proxy-url"` + Headers *map[string]string `json:"headers"` + ExcludedModels *[]string `json:"excluded-models"` + } var body struct { - Index *int `json:"index"` - Match *string `json:"match"` - Value *config.CodexKey `json:"value"` + Index *int `json:"index"` + Match *string `json:"match"` + Value *codexKeyPatch `json:"value"` } if err := c.ShouldBindJSON(&body); err != nil || body.Value == nil { c.JSON(400, gin.H{"error": "invalid body"}) return } - value := *body.Value - value.APIKey = strings.TrimSpace(value.APIKey) - value.BaseURL = strings.TrimSpace(value.BaseURL) - value.ProxyURL = strings.TrimSpace(value.ProxyURL) - value.Headers = config.NormalizeHeaders(value.Headers) - value.ExcludedModels = config.NormalizeExcludedModels(value.ExcludedModels) - // If base-url becomes empty, delete instead of update - if value.BaseURL == "" { - if body.Index != nil && *body.Index >= 0 && *body.Index < len(h.cfg.CodexKey) { - h.cfg.CodexKey = append(h.cfg.CodexKey[:*body.Index], h.cfg.CodexKey[*body.Index+1:]...) - h.cfg.SanitizeCodexKeys() - h.persist(c) - return - } - if body.Match != nil { - out := make([]config.CodexKey, 0, len(h.cfg.CodexKey)) - removed := false - for i := range h.cfg.CodexKey { - if !removed && h.cfg.CodexKey[i].APIKey == *body.Match { - removed = true - continue - } - out = append(out, h.cfg.CodexKey[i]) - } - if removed { - h.cfg.CodexKey = out - h.cfg.SanitizeCodexKeys() - h.persist(c) - return - } - } - } else { - if body.Index != nil && *body.Index >= 0 && *body.Index < len(h.cfg.CodexKey) { - h.cfg.CodexKey[*body.Index] = value - h.cfg.SanitizeCodexKeys() - h.persist(c) - return - } - if body.Match != nil { - for i := range h.cfg.CodexKey { - if h.cfg.CodexKey[i].APIKey == *body.Match { - h.cfg.CodexKey[i] = value - h.cfg.SanitizeCodexKeys() - h.persist(c) - return - } + targetIndex := -1 + if body.Index != nil && *body.Index >= 0 && *body.Index < len(h.cfg.CodexKey) { + targetIndex = *body.Index + } + if targetIndex == -1 && body.Match != nil { + match := strings.TrimSpace(*body.Match) + for i := range h.cfg.CodexKey { + if h.cfg.CodexKey[i].APIKey == match { + targetIndex = i + break } } } - c.JSON(404, gin.H{"error": "item not found"}) + if targetIndex == -1 { + c.JSON(404, gin.H{"error": "item not found"}) + return + } + + entry := h.cfg.CodexKey[targetIndex] + if body.Value.APIKey != nil { + entry.APIKey = strings.TrimSpace(*body.Value.APIKey) + } + if body.Value.Prefix != nil { + entry.Prefix = strings.TrimSpace(*body.Value.Prefix) + } + if body.Value.BaseURL != nil { + trimmed := strings.TrimSpace(*body.Value.BaseURL) + if trimmed == "" { + h.cfg.CodexKey = append(h.cfg.CodexKey[:targetIndex], h.cfg.CodexKey[targetIndex+1:]...) + h.cfg.SanitizeCodexKeys() + h.persist(c) + return + } + entry.BaseURL = trimmed + } + if body.Value.ProxyURL != nil { + entry.ProxyURL = strings.TrimSpace(*body.Value.ProxyURL) + } + if body.Value.Headers != nil { + entry.Headers = config.NormalizeHeaders(*body.Value.Headers) + } + if body.Value.ExcludedModels != nil { + entry.ExcludedModels = config.NormalizeExcludedModels(*body.Value.ExcludedModels) + } + h.cfg.CodexKey[targetIndex] = entry + h.cfg.SanitizeCodexKeys() + h.persist(c) } + func (h *Handler) DeleteCodexKey(c *gin.Context) { if val := c.Query("api-key"); val != "" { out := make([]config.CodexKey, 0, len(h.cfg.CodexKey)) diff --git a/internal/api/handlers/management/oauth_sessions.go b/internal/api/handlers/management/oauth_sessions.go index f23b608c..05ff8d1f 100644 --- a/internal/api/handlers/management/oauth_sessions.go +++ b/internal/api/handlers/management/oauth_sessions.go @@ -111,6 +111,27 @@ func (s *oauthSessionStore) Complete(state string) { delete(s.sessions, state) } +func (s *oauthSessionStore) CompleteProvider(provider string) int { + provider = strings.ToLower(strings.TrimSpace(provider)) + if provider == "" { + return 0 + } + now := time.Now() + + s.mu.Lock() + defer s.mu.Unlock() + + s.purgeExpiredLocked(now) + removed := 0 + for state, session := range s.sessions { + if strings.EqualFold(session.Provider, provider) { + delete(s.sessions, state) + removed++ + } + } + return removed +} + func (s *oauthSessionStore) Get(state string) (oauthSession, bool) { state = strings.TrimSpace(state) now := time.Now() @@ -153,6 +174,10 @@ func SetOAuthSessionError(state, message string) { oauthSessions.SetError(state, func CompleteOAuthSession(state string) { oauthSessions.Complete(state) } +func CompleteOAuthSessionsByProvider(provider string) int { + return oauthSessions.CompleteProvider(provider) +} + func GetOAuthSession(state string) (provider string, status string, ok bool) { session, ok := oauthSessions.Get(state) if !ok { diff --git a/internal/auth/gemini/gemini_auth.go b/internal/auth/gemini/gemini_auth.go index f173c95f..7b18e738 100644 --- a/internal/auth/gemini/gemini_auth.go +++ b/internal/auth/gemini/gemini_auth.go @@ -18,6 +18,7 @@ import ( "github.com/router-for-me/CLIProxyAPI/v6/internal/auth/codex" "github.com/router-for-me/CLIProxyAPI/v6/internal/browser" "github.com/router-for-me/CLIProxyAPI/v6/internal/config" + "github.com/router-for-me/CLIProxyAPI/v6/internal/misc" "github.com/router-for-me/CLIProxyAPI/v6/internal/util" log "github.com/sirupsen/logrus" "github.com/tidwall/gjson" @@ -46,6 +47,12 @@ var ( type GeminiAuth struct { } +// WebLoginOptions customizes the interactive OAuth flow. +type WebLoginOptions struct { + NoBrowser bool + Prompt func(string) (string, error) +} + // NewGeminiAuth creates a new instance of GeminiAuth. func NewGeminiAuth() *GeminiAuth { return &GeminiAuth{} @@ -59,12 +66,12 @@ func NewGeminiAuth() *GeminiAuth { // - ctx: The context for the HTTP client // - ts: The Gemini token storage containing authentication tokens // - cfg: The configuration containing proxy settings -// - noBrowser: Optional parameter to disable browser opening +// - opts: Optional parameters to customize browser and prompt behavior // // Returns: // - *http.Client: An HTTP client configured with authentication // - error: An error if the client configuration fails, nil otherwise -func (g *GeminiAuth) GetAuthenticatedClient(ctx context.Context, ts *GeminiTokenStorage, cfg *config.Config, noBrowser ...bool) (*http.Client, error) { +func (g *GeminiAuth) GetAuthenticatedClient(ctx context.Context, ts *GeminiTokenStorage, cfg *config.Config, opts *WebLoginOptions) (*http.Client, error) { // Configure proxy settings for the HTTP client if a proxy URL is provided. proxyURL, err := url.Parse(cfg.ProxyURL) if err == nil { @@ -109,7 +116,7 @@ func (g *GeminiAuth) GetAuthenticatedClient(ctx context.Context, ts *GeminiToken // If no token is found in storage, initiate the web-based OAuth flow. if ts.Token == nil { fmt.Printf("Could not load token from file, starting OAuth flow.\n") - token, err = g.getTokenFromWeb(ctx, conf, noBrowser...) + token, err = g.getTokenFromWeb(ctx, conf, opts) if err != nil { return nil, fmt.Errorf("failed to get token from web: %w", err) } @@ -205,15 +212,15 @@ func (g *GeminiAuth) createTokenStorage(ctx context.Context, config *oauth2.Conf // Parameters: // - ctx: The context for the HTTP client // - config: The OAuth2 configuration -// - noBrowser: Optional parameter to disable browser opening +// - opts: Optional parameters to customize browser and prompt behavior // // Returns: // - *oauth2.Token: The OAuth2 token obtained from the authorization flow // - error: An error if the token acquisition fails, nil otherwise -func (g *GeminiAuth) getTokenFromWeb(ctx context.Context, config *oauth2.Config, noBrowser ...bool) (*oauth2.Token, error) { +func (g *GeminiAuth) getTokenFromWeb(ctx context.Context, config *oauth2.Config, opts *WebLoginOptions) (*oauth2.Token, error) { // Use a channel to pass the authorization code from the HTTP handler to the main function. - codeChan := make(chan string) - errChan := make(chan error) + codeChan := make(chan string, 1) + errChan := make(chan error, 1) // Create a new HTTP server with its own multiplexer. mux := http.NewServeMux() @@ -223,17 +230,26 @@ func (g *GeminiAuth) getTokenFromWeb(ctx context.Context, config *oauth2.Config, mux.HandleFunc("/oauth2callback", func(w http.ResponseWriter, r *http.Request) { if err := r.URL.Query().Get("error"); err != "" { _, _ = fmt.Fprintf(w, "Authentication failed: %s", err) - errChan <- fmt.Errorf("authentication failed via callback: %s", err) + select { + case errChan <- fmt.Errorf("authentication failed via callback: %s", err): + default: + } return } code := r.URL.Query().Get("code") if code == "" { _, _ = fmt.Fprint(w, "Authentication failed: code not found.") - errChan <- fmt.Errorf("code not found in callback") + select { + case errChan <- fmt.Errorf("code not found in callback"): + default: + } return } _, _ = fmt.Fprint(w, "

Authentication successful!

You can close this window.

") - codeChan <- code + select { + case codeChan <- code: + default: + } }) // Start the server in a goroutine. @@ -250,7 +266,12 @@ func (g *GeminiAuth) getTokenFromWeb(ctx context.Context, config *oauth2.Config, // Open the authorization URL in the user's browser. authURL := config.AuthCodeURL("state-token", oauth2.AccessTypeOffline, oauth2.SetAuthURLParam("prompt", "consent")) - if len(noBrowser) == 1 && !noBrowser[0] { + noBrowser := false + if opts != nil { + noBrowser = opts.NoBrowser + } + + if !noBrowser { fmt.Println("Opening browser for authentication...") // Check if browser is available @@ -281,13 +302,60 @@ func (g *GeminiAuth) getTokenFromWeb(ctx context.Context, config *oauth2.Config, // Wait for the authorization code or an error. var authCode string - select { - case code := <-codeChan: - authCode = code - case err := <-errChan: - return nil, err - case <-time.After(5 * time.Minute): // Timeout - return nil, fmt.Errorf("oauth flow timed out") + timeoutTimer := time.NewTimer(5 * time.Minute) + defer timeoutTimer.Stop() + + var manualPromptTimer *time.Timer + var manualPromptC <-chan time.Time + if opts != nil && opts.Prompt != nil { + manualPromptTimer = time.NewTimer(15 * time.Second) + manualPromptC = manualPromptTimer.C + defer manualPromptTimer.Stop() + } + +waitForCallback: + for { + select { + case code := <-codeChan: + authCode = code + break waitForCallback + case err := <-errChan: + return nil, err + case <-manualPromptC: + manualPromptC = nil + if manualPromptTimer != nil { + manualPromptTimer.Stop() + } + select { + case code := <-codeChan: + authCode = code + break waitForCallback + case err := <-errChan: + return nil, err + default: + } + input, err := opts.Prompt("Paste the Gemini callback URL (or press Enter to keep waiting): ") + if err != nil { + return nil, err + } + parsed, err := misc.ParseOAuthCallback(input) + if err != nil { + return nil, err + } + if parsed == nil { + continue + } + if parsed.Error != "" { + return nil, fmt.Errorf("authentication failed via callback: %s", parsed.Error) + } + if parsed.Code == "" { + return nil, fmt.Errorf("code not found in callback") + } + authCode = parsed.Code + break waitForCallback + case <-timeoutTimer.C: + return nil, fmt.Errorf("oauth flow timed out") + } } // Shutdown the server. diff --git a/internal/cmd/anthropic_login.go b/internal/cmd/anthropic_login.go index 8e9d01cd..6efd87a8 100644 --- a/internal/cmd/anthropic_login.go +++ b/internal/cmd/anthropic_login.go @@ -24,12 +24,17 @@ func DoClaudeLogin(cfg *config.Config, options *LoginOptions) { options = &LoginOptions{} } + promptFn := options.Prompt + if promptFn == nil { + promptFn = defaultProjectPrompt() + } + manager := newAuthManager() authOpts := &sdkAuth.LoginOptions{ NoBrowser: options.NoBrowser, Metadata: map[string]string{}, - Prompt: options.Prompt, + Prompt: promptFn, } _, savedPath, err := manager.Login(context.Background(), "claude", cfg, authOpts) diff --git a/internal/cmd/antigravity_login.go b/internal/cmd/antigravity_login.go index b2602638..1cd42899 100644 --- a/internal/cmd/antigravity_login.go +++ b/internal/cmd/antigravity_login.go @@ -15,11 +15,16 @@ func DoAntigravityLogin(cfg *config.Config, options *LoginOptions) { options = &LoginOptions{} } + promptFn := options.Prompt + if promptFn == nil { + promptFn = defaultProjectPrompt() + } + manager := newAuthManager() authOpts := &sdkAuth.LoginOptions{ NoBrowser: options.NoBrowser, Metadata: map[string]string{}, - Prompt: options.Prompt, + Prompt: promptFn, } record, savedPath, err := manager.Login(context.Background(), "antigravity", cfg, authOpts) diff --git a/internal/cmd/iflow_login.go b/internal/cmd/iflow_login.go index ba43470b..cf00b63c 100644 --- a/internal/cmd/iflow_login.go +++ b/internal/cmd/iflow_login.go @@ -20,13 +20,7 @@ func DoIFlowLogin(cfg *config.Config, options *LoginOptions) { promptFn := options.Prompt if promptFn == nil { - promptFn = func(prompt string) (string, error) { - fmt.Println() - fmt.Println(prompt) - var value string - _, err := fmt.Scanln(&value) - return value, err - } + promptFn = defaultProjectPrompt() } authOpts := &sdkAuth.LoginOptions{ diff --git a/internal/cmd/login.go b/internal/cmd/login.go index de01cec5..3bb0b9a5 100644 --- a/internal/cmd/login.go +++ b/internal/cmd/login.go @@ -55,11 +55,22 @@ func DoLogin(cfg *config.Config, projectID string, options *LoginOptions) { ctx := context.Background() + promptFn := options.Prompt + if promptFn == nil { + promptFn = defaultProjectPrompt() + } + + trimmedProjectID := strings.TrimSpace(projectID) + callbackPrompt := promptFn + if trimmedProjectID == "" { + callbackPrompt = nil + } + loginOpts := &sdkAuth.LoginOptions{ NoBrowser: options.NoBrowser, - ProjectID: strings.TrimSpace(projectID), + ProjectID: trimmedProjectID, Metadata: map[string]string{}, - Prompt: options.Prompt, + Prompt: callbackPrompt, } authenticator := sdkAuth.NewGeminiAuthenticator() @@ -76,7 +87,10 @@ func DoLogin(cfg *config.Config, projectID string, options *LoginOptions) { } geminiAuth := gemini.NewGeminiAuth() - httpClient, errClient := geminiAuth.GetAuthenticatedClient(ctx, storage, cfg, options.NoBrowser) + httpClient, errClient := geminiAuth.GetAuthenticatedClient(ctx, storage, cfg, &gemini.WebLoginOptions{ + NoBrowser: options.NoBrowser, + Prompt: callbackPrompt, + }) if errClient != nil { log.Errorf("Gemini authentication failed: %v", errClient) return @@ -90,12 +104,7 @@ func DoLogin(cfg *config.Config, projectID string, options *LoginOptions) { return } - promptFn := options.Prompt - if promptFn == nil { - promptFn = defaultProjectPrompt() - } - - selectedProjectID := promptForProjectSelection(projects, strings.TrimSpace(projectID), promptFn) + selectedProjectID := promptForProjectSelection(projects, trimmedProjectID, promptFn) projectSelections, errSelection := resolveProjectSelections(selectedProjectID, projects) if errSelection != nil { log.Errorf("Invalid project selection: %v", errSelection) diff --git a/internal/cmd/openai_login.go b/internal/cmd/openai_login.go index e402e476..d981f6ae 100644 --- a/internal/cmd/openai_login.go +++ b/internal/cmd/openai_login.go @@ -35,12 +35,17 @@ func DoCodexLogin(cfg *config.Config, options *LoginOptions) { options = &LoginOptions{} } + promptFn := options.Prompt + if promptFn == nil { + promptFn = defaultProjectPrompt() + } + manager := newAuthManager() authOpts := &sdkAuth.LoginOptions{ NoBrowser: options.NoBrowser, Metadata: map[string]string{}, - Prompt: options.Prompt, + Prompt: promptFn, } _, savedPath, err := manager.Login(context.Background(), "codex", cfg, authOpts) diff --git a/internal/misc/oauth.go b/internal/misc/oauth.go index acf034b2..c14f39d2 100644 --- a/internal/misc/oauth.go +++ b/internal/misc/oauth.go @@ -4,6 +4,8 @@ import ( "crypto/rand" "encoding/hex" "fmt" + "net/url" + "strings" ) // GenerateRandomState generates a cryptographically secure random state parameter @@ -19,3 +21,83 @@ func GenerateRandomState() (string, error) { } return hex.EncodeToString(bytes), nil } + +// OAuthCallback captures the parsed OAuth callback parameters. +type OAuthCallback struct { + Code string + State string + Error string + ErrorDescription string +} + +// ParseOAuthCallback extracts OAuth parameters from a callback URL. +// It returns nil when the input is empty. +func ParseOAuthCallback(input string) (*OAuthCallback, error) { + trimmed := strings.TrimSpace(input) + if trimmed == "" { + return nil, nil + } + + candidate := trimmed + if !strings.Contains(candidate, "://") { + if strings.HasPrefix(candidate, "?") { + candidate = "http://localhost" + candidate + } else if strings.ContainsAny(candidate, "/?#") || strings.Contains(candidate, ":") { + candidate = "http://" + candidate + } else if strings.Contains(candidate, "=") { + candidate = "http://localhost/?" + candidate + } else { + return nil, fmt.Errorf("invalid callback URL") + } + } + + parsedURL, err := url.Parse(candidate) + if err != nil { + return nil, err + } + + query := parsedURL.Query() + code := strings.TrimSpace(query.Get("code")) + state := strings.TrimSpace(query.Get("state")) + errCode := strings.TrimSpace(query.Get("error")) + errDesc := strings.TrimSpace(query.Get("error_description")) + + if parsedURL.Fragment != "" { + if fragQuery, errFrag := url.ParseQuery(parsedURL.Fragment); errFrag == nil { + if code == "" { + code = strings.TrimSpace(fragQuery.Get("code")) + } + if state == "" { + state = strings.TrimSpace(fragQuery.Get("state")) + } + if errCode == "" { + errCode = strings.TrimSpace(fragQuery.Get("error")) + } + if errDesc == "" { + errDesc = strings.TrimSpace(fragQuery.Get("error_description")) + } + } + } + + if code != "" && state == "" && strings.Contains(code, "#") { + parts := strings.SplitN(code, "#", 2) + code = parts[0] + state = parts[1] + } + + if errCode == "" && errDesc != "" { + errCode = errDesc + errDesc = "" + } + + if code == "" && errCode == "" { + return nil, fmt.Errorf("callback URL missing code") + } + + return &OAuthCallback{ + Code: code, + State: state, + Error: errCode, + ErrorDescription: errDesc, + }, nil +} diff --git a/sdk/auth/antigravity.go b/sdk/auth/antigravity.go index b3d7f6c5..ae22f772 100644 --- a/sdk/auth/antigravity.go +++ b/sdk/auth/antigravity.go @@ -99,11 +99,54 @@ func (AntigravityAuthenticator) Login(ctx context.Context, cfg *config.Config, o fmt.Println("Waiting for antigravity authentication callback...") var cbRes callbackResult - select { - case res := <-cbChan: - cbRes = res - case <-time.After(5 * time.Minute): - return nil, fmt.Errorf("antigravity: authentication timed out") + timeoutTimer := time.NewTimer(5 * time.Minute) + defer timeoutTimer.Stop() + + var manualPromptTimer *time.Timer + var manualPromptC <-chan time.Time + if opts.Prompt != nil { + manualPromptTimer = time.NewTimer(15 * time.Second) + manualPromptC = manualPromptTimer.C + defer manualPromptTimer.Stop() + } + +waitForCallback: + for { + select { + case res := <-cbChan: + cbRes = res + break waitForCallback + case <-manualPromptC: + manualPromptC = nil + if manualPromptTimer != nil { + manualPromptTimer.Stop() + } + select { + case res := <-cbChan: + cbRes = res + break waitForCallback + default: + } + input, errPrompt := opts.Prompt("Paste the antigravity callback URL (or press Enter to keep waiting): ") + if errPrompt != nil { + return nil, errPrompt + } + parsed, errParse := misc.ParseOAuthCallback(input) + if errParse != nil { + return nil, errParse + } + if parsed == nil { + continue + } + cbRes = callbackResult{ + Code: parsed.Code, + State: parsed.State, + Error: parsed.Error, + } + break waitForCallback + case <-timeoutTimer.C: + return nil, fmt.Errorf("antigravity: authentication timed out") + } } if cbRes.Error != "" { diff --git a/sdk/auth/claude.go b/sdk/auth/claude.go index da9e5065..c43b78cd 100644 --- a/sdk/auth/claude.go +++ b/sdk/auth/claude.go @@ -98,16 +98,76 @@ func (a *ClaudeAuthenticator) Login(ctx context.Context, cfg *config.Config, opt fmt.Println("Waiting for Claude authentication callback...") - result, err := oauthServer.WaitForCallback(5 * time.Minute) - if err != nil { - if strings.Contains(err.Error(), "timeout") { - return nil, claude.NewAuthenticationError(claude.ErrCallbackTimeout, err) + callbackCh := make(chan *claude.OAuthResult, 1) + callbackErrCh := make(chan error, 1) + manualDescription := "" + + go func() { + result, errWait := oauthServer.WaitForCallback(5 * time.Minute) + if errWait != nil { + callbackErrCh <- errWait + return + } + callbackCh <- result + }() + + var result *claude.OAuthResult + var manualPromptTimer *time.Timer + var manualPromptC <-chan time.Time + if opts.Prompt != nil { + manualPromptTimer = time.NewTimer(15 * time.Second) + manualPromptC = manualPromptTimer.C + defer manualPromptTimer.Stop() + } + +waitForCallback: + for { + select { + case result = <-callbackCh: + break waitForCallback + case err = <-callbackErrCh: + if strings.Contains(err.Error(), "timeout") { + return nil, claude.NewAuthenticationError(claude.ErrCallbackTimeout, err) + } + return nil, err + case <-manualPromptC: + manualPromptC = nil + if manualPromptTimer != nil { + manualPromptTimer.Stop() + } + select { + case result = <-callbackCh: + break waitForCallback + case err = <-callbackErrCh: + if strings.Contains(err.Error(), "timeout") { + return nil, claude.NewAuthenticationError(claude.ErrCallbackTimeout, err) + } + return nil, err + default: + } + input, errPrompt := opts.Prompt("Paste the Claude callback URL (or press Enter to keep waiting): ") + if errPrompt != nil { + return nil, errPrompt + } + parsed, errParse := misc.ParseOAuthCallback(input) + if errParse != nil { + return nil, errParse + } + if parsed == nil { + continue + } + manualDescription = parsed.ErrorDescription + result = &claude.OAuthResult{ + Code: parsed.Code, + State: parsed.State, + Error: parsed.Error, + } + break waitForCallback } - return nil, err } if result.Error != "" { - return nil, claude.NewOAuthError(result.Error, "", http.StatusBadRequest) + return nil, claude.NewOAuthError(result.Error, manualDescription, http.StatusBadRequest) } if result.State != state { diff --git a/sdk/auth/codex.go b/sdk/auth/codex.go index 138c2141..99992525 100644 --- a/sdk/auth/codex.go +++ b/sdk/auth/codex.go @@ -97,16 +97,76 @@ func (a *CodexAuthenticator) Login(ctx context.Context, cfg *config.Config, opts fmt.Println("Waiting for Codex authentication callback...") - result, err := oauthServer.WaitForCallback(5 * time.Minute) - if err != nil { - if strings.Contains(err.Error(), "timeout") { - return nil, codex.NewAuthenticationError(codex.ErrCallbackTimeout, err) + callbackCh := make(chan *codex.OAuthResult, 1) + callbackErrCh := make(chan error, 1) + manualDescription := "" + + go func() { + result, errWait := oauthServer.WaitForCallback(5 * time.Minute) + if errWait != nil { + callbackErrCh <- errWait + return + } + callbackCh <- result + }() + + var result *codex.OAuthResult + var manualPromptTimer *time.Timer + var manualPromptC <-chan time.Time + if opts.Prompt != nil { + manualPromptTimer = time.NewTimer(15 * time.Second) + manualPromptC = manualPromptTimer.C + defer manualPromptTimer.Stop() + } + +waitForCallback: + for { + select { + case result = <-callbackCh: + break waitForCallback + case err = <-callbackErrCh: + if strings.Contains(err.Error(), "timeout") { + return nil, codex.NewAuthenticationError(codex.ErrCallbackTimeout, err) + } + return nil, err + case <-manualPromptC: + manualPromptC = nil + if manualPromptTimer != nil { + manualPromptTimer.Stop() + } + select { + case result = <-callbackCh: + break waitForCallback + case err = <-callbackErrCh: + if strings.Contains(err.Error(), "timeout") { + return nil, codex.NewAuthenticationError(codex.ErrCallbackTimeout, err) + } + return nil, err + default: + } + input, errPrompt := opts.Prompt("Paste the Codex callback URL (or press Enter to keep waiting): ") + if errPrompt != nil { + return nil, errPrompt + } + parsed, errParse := misc.ParseOAuthCallback(input) + if errParse != nil { + return nil, errParse + } + if parsed == nil { + continue + } + manualDescription = parsed.ErrorDescription + result = &codex.OAuthResult{ + Code: parsed.Code, + State: parsed.State, + Error: parsed.Error, + } + break waitForCallback } - return nil, err } if result.Error != "" { - return nil, codex.NewOAuthError(result.Error, "", http.StatusBadRequest) + return nil, codex.NewOAuthError(result.Error, manualDescription, http.StatusBadRequest) } if result.State != state { diff --git a/sdk/auth/gemini.go b/sdk/auth/gemini.go index 7110101f..75ef4579 100644 --- a/sdk/auth/gemini.go +++ b/sdk/auth/gemini.go @@ -44,7 +44,10 @@ func (a *GeminiAuthenticator) Login(ctx context.Context, cfg *config.Config, opt } geminiAuth := gemini.NewGeminiAuth() - _, err := geminiAuth.GetAuthenticatedClient(ctx, &ts, cfg, opts.NoBrowser) + _, err := geminiAuth.GetAuthenticatedClient(ctx, &ts, cfg, &gemini.WebLoginOptions{ + NoBrowser: opts.NoBrowser, + Prompt: opts.Prompt, + }) if err != nil { return nil, fmt.Errorf("gemini authentication failed: %w", err) } diff --git a/sdk/auth/iflow.go b/sdk/auth/iflow.go index ee96bdaa..3fd82f1d 100644 --- a/sdk/auth/iflow.go +++ b/sdk/auth/iflow.go @@ -84,9 +84,64 @@ func (a *IFlowAuthenticator) Login(ctx context.Context, cfg *config.Config, opts fmt.Println("Waiting for iFlow authentication callback...") - result, err := oauthServer.WaitForCallback(5 * time.Minute) - if err != nil { - return nil, fmt.Errorf("iflow auth: callback wait failed: %w", err) + callbackCh := make(chan *iflow.OAuthResult, 1) + callbackErrCh := make(chan error, 1) + + go func() { + result, errWait := oauthServer.WaitForCallback(5 * time.Minute) + if errWait != nil { + callbackErrCh <- errWait + return + } + callbackCh <- result + }() + + var result *iflow.OAuthResult + var manualPromptTimer *time.Timer + var manualPromptC <-chan time.Time + if opts.Prompt != nil { + manualPromptTimer = time.NewTimer(15 * time.Second) + manualPromptC = manualPromptTimer.C + defer manualPromptTimer.Stop() + } + +waitForCallback: + for { + select { + case result = <-callbackCh: + break waitForCallback + case err = <-callbackErrCh: + return nil, fmt.Errorf("iflow auth: callback wait failed: %w", err) + case <-manualPromptC: + manualPromptC = nil + if manualPromptTimer != nil { + manualPromptTimer.Stop() + } + select { + case result = <-callbackCh: + break waitForCallback + case err = <-callbackErrCh: + return nil, fmt.Errorf("iflow auth: callback wait failed: %w", err) + default: + } + input, errPrompt := opts.Prompt("Paste the iFlow callback URL (or press Enter to keep waiting): ") + if errPrompt != nil { + return nil, errPrompt + } + parsed, errParse := misc.ParseOAuthCallback(input) + if errParse != nil { + return nil, errParse + } + if parsed == nil { + continue + } + result = &iflow.OAuthResult{ + Code: parsed.Code, + State: parsed.State, + Error: parsed.Error, + } + break waitForCallback + } } if result.Error != "" { return nil, fmt.Errorf("iflow auth: provider returned error %s", result.Error)