rebuild branch

This commit is contained in:
Luis Pater
2025-09-25 10:32:48 +08:00
parent 3f69254f43
commit f5dc380b63
214 changed files with 39377 additions and 0 deletions

View File

@@ -0,0 +1,237 @@
// Package claude provides HTTP handlers for Claude API code-related functionality.
// This package implements Claude-compatible streaming chat completions with sophisticated
// client rotation and quota management systems to ensure high availability and optimal
// resource utilization across multiple backend clients. It handles request translation
// between Claude API format and the underlying Gemini backend, providing seamless
// API compatibility while maintaining robust error handling and connection management.
package claude
import (
"bytes"
"context"
"fmt"
"net/http"
"time"
"github.com/gin-gonic/gin"
"github.com/router-for-me/CLIProxyAPI/v6/internal/api/handlers"
. "github.com/router-for-me/CLIProxyAPI/v6/internal/constant"
"github.com/router-for-me/CLIProxyAPI/v6/internal/interfaces"
"github.com/router-for-me/CLIProxyAPI/v6/internal/registry"
"github.com/tidwall/gjson"
)
// ClaudeCodeAPIHandler contains the handlers for Claude API endpoints.
// It holds a pool of clients to interact with the backend service.
type ClaudeCodeAPIHandler struct {
*handlers.BaseAPIHandler
}
// NewClaudeCodeAPIHandler creates a new Claude API handlers instance.
// It takes an BaseAPIHandler instance as input and returns a ClaudeCodeAPIHandler.
//
// Parameters:
// - apiHandlers: The base API handler instance.
//
// Returns:
// - *ClaudeCodeAPIHandler: A new Claude code API handler instance.
func NewClaudeCodeAPIHandler(apiHandlers *handlers.BaseAPIHandler) *ClaudeCodeAPIHandler {
return &ClaudeCodeAPIHandler{
BaseAPIHandler: apiHandlers,
}
}
// HandlerType returns the identifier for this handler implementation.
func (h *ClaudeCodeAPIHandler) HandlerType() string {
return Claude
}
// Models returns a list of models supported by this handler.
func (h *ClaudeCodeAPIHandler) Models() []map[string]any {
// Get dynamic models from the global registry
modelRegistry := registry.GetGlobalRegistry()
return modelRegistry.GetAvailableModels("claude")
}
// ClaudeMessages handles Claude-compatible streaming chat completions.
// This function implements a sophisticated client rotation and quota management system
// to ensure high availability and optimal resource utilization across multiple backend clients.
//
// Parameters:
// - c: The Gin context for the request.
func (h *ClaudeCodeAPIHandler) ClaudeMessages(c *gin.Context) {
// Extract raw JSON data from the incoming request
rawJSON, err := c.GetRawData()
// If data retrieval fails, return a 400 Bad Request error.
if err != nil {
c.JSON(http.StatusBadRequest, handlers.ErrorResponse{
Error: handlers.ErrorDetail{
Message: fmt.Sprintf("Invalid request: %v", err),
Type: "invalid_request_error",
},
})
return
}
// Check if the client requested a streaming response.
streamResult := gjson.GetBytes(rawJSON, "stream")
if !streamResult.Exists() || streamResult.Type == gjson.False {
h.handleNonStreamingResponse(c, rawJSON)
} else {
h.handleStreamingResponse(c, rawJSON)
}
}
// ClaudeMessages handles Claude-compatible streaming chat completions.
// This function implements a sophisticated client rotation and quota management system
// to ensure high availability and optimal resource utilization across multiple backend clients.
//
// Parameters:
// - c: The Gin context for the request.
func (h *ClaudeCodeAPIHandler) ClaudeCountTokens(c *gin.Context) {
// Extract raw JSON data from the incoming request
rawJSON, err := c.GetRawData()
// If data retrieval fails, return a 400 Bad Request error.
if err != nil {
c.JSON(http.StatusBadRequest, handlers.ErrorResponse{
Error: handlers.ErrorDetail{
Message: fmt.Sprintf("Invalid request: %v", err),
Type: "invalid_request_error",
},
})
return
}
c.Header("Content-Type", "application/json")
alt := h.GetAlt(c)
cliCtx, cliCancel := h.GetContextWithCancel(h, c, context.Background())
modelName := gjson.GetBytes(rawJSON, "model").String()
resp, errMsg := h.ExecuteCountWithAuthManager(cliCtx, h.HandlerType(), modelName, rawJSON, alt)
if errMsg != nil {
h.WriteErrorResponse(c, errMsg)
cliCancel(errMsg.Error)
return
}
_, _ = c.Writer.Write(resp)
cliCancel()
}
// ClaudeModels handles the Claude models listing endpoint.
// It returns a JSON response containing available Claude models and their specifications.
//
// Parameters:
// - c: The Gin context for the request.
func (h *ClaudeCodeAPIHandler) ClaudeModels(c *gin.Context) {
c.JSON(http.StatusOK, gin.H{
"data": h.Models(),
})
}
// handleNonStreamingResponse handles non-streaming content generation requests for Claude models.
// This function processes the request synchronously and returns the complete generated
// response in a single API call. It supports various generation parameters and
// response formats.
//
// Parameters:
// - c: The Gin context for the request
// - modelName: The name of the Gemini model to use for content generation
// - rawJSON: The raw JSON request body containing generation parameters and content
func (h *ClaudeCodeAPIHandler) handleNonStreamingResponse(c *gin.Context, rawJSON []byte) {
c.Header("Content-Type", "application/json")
alt := h.GetAlt(c)
cliCtx, cliCancel := h.GetContextWithCancel(h, c, context.Background())
modelName := gjson.GetBytes(rawJSON, "model").String()
resp, errMsg := h.ExecuteWithAuthManager(cliCtx, h.HandlerType(), modelName, rawJSON, alt)
if errMsg != nil {
h.WriteErrorResponse(c, errMsg)
cliCancel(errMsg.Error)
return
}
_, _ = c.Writer.Write(resp)
cliCancel()
}
// handleStreamingResponse streams Claude-compatible responses backed by Gemini.
// It sets up SSE, selects a backend client with rotation/quota logic,
// forwards chunks, and translates them to Claude CLI format.
//
// Parameters:
// - c: The Gin context for the request.
// - rawJSON: The raw JSON request body.
func (h *ClaudeCodeAPIHandler) handleStreamingResponse(c *gin.Context, rawJSON []byte) {
// Set up Server-Sent Events (SSE) headers for streaming response
// These headers are essential for maintaining a persistent connection
// and enabling real-time streaming of chat completions
c.Header("Content-Type", "text/event-stream")
c.Header("Cache-Control", "no-cache")
c.Header("Connection", "keep-alive")
c.Header("Access-Control-Allow-Origin", "*")
// Get the http.Flusher interface to manually flush the response.
// This is crucial for streaming as it allows immediate sending of data chunks
flusher, ok := c.Writer.(http.Flusher)
if !ok {
c.JSON(http.StatusInternalServerError, handlers.ErrorResponse{
Error: handlers.ErrorDetail{
Message: "Streaming not supported",
Type: "server_error",
},
})
return
}
modelName := gjson.GetBytes(rawJSON, "model").String()
// Create a cancellable context for the backend client request
// This allows proper cleanup and cancellation of ongoing requests
cliCtx, cliCancel := h.GetContextWithCancel(h, c, context.Background())
dataChan, errChan := h.ExecuteStreamWithAuthManager(cliCtx, h.HandlerType(), modelName, rawJSON, "")
h.forwardClaudeStream(c, flusher, func(err error) { cliCancel(err) }, dataChan, errChan)
return
}
func (h *ClaudeCodeAPIHandler) forwardClaudeStream(c *gin.Context, flusher http.Flusher, cancel func(error), data <-chan []byte, errs <-chan *interfaces.ErrorMessage) {
for {
select {
case <-c.Request.Context().Done():
cancel(c.Request.Context().Err())
return
case chunk, ok := <-data:
if !ok {
flusher.Flush()
cancel(nil)
return
}
if bytes.HasPrefix(chunk, []byte("event:")) {
_, _ = c.Writer.Write([]byte("\n"))
}
_, _ = c.Writer.Write(chunk)
_, _ = c.Writer.Write([]byte("\n"))
flusher.Flush()
case errMsg, ok := <-errs:
if !ok {
continue
}
if errMsg != nil {
h.WriteErrorResponse(c, errMsg)
flusher.Flush()
}
var execErr error
if errMsg != nil {
execErr = errMsg.Error
}
cancel(execErr)
return
case <-time.After(500 * time.Millisecond):
}
}
}

View File

@@ -0,0 +1,227 @@
// Package gemini provides HTTP handlers for Gemini CLI API functionality.
// This package implements handlers that process CLI-specific requests for Gemini API operations,
// including content generation and streaming content generation endpoints.
// The handlers restrict access to localhost only and manage communication with the backend service.
package gemini
import (
"bytes"
"context"
"fmt"
"io"
"net/http"
"strings"
"time"
"github.com/gin-gonic/gin"
"github.com/router-for-me/CLIProxyAPI/v6/internal/api/handlers"
. "github.com/router-for-me/CLIProxyAPI/v6/internal/constant"
"github.com/router-for-me/CLIProxyAPI/v6/internal/interfaces"
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
log "github.com/sirupsen/logrus"
"github.com/tidwall/gjson"
)
// GeminiCLIAPIHandler contains the handlers for Gemini CLI API endpoints.
// It holds a pool of clients to interact with the backend service.
type GeminiCLIAPIHandler struct {
*handlers.BaseAPIHandler
}
// NewGeminiCLIAPIHandler creates a new Gemini CLI API handlers instance.
// It takes an BaseAPIHandler instance as input and returns a GeminiCLIAPIHandler.
func NewGeminiCLIAPIHandler(apiHandlers *handlers.BaseAPIHandler) *GeminiCLIAPIHandler {
return &GeminiCLIAPIHandler{
BaseAPIHandler: apiHandlers,
}
}
// HandlerType returns the type of this handler.
func (h *GeminiCLIAPIHandler) HandlerType() string {
return GeminiCLI
}
// Models returns a list of models supported by this handler.
func (h *GeminiCLIAPIHandler) Models() []map[string]any {
return make([]map[string]any, 0)
}
// CLIHandler handles CLI-specific requests for Gemini API operations.
// It restricts access to localhost only and routes requests to appropriate internal handlers.
func (h *GeminiCLIAPIHandler) CLIHandler(c *gin.Context) {
if !strings.HasPrefix(c.Request.RemoteAddr, "127.0.0.1:") {
c.JSON(http.StatusForbidden, handlers.ErrorResponse{
Error: handlers.ErrorDetail{
Message: "CLI reply only allow local access",
Type: "forbidden",
},
})
return
}
rawJSON, _ := c.GetRawData()
requestRawURI := c.Request.URL.Path
if requestRawURI == "/v1internal:generateContent" {
h.handleInternalGenerateContent(c, rawJSON)
} else if requestRawURI == "/v1internal:streamGenerateContent" {
h.handleInternalStreamGenerateContent(c, rawJSON)
} else {
reqBody := bytes.NewBuffer(rawJSON)
req, err := http.NewRequest("POST", fmt.Sprintf("https://cloudcode-pa.googleapis.com%s", c.Request.URL.RequestURI()), reqBody)
if err != nil {
c.JSON(http.StatusBadRequest, handlers.ErrorResponse{
Error: handlers.ErrorDetail{
Message: fmt.Sprintf("Invalid request: %v", err),
Type: "invalid_request_error",
},
})
return
}
for key, value := range c.Request.Header {
req.Header[key] = value
}
httpClient := util.SetProxy(h.Cfg, &http.Client{})
resp, err := httpClient.Do(req)
if err != nil {
c.JSON(http.StatusBadRequest, handlers.ErrorResponse{
Error: handlers.ErrorDetail{
Message: fmt.Sprintf("Invalid request: %v", err),
Type: "invalid_request_error",
},
})
return
}
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
defer func() {
if err = resp.Body.Close(); err != nil {
log.Printf("warn: failed to close response body: %v", err)
}
}()
bodyBytes, _ := io.ReadAll(resp.Body)
c.JSON(http.StatusBadRequest, handlers.ErrorResponse{
Error: handlers.ErrorDetail{
Message: string(bodyBytes),
Type: "invalid_request_error",
},
})
return
}
defer func() {
_ = resp.Body.Close()
}()
for key, value := range resp.Header {
c.Header(key, value[0])
}
output, err := io.ReadAll(resp.Body)
if err != nil {
log.Errorf("Failed to read response body: %v", err)
return
}
_, _ = c.Writer.Write(output)
c.Set("API_RESPONSE", output)
}
}
// handleInternalStreamGenerateContent handles streaming content generation requests.
// It sets up a server-sent event stream and forwards the request to the backend client.
// The function continuously proxies response chunks from the backend to the client.
func (h *GeminiCLIAPIHandler) handleInternalStreamGenerateContent(c *gin.Context, rawJSON []byte) {
alt := h.GetAlt(c)
if alt == "" {
c.Header("Content-Type", "text/event-stream")
c.Header("Cache-Control", "no-cache")
c.Header("Connection", "keep-alive")
c.Header("Access-Control-Allow-Origin", "*")
}
// Get the http.Flusher interface to manually flush the response.
flusher, ok := c.Writer.(http.Flusher)
if !ok {
c.JSON(http.StatusInternalServerError, handlers.ErrorResponse{
Error: handlers.ErrorDetail{
Message: "Streaming not supported",
Type: "server_error",
},
})
return
}
modelResult := gjson.GetBytes(rawJSON, "model")
modelName := modelResult.String()
cliCtx, cliCancel := h.GetContextWithCancel(h, c, context.Background())
dataChan, errChan := h.ExecuteStreamWithAuthManager(cliCtx, h.HandlerType(), modelName, rawJSON, "")
h.forwardCLIStream(c, flusher, "", func(err error) { cliCancel(err) }, dataChan, errChan)
return
}
// handleInternalGenerateContent handles non-streaming content generation requests.
// It sends a request to the backend client and proxies the entire response back to the client at once.
func (h *GeminiCLIAPIHandler) handleInternalGenerateContent(c *gin.Context, rawJSON []byte) {
c.Header("Content-Type", "application/json")
modelResult := gjson.GetBytes(rawJSON, "model")
modelName := modelResult.String()
cliCtx, cliCancel := h.GetContextWithCancel(h, c, context.Background())
resp, errMsg := h.ExecuteWithAuthManager(cliCtx, h.HandlerType(), modelName, rawJSON, "")
if errMsg != nil {
h.WriteErrorResponse(c, errMsg)
cliCancel(errMsg.Error)
return
}
_, _ = c.Writer.Write(resp)
cliCancel()
}
func (h *GeminiCLIAPIHandler) forwardCLIStream(c *gin.Context, flusher http.Flusher, alt string, cancel func(error), data <-chan []byte, errs <-chan *interfaces.ErrorMessage) {
for {
select {
case <-c.Request.Context().Done():
cancel(c.Request.Context().Err())
return
case chunk, ok := <-data:
if !ok {
cancel(nil)
return
}
if alt == "" {
if bytes.Equal(chunk, []byte("data: [DONE]")) || bytes.Equal(chunk, []byte("[DONE]")) {
continue
}
if !bytes.HasPrefix(chunk, []byte("data:")) {
_, _ = c.Writer.Write([]byte("data: "))
}
_, _ = c.Writer.Write(chunk)
_, _ = c.Writer.Write([]byte("\n\n"))
} else {
_, _ = c.Writer.Write(chunk)
}
flusher.Flush()
case errMsg, ok := <-errs:
if !ok {
continue
}
if errMsg != nil {
h.WriteErrorResponse(c, errMsg)
flusher.Flush()
}
var execErr error
if errMsg != nil {
execErr = errMsg.Error
}
cancel(execErr)
return
case <-time.After(500 * time.Millisecond):
}
}
}

View File

@@ -0,0 +1,297 @@
// Package gemini provides HTTP handlers for Gemini API endpoints.
// This package implements handlers for managing Gemini model operations including
// model listing, content generation, streaming content generation, and token counting.
// It serves as a proxy layer between clients and the Gemini backend service,
// handling request translation, client management, and response processing.
package gemini
import (
"context"
"fmt"
"net/http"
"strings"
"time"
"github.com/gin-gonic/gin"
"github.com/router-for-me/CLIProxyAPI/v6/internal/api/handlers"
. "github.com/router-for-me/CLIProxyAPI/v6/internal/constant"
"github.com/router-for-me/CLIProxyAPI/v6/internal/interfaces"
"github.com/router-for-me/CLIProxyAPI/v6/internal/registry"
)
// GeminiAPIHandler contains the handlers for Gemini API endpoints.
// It holds a pool of clients to interact with the backend service.
type GeminiAPIHandler struct {
*handlers.BaseAPIHandler
}
// NewGeminiAPIHandler creates a new Gemini API handlers instance.
// It takes an BaseAPIHandler instance as input and returns a GeminiAPIHandler.
func NewGeminiAPIHandler(apiHandlers *handlers.BaseAPIHandler) *GeminiAPIHandler {
return &GeminiAPIHandler{
BaseAPIHandler: apiHandlers,
}
}
// HandlerType returns the identifier for this handler implementation.
func (h *GeminiAPIHandler) HandlerType() string {
return Gemini
}
// Models returns the Gemini-compatible model metadata supported by this handler.
func (h *GeminiAPIHandler) Models() []map[string]any {
// Get dynamic models from the global registry
modelRegistry := registry.GetGlobalRegistry()
return modelRegistry.GetAvailableModels("gemini")
}
// GeminiModels handles the Gemini models listing endpoint.
// It returns a JSON response containing available Gemini models and their specifications.
func (h *GeminiAPIHandler) GeminiModels(c *gin.Context) {
c.JSON(http.StatusOK, gin.H{
"models": h.Models(),
})
}
// GeminiGetHandler handles GET requests for specific Gemini model information.
// It returns detailed information about a specific Gemini model based on the action parameter.
func (h *GeminiAPIHandler) GeminiGetHandler(c *gin.Context) {
var request struct {
Action string `uri:"action" binding:"required"`
}
if err := c.ShouldBindUri(&request); err != nil {
c.JSON(http.StatusBadRequest, handlers.ErrorResponse{
Error: handlers.ErrorDetail{
Message: fmt.Sprintf("Invalid request: %v", err),
Type: "invalid_request_error",
},
})
return
}
switch request.Action {
case "gemini-2.5-pro":
c.JSON(http.StatusOK, gin.H{
"name": "models/gemini-2.5-pro",
"version": "2.5",
"displayName": "Gemini 2.5 Pro",
"description": "Stable release (June 17th, 2025) of Gemini 2.5 Pro",
"inputTokenLimit": 1048576,
"outputTokenLimit": 65536,
"supportedGenerationMethods": []string{
"generateContent",
"countTokens",
"createCachedContent",
"batchGenerateContent",
},
"temperature": 1,
"topP": 0.95,
"topK": 64,
"maxTemperature": 2,
"thinking": true,
},
)
case "gemini-2.5-flash":
c.JSON(http.StatusOK, gin.H{
"name": "models/gemini-2.5-flash",
"version": "001",
"displayName": "Gemini 2.5 Flash",
"description": "Stable version of Gemini 2.5 Flash, our mid-size multimodal model that supports up to 1 million tokens, released in June of 2025.",
"inputTokenLimit": 1048576,
"outputTokenLimit": 65536,
"supportedGenerationMethods": []string{
"generateContent",
"countTokens",
"createCachedContent",
"batchGenerateContent",
},
"temperature": 1,
"topP": 0.95,
"topK": 64,
"maxTemperature": 2,
"thinking": true,
})
case "gpt-5":
c.JSON(http.StatusOK, gin.H{
"name": "gpt-5",
"version": "001",
"displayName": "GPT 5",
"description": "Stable version of GPT 5, The best model for coding and agentic tasks across domains.",
"inputTokenLimit": 400000,
"outputTokenLimit": 128000,
"supportedGenerationMethods": []string{
"generateContent",
},
"temperature": 1,
"topP": 0.95,
"topK": 64,
"maxTemperature": 2,
"thinking": true,
})
default:
c.JSON(http.StatusNotFound, handlers.ErrorResponse{
Error: handlers.ErrorDetail{
Message: "Not Found",
Type: "not_found",
},
})
}
}
// GeminiHandler handles POST requests for Gemini API operations.
// It routes requests to appropriate handlers based on the action parameter (model:method format).
func (h *GeminiAPIHandler) GeminiHandler(c *gin.Context) {
var request struct {
Action string `uri:"action" binding:"required"`
}
if err := c.ShouldBindUri(&request); err != nil {
c.JSON(http.StatusBadRequest, handlers.ErrorResponse{
Error: handlers.ErrorDetail{
Message: fmt.Sprintf("Invalid request: %v", err),
Type: "invalid_request_error",
},
})
return
}
action := strings.Split(request.Action, ":")
if len(action) != 2 {
c.JSON(http.StatusNotFound, handlers.ErrorResponse{
Error: handlers.ErrorDetail{
Message: fmt.Sprintf("%s not found.", c.Request.URL.Path),
Type: "invalid_request_error",
},
})
return
}
method := action[1]
rawJSON, _ := c.GetRawData()
switch method {
case "generateContent":
h.handleGenerateContent(c, action[0], rawJSON)
case "streamGenerateContent":
h.handleStreamGenerateContent(c, action[0], rawJSON)
case "countTokens":
h.handleCountTokens(c, action[0], rawJSON)
}
}
// handleStreamGenerateContent handles streaming content generation requests for Gemini models.
// This function establishes a Server-Sent Events connection and streams the generated content
// back to the client in real-time. It supports both SSE format and direct streaming based
// on the 'alt' query parameter.
//
// Parameters:
// - c: The Gin context for the request
// - modelName: The name of the Gemini model to use for content generation
// - rawJSON: The raw JSON request body containing generation parameters
func (h *GeminiAPIHandler) handleStreamGenerateContent(c *gin.Context, modelName string, rawJSON []byte) {
alt := h.GetAlt(c)
if alt == "" {
c.Header("Content-Type", "text/event-stream")
c.Header("Cache-Control", "no-cache")
c.Header("Connection", "keep-alive")
c.Header("Access-Control-Allow-Origin", "*")
}
// Get the http.Flusher interface to manually flush the response.
flusher, ok := c.Writer.(http.Flusher)
if !ok {
c.JSON(http.StatusInternalServerError, handlers.ErrorResponse{
Error: handlers.ErrorDetail{
Message: "Streaming not supported",
Type: "server_error",
},
})
return
}
cliCtx, cliCancel := h.GetContextWithCancel(h, c, context.Background())
dataChan, errChan := h.ExecuteStreamWithAuthManager(cliCtx, h.HandlerType(), modelName, rawJSON, alt)
h.forwardGeminiStream(c, flusher, alt, func(err error) { cliCancel(err) }, dataChan, errChan)
return
}
// handleCountTokens handles token counting requests for Gemini models.
// This function counts the number of tokens in the provided content without
// generating a response. It's useful for quota management and content validation.
//
// Parameters:
// - c: The Gin context for the request
// - modelName: The name of the Gemini model to use for token counting
// - rawJSON: The raw JSON request body containing the content to count
func (h *GeminiAPIHandler) handleCountTokens(c *gin.Context, modelName string, rawJSON []byte) {
c.Header("Content-Type", "application/json")
alt := h.GetAlt(c)
cliCtx, cliCancel := h.GetContextWithCancel(h, c, context.Background())
resp, errMsg := h.ExecuteCountWithAuthManager(cliCtx, h.HandlerType(), modelName, rawJSON, alt)
if errMsg != nil {
h.WriteErrorResponse(c, errMsg)
cliCancel(errMsg.Error)
return
}
_, _ = c.Writer.Write(resp)
cliCancel()
}
// handleGenerateContent handles non-streaming content generation requests for Gemini models.
// This function processes the request synchronously and returns the complete generated
// response in a single API call. It supports various generation parameters and
// response formats.
//
// Parameters:
// - c: The Gin context for the request
// - modelName: The name of the Gemini model to use for content generation
// - rawJSON: The raw JSON request body containing generation parameters and content
func (h *GeminiAPIHandler) handleGenerateContent(c *gin.Context, modelName string, rawJSON []byte) {
c.Header("Content-Type", "application/json")
alt := h.GetAlt(c)
cliCtx, cliCancel := h.GetContextWithCancel(h, c, context.Background())
resp, errMsg := h.ExecuteWithAuthManager(cliCtx, h.HandlerType(), modelName, rawJSON, alt)
if errMsg != nil {
h.WriteErrorResponse(c, errMsg)
cliCancel(errMsg.Error)
return
}
_, _ = c.Writer.Write(resp)
cliCancel()
}
func (h *GeminiAPIHandler) forwardGeminiStream(c *gin.Context, flusher http.Flusher, alt string, cancel func(error), data <-chan []byte, errs <-chan *interfaces.ErrorMessage) {
for {
select {
case <-c.Request.Context().Done():
cancel(c.Request.Context().Err())
return
case chunk, ok := <-data:
if !ok {
cancel(nil)
return
}
if alt == "" {
_, _ = c.Writer.Write([]byte("data: "))
_, _ = c.Writer.Write(chunk)
_, _ = c.Writer.Write([]byte("\n\n"))
} else {
_, _ = c.Writer.Write(chunk)
}
flusher.Flush()
case errMsg, ok := <-errs:
if !ok {
continue
}
if errMsg != nil {
h.WriteErrorResponse(c, errMsg)
flusher.Flush()
}
var execErr error
if errMsg != nil {
execErr = errMsg.Error
}
cancel(execErr)
return
case <-time.After(500 * time.Millisecond):
}
}
}

View File

@@ -0,0 +1,267 @@
// Package handlers provides core API handler functionality for the CLI Proxy API server.
// It includes common types, client management, load balancing, and error handling
// shared across all API endpoint handlers (OpenAI, Claude, Gemini).
package handlers
import (
"fmt"
"net/http"
"github.com/gin-gonic/gin"
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
"github.com/router-for-me/CLIProxyAPI/v6/internal/interfaces"
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
coreauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
coreexecutor "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor"
sdktranslator "github.com/router-for-me/CLIProxyAPI/v6/sdk/translator"
"golang.org/x/net/context"
)
// ErrorResponse represents a standard error response format for the API.
// It contains a single ErrorDetail field.
type ErrorResponse struct {
// Error contains detailed information about the error that occurred.
Error ErrorDetail `json:"error"`
}
// ErrorDetail provides specific information about an error that occurred.
// It includes a human-readable message, an error type, and an optional error code.
type ErrorDetail struct {
// Message is a human-readable message providing more details about the error.
Message string `json:"message"`
// Type is the category of error that occurred (e.g., "invalid_request_error").
Type string `json:"type"`
// Code is a short code identifying the error, if applicable.
Code string `json:"code,omitempty"`
}
// BaseAPIHandler contains the handlers for API endpoints.
// It holds a pool of clients to interact with the backend service and manages
// load balancing, client selection, and configuration.
type BaseAPIHandler struct {
// AuthManager manages auth lifecycle and execution in the new architecture.
AuthManager *coreauth.Manager
// Cfg holds the current application configuration.
Cfg *config.Config
}
// NewBaseAPIHandlers creates a new API handlers instance.
// It takes a slice of clients and configuration as input.
//
// Parameters:
// - cliClients: A slice of AI service clients
// - cfg: The application configuration
//
// Returns:
// - *BaseAPIHandler: A new API handlers instance
func NewBaseAPIHandlers(cfg *config.Config, authManager *coreauth.Manager) *BaseAPIHandler {
return &BaseAPIHandler{
Cfg: cfg,
AuthManager: authManager,
}
}
// UpdateClients updates the handlers' client list and configuration.
// This method is called when the configuration or authentication tokens change.
//
// Parameters:
// - clients: The new slice of AI service clients
// - cfg: The new application configuration
func (h *BaseAPIHandler) UpdateClients(cfg *config.Config) { h.Cfg = cfg }
// GetAlt extracts the 'alt' parameter from the request query string.
// It checks both 'alt' and '$alt' parameters and returns the appropriate value.
//
// Parameters:
// - c: The Gin context containing the HTTP request
//
// Returns:
// - string: The alt parameter value, or empty string if it's "sse"
func (h *BaseAPIHandler) GetAlt(c *gin.Context) string {
var alt string
var hasAlt bool
alt, hasAlt = c.GetQuery("alt")
if !hasAlt {
alt, _ = c.GetQuery("$alt")
}
if alt == "sse" {
return ""
}
return alt
}
// GetContextWithCancel creates a new context with cancellation capabilities.
// It embeds the Gin context and the API handler into the new context for later use.
// The returned cancel function also handles logging the API response if request logging is enabled.
//
// Parameters:
// - handler: The API handler associated with the request.
// - c: The Gin context of the current request.
// - ctx: The parent context.
//
// Returns:
// - context.Context: The new context with cancellation and embedded values.
// - APIHandlerCancelFunc: A function to cancel the context and log the response.
func (h *BaseAPIHandler) GetContextWithCancel(handler interfaces.APIHandler, c *gin.Context, ctx context.Context) (context.Context, APIHandlerCancelFunc) {
newCtx, cancel := context.WithCancel(ctx)
newCtx = context.WithValue(newCtx, "gin", c)
newCtx = context.WithValue(newCtx, "handler", handler)
return newCtx, func(params ...interface{}) {
if h.Cfg.RequestLog {
if len(params) == 1 {
data := params[0]
switch data.(type) {
case []byte:
c.Set("API_RESPONSE", data.([]byte))
case error:
c.Set("API_RESPONSE", []byte(data.(error).Error()))
case string:
c.Set("API_RESPONSE", []byte(data.(string)))
case bool:
case nil:
}
}
}
cancel()
}
}
// ExecuteWithAuthManager executes a non-streaming request via the core auth manager.
// This path is the only supported execution route.
func (h *BaseAPIHandler) ExecuteWithAuthManager(ctx context.Context, handlerType, modelName string, rawJSON []byte, alt string) ([]byte, *interfaces.ErrorMessage) {
providers := util.GetProviderName(modelName, h.Cfg)
if len(providers) == 0 {
return nil, &interfaces.ErrorMessage{StatusCode: http.StatusBadRequest, Error: fmt.Errorf("unknown provider for model %s", modelName)}
}
req := coreexecutor.Request{
Model: modelName,
Payload: cloneBytes(rawJSON),
}
opts := coreexecutor.Options{
Stream: false,
Alt: alt,
OriginalRequest: cloneBytes(rawJSON),
SourceFormat: sdktranslator.FromString(handlerType),
}
resp, err := h.AuthManager.Execute(ctx, providers, req, opts)
if err != nil {
return nil, &interfaces.ErrorMessage{StatusCode: http.StatusInternalServerError, Error: err}
}
return cloneBytes(resp.Payload), nil
}
// ExecuteCountWithAuthManager executes a non-streaming request via the core auth manager.
// This path is the only supported execution route.
func (h *BaseAPIHandler) ExecuteCountWithAuthManager(ctx context.Context, handlerType, modelName string, rawJSON []byte, alt string) ([]byte, *interfaces.ErrorMessage) {
providers := util.GetProviderName(modelName, h.Cfg)
if len(providers) == 0 {
return nil, &interfaces.ErrorMessage{StatusCode: http.StatusBadRequest, Error: fmt.Errorf("unknown provider for model %s", modelName)}
}
req := coreexecutor.Request{
Model: modelName,
Payload: cloneBytes(rawJSON),
}
opts := coreexecutor.Options{
Stream: false,
Alt: alt,
OriginalRequest: cloneBytes(rawJSON),
SourceFormat: sdktranslator.FromString(handlerType),
}
resp, err := h.AuthManager.ExecuteCount(ctx, providers, req, opts)
if err != nil {
return nil, &interfaces.ErrorMessage{StatusCode: http.StatusInternalServerError, Error: err}
}
return cloneBytes(resp.Payload), nil
}
// ExecuteStreamWithAuthManager executes a streaming request via the core auth manager.
// This path is the only supported execution route.
func (h *BaseAPIHandler) ExecuteStreamWithAuthManager(ctx context.Context, handlerType, modelName string, rawJSON []byte, alt string) (<-chan []byte, <-chan *interfaces.ErrorMessage) {
providers := util.GetProviderName(modelName, h.Cfg)
if len(providers) == 0 {
errChan := make(chan *interfaces.ErrorMessage, 1)
errChan <- &interfaces.ErrorMessage{StatusCode: http.StatusBadRequest, Error: fmt.Errorf("unknown provider for model %s", modelName)}
close(errChan)
return nil, errChan
}
req := coreexecutor.Request{
Model: modelName,
Payload: cloneBytes(rawJSON),
}
opts := coreexecutor.Options{
Stream: true,
Alt: alt,
OriginalRequest: cloneBytes(rawJSON),
SourceFormat: sdktranslator.FromString(handlerType),
}
chunks, err := h.AuthManager.ExecuteStream(ctx, providers, req, opts)
if err != nil {
errChan := make(chan *interfaces.ErrorMessage, 1)
errChan <- &interfaces.ErrorMessage{StatusCode: http.StatusInternalServerError, Error: err}
close(errChan)
return nil, errChan
}
dataChan := make(chan []byte)
errChan := make(chan *interfaces.ErrorMessage, 1)
go func() {
defer close(dataChan)
defer close(errChan)
for chunk := range chunks {
if chunk.Err != nil {
errChan <- &interfaces.ErrorMessage{StatusCode: http.StatusInternalServerError, Error: chunk.Err}
return
}
if len(chunk.Payload) > 0 {
dataChan <- cloneBytes(chunk.Payload)
}
}
}()
return dataChan, errChan
}
func cloneBytes(src []byte) []byte {
if len(src) == 0 {
return nil
}
dst := make([]byte, len(src))
copy(dst, src)
return dst
}
// WriteErrorResponse writes an error message to the response writer using the HTTP status embedded in the message.
func (h *BaseAPIHandler) WriteErrorResponse(c *gin.Context, msg *interfaces.ErrorMessage) {
status := http.StatusInternalServerError
if msg != nil && msg.StatusCode > 0 {
status = msg.StatusCode
}
c.Status(status)
if msg != nil && msg.Error != nil {
_, _ = c.Writer.Write([]byte(msg.Error.Error()))
} else {
_, _ = c.Writer.Write([]byte(http.StatusText(status)))
}
}
func (h *BaseAPIHandler) LoggingAPIResponseError(ctx context.Context, err *interfaces.ErrorMessage) {
if h.Cfg.RequestLog {
if ginContext, ok := ctx.Value("gin").(*gin.Context); ok {
if apiResponseErrors, isExist := ginContext.Get("API_RESPONSE_ERROR"); isExist {
if slicesAPIResponseError, isOk := apiResponseErrors.([]*interfaces.ErrorMessage); isOk {
slicesAPIResponseError = append(slicesAPIResponseError, err)
ginContext.Set("API_RESPONSE_ERROR", slicesAPIResponseError)
}
} else {
// Create new response data entry
ginContext.Set("API_RESPONSE_ERROR", []*interfaces.ErrorMessage{err})
}
}
}
}
// APIHandlerCancelFunc is a function type for canceling an API handler's context.
// It can optionally accept parameters, which are used for logging the response.
type APIHandlerCancelFunc func(params ...interface{})

View File

@@ -0,0 +1,955 @@
package management
import (
"context"
"crypto/sha256"
"encoding/hex"
"encoding/json"
"fmt"
"io"
"net/http"
"net/url"
"os"
"path/filepath"
"strconv"
"strings"
"time"
"github.com/gin-gonic/gin"
"github.com/router-for-me/CLIProxyAPI/v6/internal/auth/claude"
"github.com/router-for-me/CLIProxyAPI/v6/internal/auth/codex"
geminiAuth "github.com/router-for-me/CLIProxyAPI/v6/internal/auth/gemini"
"github.com/router-for-me/CLIProxyAPI/v6/internal/auth/qwen"
// legacy client removed
"github.com/router-for-me/CLIProxyAPI/v6/internal/misc"
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
sdkAuth "github.com/router-for-me/CLIProxyAPI/v6/sdk/auth"
coreauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
log "github.com/sirupsen/logrus"
"github.com/tidwall/gjson"
"golang.org/x/oauth2"
"golang.org/x/oauth2/google"
)
var (
oauthStatus = make(map[string]string)
)
var lastRefreshKeys = []string{"last_refresh", "lastRefresh", "last_refreshed_at", "lastRefreshedAt"}
func extractLastRefreshTimestamp(meta map[string]any) (time.Time, bool) {
if len(meta) == 0 {
return time.Time{}, false
}
for _, key := range lastRefreshKeys {
if val, ok := meta[key]; ok {
if ts, ok1 := parseLastRefreshValue(val); ok1 {
return ts, true
}
}
}
return time.Time{}, false
}
func parseLastRefreshValue(v any) (time.Time, bool) {
switch val := v.(type) {
case string:
s := strings.TrimSpace(val)
if s == "" {
return time.Time{}, false
}
layouts := []string{time.RFC3339, time.RFC3339Nano, "2006-01-02 15:04:05", "2006-01-02T15:04:05Z07:00"}
for _, layout := range layouts {
if ts, err := time.Parse(layout, s); err == nil {
return ts.UTC(), true
}
}
if unix, err := strconv.ParseInt(s, 10, 64); err == nil && unix > 0 {
return time.Unix(unix, 0).UTC(), true
}
case float64:
if val <= 0 {
return time.Time{}, false
}
return time.Unix(int64(val), 0).UTC(), true
case int64:
if val <= 0 {
return time.Time{}, false
}
return time.Unix(val, 0).UTC(), true
case int:
if val <= 0 {
return time.Time{}, false
}
return time.Unix(int64(val), 0).UTC(), true
case json.Number:
if i, err := val.Int64(); err == nil && i > 0 {
return time.Unix(i, 0).UTC(), true
}
}
return time.Time{}, false
}
// List auth files
func (h *Handler) ListAuthFiles(c *gin.Context) {
entries, err := os.ReadDir(h.cfg.AuthDir)
if err != nil {
c.JSON(500, gin.H{"error": fmt.Sprintf("failed to read auth dir: %v", err)})
return
}
files := make([]gin.H, 0)
for _, e := range entries {
if e.IsDir() {
continue
}
name := e.Name()
if !strings.HasSuffix(strings.ToLower(name), ".json") {
continue
}
if info, errInfo := e.Info(); errInfo == nil {
fileData := gin.H{"name": name, "size": info.Size(), "modtime": info.ModTime()}
// Read file to get type field
full := filepath.Join(h.cfg.AuthDir, name)
if data, errRead := os.ReadFile(full); errRead == nil {
typeValue := gjson.GetBytes(data, "type").String()
fileData["type"] = typeValue
}
files = append(files, fileData)
}
}
c.JSON(200, gin.H{"files": files})
}
// Download single auth file by name
func (h *Handler) DownloadAuthFile(c *gin.Context) {
name := c.Query("name")
if name == "" || strings.Contains(name, string(os.PathSeparator)) {
c.JSON(400, gin.H{"error": "invalid name"})
return
}
if !strings.HasSuffix(strings.ToLower(name), ".json") {
c.JSON(400, gin.H{"error": "name must end with .json"})
return
}
full := filepath.Join(h.cfg.AuthDir, name)
data, err := os.ReadFile(full)
if err != nil {
if os.IsNotExist(err) {
c.JSON(404, gin.H{"error": "file not found"})
} else {
c.JSON(500, gin.H{"error": fmt.Sprintf("failed to read file: %v", err)})
}
return
}
c.Header("Content-Disposition", fmt.Sprintf("attachment; filename=\"%s\"", name))
c.Data(200, "application/json", data)
}
// Upload auth file: multipart or raw JSON with ?name=
func (h *Handler) UploadAuthFile(c *gin.Context) {
if h.authManager == nil {
c.JSON(http.StatusServiceUnavailable, gin.H{"error": "core auth manager unavailable"})
return
}
ctx := c.Request.Context()
if file, err := c.FormFile("file"); err == nil && file != nil {
name := filepath.Base(file.Filename)
if !strings.HasSuffix(strings.ToLower(name), ".json") {
c.JSON(400, gin.H{"error": "file must be .json"})
return
}
dst := filepath.Join(h.cfg.AuthDir, name)
if !filepath.IsAbs(dst) {
if abs, errAbs := filepath.Abs(dst); errAbs == nil {
dst = abs
}
}
if errSave := c.SaveUploadedFile(file, dst); errSave != nil {
c.JSON(500, gin.H{"error": fmt.Sprintf("failed to save file: %v", errSave)})
return
}
data, errRead := os.ReadFile(dst)
if errRead != nil {
c.JSON(500, gin.H{"error": fmt.Sprintf("failed to read saved file: %v", errRead)})
return
}
if errReg := h.registerAuthFromFile(ctx, dst, data); errReg != nil {
c.JSON(500, gin.H{"error": errReg.Error()})
return
}
c.JSON(200, gin.H{"status": "ok"})
return
}
name := c.Query("name")
if name == "" || strings.Contains(name, string(os.PathSeparator)) {
c.JSON(400, gin.H{"error": "invalid name"})
return
}
if !strings.HasSuffix(strings.ToLower(name), ".json") {
c.JSON(400, gin.H{"error": "name must end with .json"})
return
}
data, err := io.ReadAll(c.Request.Body)
if err != nil {
c.JSON(400, gin.H{"error": "failed to read body"})
return
}
dst := filepath.Join(h.cfg.AuthDir, filepath.Base(name))
if !filepath.IsAbs(dst) {
if abs, errAbs := filepath.Abs(dst); errAbs == nil {
dst = abs
}
}
if errWrite := os.WriteFile(dst, data, 0o600); errWrite != nil {
c.JSON(500, gin.H{"error": fmt.Sprintf("failed to write file: %v", errWrite)})
return
}
if err = h.registerAuthFromFile(ctx, dst, data); err != nil {
c.JSON(500, gin.H{"error": err.Error()})
return
}
c.JSON(200, gin.H{"status": "ok"})
}
// Delete auth files: single by name or all
func (h *Handler) DeleteAuthFile(c *gin.Context) {
if h.authManager == nil {
c.JSON(http.StatusServiceUnavailable, gin.H{"error": "core auth manager unavailable"})
return
}
ctx := c.Request.Context()
if all := c.Query("all"); all == "true" || all == "1" || all == "*" {
entries, err := os.ReadDir(h.cfg.AuthDir)
if err != nil {
c.JSON(500, gin.H{"error": fmt.Sprintf("failed to read auth dir: %v", err)})
return
}
deleted := 0
for _, e := range entries {
if e.IsDir() {
continue
}
name := e.Name()
if !strings.HasSuffix(strings.ToLower(name), ".json") {
continue
}
full := filepath.Join(h.cfg.AuthDir, name)
if !filepath.IsAbs(full) {
if abs, errAbs := filepath.Abs(full); errAbs == nil {
full = abs
}
}
if err = os.Remove(full); err == nil {
deleted++
h.disableAuth(ctx, full)
}
}
c.JSON(200, gin.H{"status": "ok", "deleted": deleted})
return
}
name := c.Query("name")
if name == "" || strings.Contains(name, string(os.PathSeparator)) {
c.JSON(400, gin.H{"error": "invalid name"})
return
}
full := filepath.Join(h.cfg.AuthDir, filepath.Base(name))
if !filepath.IsAbs(full) {
if abs, errAbs := filepath.Abs(full); errAbs == nil {
full = abs
}
}
if err := os.Remove(full); err != nil {
if os.IsNotExist(err) {
c.JSON(404, gin.H{"error": "file not found"})
} else {
c.JSON(500, gin.H{"error": fmt.Sprintf("failed to remove file: %v", err)})
}
return
}
h.disableAuth(ctx, full)
c.JSON(200, gin.H{"status": "ok"})
}
func (h *Handler) registerAuthFromFile(ctx context.Context, path string, data []byte) error {
if h.authManager == nil {
return nil
}
if path == "" {
return fmt.Errorf("auth path is empty")
}
if data == nil {
var err error
data, err = os.ReadFile(path)
if err != nil {
return fmt.Errorf("failed to read auth file: %w", err)
}
}
metadata := make(map[string]any)
if err := json.Unmarshal(data, &metadata); err != nil {
return fmt.Errorf("invalid auth file: %w", err)
}
provider, _ := metadata["type"].(string)
if provider == "" {
provider = "unknown"
}
label := provider
if email, ok := metadata["email"].(string); ok && email != "" {
label = email
}
lastRefresh, hasLastRefresh := extractLastRefreshTimestamp(metadata)
attr := map[string]string{
"path": path,
"source": path,
}
auth := &coreauth.Auth{
ID: path,
Provider: provider,
Label: label,
Status: coreauth.StatusActive,
Attributes: attr,
Metadata: metadata,
CreatedAt: time.Now(),
UpdatedAt: time.Now(),
}
if hasLastRefresh {
auth.LastRefreshedAt = lastRefresh
}
if existing, ok := h.authManager.GetByID(path); ok {
auth.CreatedAt = existing.CreatedAt
if !hasLastRefresh {
auth.LastRefreshedAt = existing.LastRefreshedAt
}
auth.NextRefreshAfter = existing.NextRefreshAfter
auth.Runtime = existing.Runtime
_, err := h.authManager.Update(ctx, auth)
return err
}
_, err := h.authManager.Register(ctx, auth)
return err
}
func (h *Handler) disableAuth(ctx context.Context, id string) {
if h.authManager == nil || id == "" {
return
}
if auth, ok := h.authManager.GetByID(id); ok {
auth.Disabled = true
auth.Status = coreauth.StatusDisabled
auth.StatusMessage = "removed via management API"
auth.UpdatedAt = time.Now()
_, _ = h.authManager.Update(ctx, auth)
}
}
func (h *Handler) saveTokenRecord(ctx context.Context, record *sdkAuth.TokenRecord) (string, error) {
if record == nil {
return "", fmt.Errorf("token record is nil")
}
store := h.tokenStore
if store == nil {
store = sdkAuth.GetTokenStore()
h.tokenStore = store
}
return store.Save(ctx, h.cfg, record)
}
func (h *Handler) RequestAnthropicToken(c *gin.Context) {
ctx := context.Background()
log.Info("Initializing Claude authentication...")
// Generate PKCE codes
pkceCodes, err := claude.GeneratePKCECodes()
if err != nil {
log.Fatalf("Failed to generate PKCE codes: %v", err)
return
}
// Generate random state parameter
state, err := misc.GenerateRandomState()
if err != nil {
log.Fatalf("Failed to generate state parameter: %v", err)
return
}
// Initialize Claude auth service
anthropicAuth := claude.NewClaudeAuth(h.cfg)
// Generate authorization URL (then override redirect_uri to reuse server port)
authURL, state, err := anthropicAuth.GenerateAuthURL(state, pkceCodes)
if err != nil {
log.Fatalf("Failed to generate authorization URL: %v", err)
return
}
// Override redirect_uri in authorization URL to current server port
go func() {
// Helper: wait for callback file
waitFile := filepath.Join(h.cfg.AuthDir, fmt.Sprintf(".oauth-anthropic-%s.oauth", state))
waitForFile := func(path string, timeout time.Duration) (map[string]string, error) {
deadline := time.Now().Add(timeout)
for {
if time.Now().After(deadline) {
oauthStatus[state] = "Timeout waiting for OAuth callback"
return nil, fmt.Errorf("timeout waiting for OAuth callback")
}
data, errRead := os.ReadFile(path)
if errRead == nil {
var m map[string]string
_ = json.Unmarshal(data, &m)
_ = os.Remove(path)
return m, nil
}
time.Sleep(500 * time.Millisecond)
}
}
log.Info("Waiting for authentication callback...")
// Wait up to 5 minutes
resultMap, errWait := waitForFile(waitFile, 5*time.Minute)
if errWait != nil {
authErr := claude.NewAuthenticationError(claude.ErrCallbackTimeout, errWait)
log.Error(claude.GetUserFriendlyMessage(authErr))
return
}
if errStr := resultMap["error"]; errStr != "" {
oauthErr := claude.NewOAuthError(errStr, "", http.StatusBadRequest)
log.Error(claude.GetUserFriendlyMessage(oauthErr))
oauthStatus[state] = "Bad request"
return
}
if resultMap["state"] != state {
authErr := claude.NewAuthenticationError(claude.ErrInvalidState, fmt.Errorf("expected %s, got %s", state, resultMap["state"]))
log.Error(claude.GetUserFriendlyMessage(authErr))
oauthStatus[state] = "State code error"
return
}
// Parse code (Claude may append state after '#')
rawCode := resultMap["code"]
code := strings.Split(rawCode, "#")[0]
// Exchange code for tokens (replicate logic using updated redirect_uri)
// Extract client_id from the modified auth URL
clientID := ""
if u2, errP := url.Parse(authURL); errP == nil {
clientID = u2.Query().Get("client_id")
}
// Build request
bodyMap := map[string]any{
"code": code,
"state": state,
"grant_type": "authorization_code",
"client_id": clientID,
"redirect_uri": "http://localhost:54545/callback",
"code_verifier": pkceCodes.CodeVerifier,
}
bodyJSON, _ := json.Marshal(bodyMap)
httpClient := util.SetProxy(h.cfg, &http.Client{})
req, _ := http.NewRequestWithContext(ctx, "POST", "https://console.anthropic.com/v1/oauth/token", strings.NewReader(string(bodyJSON)))
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Accept", "application/json")
resp, errDo := httpClient.Do(req)
if errDo != nil {
authErr := claude.NewAuthenticationError(claude.ErrCodeExchangeFailed, errDo)
log.Errorf("Failed to exchange authorization code for tokens: %v", authErr)
oauthStatus[state] = "Failed to exchange authorization code for tokens"
return
}
defer func() {
if errClose := resp.Body.Close(); errClose != nil {
log.Errorf("failed to close response body: %v", errClose)
}
}()
respBody, _ := io.ReadAll(resp.Body)
if resp.StatusCode != http.StatusOK {
log.Errorf("token exchange failed with status %d: %s", resp.StatusCode, string(respBody))
oauthStatus[state] = fmt.Sprintf("token exchange failed with status %d", resp.StatusCode)
return
}
var tResp struct {
AccessToken string `json:"access_token"`
RefreshToken string `json:"refresh_token"`
ExpiresIn int `json:"expires_in"`
Account struct {
EmailAddress string `json:"email_address"`
} `json:"account"`
}
if errU := json.Unmarshal(respBody, &tResp); errU != nil {
log.Errorf("failed to parse token response: %v", errU)
oauthStatus[state] = "Failed to parse token response"
return
}
bundle := &claude.ClaudeAuthBundle{
TokenData: claude.ClaudeTokenData{
AccessToken: tResp.AccessToken,
RefreshToken: tResp.RefreshToken,
Email: tResp.Account.EmailAddress,
Expire: time.Now().Add(time.Duration(tResp.ExpiresIn) * time.Second).Format(time.RFC3339),
},
LastRefresh: time.Now().Format(time.RFC3339),
}
// Create token storage
tokenStorage := anthropicAuth.CreateTokenStorage(bundle)
record := &sdkAuth.TokenRecord{
Provider: "claude",
FileName: fmt.Sprintf("claude-%s.json", tokenStorage.Email),
Storage: tokenStorage,
Metadata: map[string]string{"email": tokenStorage.Email},
}
savedPath, errSave := h.saveTokenRecord(ctx, record)
if errSave != nil {
log.Fatalf("Failed to save authentication tokens: %v", errSave)
oauthStatus[state] = "Failed to save authentication tokens"
return
}
log.Infof("Authentication successful! Token saved to %s", savedPath)
if bundle.APIKey != "" {
log.Info("API key obtained and saved")
}
log.Info("You can now use Claude services through this CLI")
delete(oauthStatus, state)
}()
oauthStatus[state] = ""
c.JSON(200, gin.H{"status": "ok", "url": authURL, "state": state})
}
func (h *Handler) RequestGeminiCLIToken(c *gin.Context) {
ctx := context.Background()
// Optional project ID from query
projectID := c.Query("project_id")
log.Info("Initializing Google authentication...")
// OAuth2 configuration (mirrors internal/auth/gemini)
conf := &oauth2.Config{
ClientID: "681255809395-oo8ft2oprdrnp9e3aqf6av3hmdib135j.apps.googleusercontent.com",
ClientSecret: "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",
RedirectURL: "http://localhost:8085/oauth2callback",
Scopes: []string{
"https://www.googleapis.com/auth/cloud-platform",
"https://www.googleapis.com/auth/userinfo.email",
"https://www.googleapis.com/auth/userinfo.profile",
},
Endpoint: google.Endpoint,
}
// Build authorization URL and return it immediately
state := fmt.Sprintf("gem-%d", time.Now().UnixNano())
authURL := conf.AuthCodeURL(state, oauth2.AccessTypeOffline, oauth2.SetAuthURLParam("prompt", "consent"))
go func() {
// Wait for callback file written by server route
waitFile := filepath.Join(h.cfg.AuthDir, fmt.Sprintf(".oauth-gemini-%s.oauth", state))
log.Info("Waiting for authentication callback...")
deadline := time.Now().Add(5 * time.Minute)
var authCode string
for {
if time.Now().After(deadline) {
log.Error("oauth flow timed out")
oauthStatus[state] = "OAuth flow timed out"
return
}
if data, errR := os.ReadFile(waitFile); errR == nil {
var m map[string]string
_ = json.Unmarshal(data, &m)
_ = os.Remove(waitFile)
if errStr := m["error"]; errStr != "" {
log.Errorf("Authentication failed: %s", errStr)
oauthStatus[state] = "Authentication failed"
return
}
authCode = m["code"]
if authCode == "" {
log.Errorf("Authentication failed: code not found")
oauthStatus[state] = "Authentication failed: code not found"
return
}
break
}
time.Sleep(500 * time.Millisecond)
}
// Exchange authorization code for token
token, err := conf.Exchange(ctx, authCode)
if err != nil {
log.Errorf("Failed to exchange token: %v", err)
oauthStatus[state] = "Failed to exchange token"
return
}
// Create token storage (mirrors internal/auth/gemini createTokenStorage)
httpClient := conf.Client(ctx, token)
req, errNewRequest := http.NewRequestWithContext(ctx, "GET", "https://www.googleapis.com/oauth2/v1/userinfo?alt=json", nil)
if errNewRequest != nil {
log.Errorf("Could not get user info: %v", errNewRequest)
oauthStatus[state] = "Could not get user info"
return
}
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", token.AccessToken))
resp, errDo := httpClient.Do(req)
if errDo != nil {
log.Errorf("Failed to execute request: %v", errDo)
oauthStatus[state] = "Failed to execute request"
return
}
defer func() {
if errClose := resp.Body.Close(); errClose != nil {
log.Printf("warn: failed to close response body: %v", errClose)
}
}()
bodyBytes, _ := io.ReadAll(resp.Body)
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
log.Errorf("Get user info request failed with status %d: %s", resp.StatusCode, string(bodyBytes))
oauthStatus[state] = fmt.Sprintf("Get user info request failed with status %d", resp.StatusCode)
return
}
email := gjson.GetBytes(bodyBytes, "email").String()
if email != "" {
log.Infof("Authenticated user email: %s", email)
} else {
log.Info("Failed to get user email from token")
oauthStatus[state] = "Failed to get user email from token"
}
// Marshal/unmarshal oauth2.Token to generic map and enrich fields
var ifToken map[string]any
jsonData, _ := json.Marshal(token)
if errUnmarshal := json.Unmarshal(jsonData, &ifToken); errUnmarshal != nil {
log.Errorf("Failed to unmarshal token: %v", errUnmarshal)
oauthStatus[state] = "Failed to unmarshal token"
return
}
ifToken["token_uri"] = "https://oauth2.googleapis.com/token"
ifToken["client_id"] = "681255809395-oo8ft2oprdrnp9e3aqf6av3hmdib135j.apps.googleusercontent.com"
ifToken["client_secret"] = "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"
ifToken["scopes"] = []string{
"https://www.googleapis.com/auth/cloud-platform",
"https://www.googleapis.com/auth/userinfo.email",
"https://www.googleapis.com/auth/userinfo.profile",
}
ifToken["universe_domain"] = "googleapis.com"
ts := geminiAuth.GeminiTokenStorage{
Token: ifToken,
ProjectID: projectID,
Email: email,
}
// Initialize authenticated HTTP client via GeminiAuth to honor proxy settings
gemAuth := geminiAuth.NewGeminiAuth()
_, errGetClient := gemAuth.GetAuthenticatedClient(ctx, &ts, h.cfg, true)
if errGetClient != nil {
log.Fatalf("failed to get authenticated client: %v", errGetClient)
oauthStatus[state] = "Failed to get authenticated client"
return
}
log.Info("Authentication successful.")
record := &sdkAuth.TokenRecord{
Provider: "gemini",
FileName: fmt.Sprintf("gemini-%s.json", ts.Email),
Storage: &ts,
Metadata: map[string]string{
"email": ts.Email,
"project_id": ts.ProjectID,
},
}
savedPath, errSave := h.saveTokenRecord(ctx, record)
if errSave != nil {
log.Fatalf("Failed to save token to file: %v", errSave)
oauthStatus[state] = "Failed to save token to file"
return
}
delete(oauthStatus, state)
log.Infof("You can now use Gemini CLI services through this CLI; token saved to %s", savedPath)
}()
oauthStatus[state] = ""
c.JSON(200, gin.H{"status": "ok", "url": authURL, "state": state})
}
func (h *Handler) CreateGeminiWebToken(c *gin.Context) {
ctx := c.Request.Context()
var payload struct {
Secure1PSID string `json:"secure_1psid"`
Secure1PSIDTS string `json:"secure_1psidts"`
}
if err := c.ShouldBindJSON(&payload); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid body"})
return
}
payload.Secure1PSID = strings.TrimSpace(payload.Secure1PSID)
payload.Secure1PSIDTS = strings.TrimSpace(payload.Secure1PSIDTS)
if payload.Secure1PSID == "" {
c.JSON(http.StatusBadRequest, gin.H{"error": "secure_1psid is required"})
return
}
if payload.Secure1PSIDTS == "" {
c.JSON(http.StatusBadRequest, gin.H{"error": "secure_1psidts is required"})
return
}
sha := sha256.New()
sha.Write([]byte(payload.Secure1PSID))
hash := hex.EncodeToString(sha.Sum(nil))
fileName := fmt.Sprintf("gemini-web-%s.json", hash[:16])
tokenStorage := &geminiAuth.GeminiWebTokenStorage{
Secure1PSID: payload.Secure1PSID,
Secure1PSIDTS: payload.Secure1PSIDTS,
}
record := &sdkAuth.TokenRecord{
Provider: "gemini-web",
FileName: fileName,
Storage: tokenStorage,
}
savedPath, errSave := h.saveTokenRecord(ctx, record)
if errSave != nil {
log.Errorf("Failed to save Gemini Web token: %v", errSave)
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to save token"})
return
}
log.Infof("Successfully saved Gemini Web token to: %s", savedPath)
c.JSON(http.StatusOK, gin.H{"status": "ok", "file": filepath.Base(savedPath)})
}
func (h *Handler) RequestCodexToken(c *gin.Context) {
ctx := context.Background()
log.Info("Initializing Codex authentication...")
// Generate PKCE codes
pkceCodes, err := codex.GeneratePKCECodes()
if err != nil {
log.Fatalf("Failed to generate PKCE codes: %v", err)
return
}
// Generate random state parameter
state, err := misc.GenerateRandomState()
if err != nil {
log.Fatalf("Failed to generate state parameter: %v", err)
return
}
// Initialize Codex auth service
openaiAuth := codex.NewCodexAuth(h.cfg)
// Generate authorization URL
authURL, err := openaiAuth.GenerateAuthURL(state, pkceCodes)
if err != nil {
log.Fatalf("Failed to generate authorization URL: %v", err)
return
}
go func() {
// Wait for callback file
waitFile := filepath.Join(h.cfg.AuthDir, fmt.Sprintf(".oauth-codex-%s.oauth", state))
deadline := time.Now().Add(5 * time.Minute)
var code string
for {
if time.Now().After(deadline) {
authErr := codex.NewAuthenticationError(codex.ErrCallbackTimeout, fmt.Errorf("timeout waiting for OAuth callback"))
log.Error(codex.GetUserFriendlyMessage(authErr))
oauthStatus[state] = "Timeout waiting for OAuth callback"
return
}
if data, errR := os.ReadFile(waitFile); errR == nil {
var m map[string]string
_ = json.Unmarshal(data, &m)
_ = os.Remove(waitFile)
if errStr := m["error"]; errStr != "" {
oauthErr := codex.NewOAuthError(errStr, "", http.StatusBadRequest)
log.Error(codex.GetUserFriendlyMessage(oauthErr))
oauthStatus[state] = "Bad Request"
return
}
if m["state"] != state {
authErr := codex.NewAuthenticationError(codex.ErrInvalidState, fmt.Errorf("expected %s, got %s", state, m["state"]))
oauthStatus[state] = "State code error"
log.Error(codex.GetUserFriendlyMessage(authErr))
return
}
code = m["code"]
break
}
time.Sleep(500 * time.Millisecond)
}
log.Debug("Authorization code received, exchanging for tokens...")
// Extract client_id from authURL
clientID := ""
if u2, errP := url.Parse(authURL); errP == nil {
clientID = u2.Query().Get("client_id")
}
// Exchange code for tokens with redirect equal to mgmtRedirect
form := url.Values{
"grant_type": {"authorization_code"},
"client_id": {clientID},
"code": {code},
"redirect_uri": {"http://localhost:1455/auth/callback"},
"code_verifier": {pkceCodes.CodeVerifier},
}
httpClient := util.SetProxy(h.cfg, &http.Client{})
req, _ := http.NewRequestWithContext(ctx, "POST", "https://auth.openai.com/oauth/token", strings.NewReader(form.Encode()))
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
req.Header.Set("Accept", "application/json")
resp, errDo := httpClient.Do(req)
if errDo != nil {
authErr := codex.NewAuthenticationError(codex.ErrCodeExchangeFailed, errDo)
oauthStatus[state] = "Failed to exchange authorization code for tokens"
log.Errorf("Failed to exchange authorization code for tokens: %v", authErr)
return
}
defer func() { _ = resp.Body.Close() }()
respBody, _ := io.ReadAll(resp.Body)
if resp.StatusCode != http.StatusOK {
oauthStatus[state] = fmt.Sprintf("Token exchange failed with status %d", resp.StatusCode)
log.Errorf("token exchange failed with status %d: %s", resp.StatusCode, string(respBody))
return
}
var tokenResp struct {
AccessToken string `json:"access_token"`
RefreshToken string `json:"refresh_token"`
IDToken string `json:"id_token"`
ExpiresIn int `json:"expires_in"`
}
if errU := json.Unmarshal(respBody, &tokenResp); errU != nil {
oauthStatus[state] = "Failed to parse token response"
log.Errorf("failed to parse token response: %v", errU)
return
}
claims, _ := codex.ParseJWTToken(tokenResp.IDToken)
email := ""
accountID := ""
if claims != nil {
email = claims.GetUserEmail()
accountID = claims.GetAccountID()
}
// Build bundle compatible with existing storage
bundle := &codex.CodexAuthBundle{
TokenData: codex.CodexTokenData{
IDToken: tokenResp.IDToken,
AccessToken: tokenResp.AccessToken,
RefreshToken: tokenResp.RefreshToken,
AccountID: accountID,
Email: email,
Expire: time.Now().Add(time.Duration(tokenResp.ExpiresIn) * time.Second).Format(time.RFC3339),
},
LastRefresh: time.Now().Format(time.RFC3339),
}
// Create token storage and persist
tokenStorage := openaiAuth.CreateTokenStorage(bundle)
record := &sdkAuth.TokenRecord{
Provider: "codex",
FileName: fmt.Sprintf("codex-%s.json", tokenStorage.Email),
Storage: tokenStorage,
Metadata: map[string]string{
"email": tokenStorage.Email,
"account_id": tokenStorage.AccountID,
},
}
savedPath, errSave := h.saveTokenRecord(ctx, record)
if errSave != nil {
oauthStatus[state] = "Failed to save authentication tokens"
log.Fatalf("Failed to save authentication tokens: %v", errSave)
return
}
log.Infof("Authentication successful! Token saved to %s", savedPath)
if bundle.APIKey != "" {
log.Info("API key obtained and saved")
}
log.Info("You can now use Codex services through this CLI")
delete(oauthStatus, state)
}()
oauthStatus[state] = ""
c.JSON(200, gin.H{"status": "ok", "url": authURL, "state": state})
}
func (h *Handler) RequestQwenToken(c *gin.Context) {
ctx := context.Background()
log.Info("Initializing Qwen authentication...")
state := fmt.Sprintf("gem-%d", time.Now().UnixNano())
// Initialize Qwen auth service
qwenAuth := qwen.NewQwenAuth(h.cfg)
// Generate authorization URL
deviceFlow, err := qwenAuth.InitiateDeviceFlow(ctx)
if err != nil {
log.Fatalf("Failed to generate authorization URL: %v", err)
return
}
authURL := deviceFlow.VerificationURIComplete
go func() {
log.Info("Waiting for authentication...")
tokenData, errPollForToken := qwenAuth.PollForToken(deviceFlow.DeviceCode, deviceFlow.CodeVerifier)
if errPollForToken != nil {
oauthStatus[state] = "Authentication failed"
fmt.Printf("Authentication failed: %v\n", errPollForToken)
return
}
// Create token storage
tokenStorage := qwenAuth.CreateTokenStorage(tokenData)
tokenStorage.Email = fmt.Sprintf("qwen-%d", time.Now().UnixMilli())
record := &sdkAuth.TokenRecord{
Provider: "qwen",
FileName: fmt.Sprintf("qwen-%s.json", tokenStorage.Email),
Storage: tokenStorage,
Metadata: map[string]string{"email": tokenStorage.Email},
}
savedPath, errSave := h.saveTokenRecord(ctx, record)
if errSave != nil {
log.Fatalf("Failed to save authentication tokens: %v", errSave)
oauthStatus[state] = "Failed to save authentication tokens"
return
}
log.Infof("Authentication successful! Token saved to %s", savedPath)
log.Info("You can now use Qwen services through this CLI")
delete(oauthStatus, state)
}()
oauthStatus[state] = ""
c.JSON(200, gin.H{"status": "ok", "url": authURL, "state": state})
}
func (h *Handler) GetAuthStatus(c *gin.Context) {
state := c.Query("state")
if err, ok := oauthStatus[state]; ok {
if err != "" {
c.JSON(200, gin.H{"status": "error", "error": err})
} else {
c.JSON(200, gin.H{"status": "wait"})
return
}
} else {
c.JSON(200, gin.H{"status": "ok"})
}
delete(oauthStatus, state)
}

View File

@@ -0,0 +1,37 @@
package management
import (
"github.com/gin-gonic/gin"
)
func (h *Handler) GetConfig(c *gin.Context) {
c.JSON(200, h.cfg)
}
// Debug
func (h *Handler) GetDebug(c *gin.Context) { c.JSON(200, gin.H{"debug": h.cfg.Debug}) }
func (h *Handler) PutDebug(c *gin.Context) { h.updateBoolField(c, func(v bool) { h.cfg.Debug = v }) }
// Request log
func (h *Handler) GetRequestLog(c *gin.Context) { c.JSON(200, gin.H{"request-log": h.cfg.RequestLog}) }
func (h *Handler) PutRequestLog(c *gin.Context) {
h.updateBoolField(c, func(v bool) { h.cfg.RequestLog = v })
}
// Request retry
func (h *Handler) GetRequestRetry(c *gin.Context) {
c.JSON(200, gin.H{"request-retry": h.cfg.RequestRetry})
}
func (h *Handler) PutRequestRetry(c *gin.Context) {
h.updateIntField(c, func(v int) { h.cfg.RequestRetry = v })
}
// Proxy URL
func (h *Handler) GetProxyURL(c *gin.Context) { c.JSON(200, gin.H{"proxy-url": h.cfg.ProxyURL}) }
func (h *Handler) PutProxyURL(c *gin.Context) {
h.updateStringField(c, func(v string) { h.cfg.ProxyURL = v })
}
func (h *Handler) DeleteProxyURL(c *gin.Context) {
h.cfg.ProxyURL = ""
h.persist(c)
}

View File

@@ -0,0 +1,348 @@
package management
import (
"encoding/json"
"fmt"
"github.com/gin-gonic/gin"
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
)
// Generic helpers for list[string]
func (h *Handler) putStringList(c *gin.Context, set func([]string), after func()) {
data, err := c.GetRawData()
if err != nil {
c.JSON(400, gin.H{"error": "failed to read body"})
return
}
var arr []string
if err = json.Unmarshal(data, &arr); err != nil {
var obj struct {
Items []string `json:"items"`
}
if err2 := json.Unmarshal(data, &obj); err2 != nil || len(obj.Items) == 0 {
c.JSON(400, gin.H{"error": "invalid body"})
return
}
arr = obj.Items
}
set(arr)
if after != nil {
after()
}
h.persist(c)
}
func (h *Handler) patchStringList(c *gin.Context, target *[]string, after func()) {
var body struct {
Old *string `json:"old"`
New *string `json:"new"`
Index *int `json:"index"`
Value *string `json:"value"`
}
if err := c.ShouldBindJSON(&body); err != nil {
c.JSON(400, gin.H{"error": "invalid body"})
return
}
if body.Index != nil && body.Value != nil && *body.Index >= 0 && *body.Index < len(*target) {
(*target)[*body.Index] = *body.Value
if after != nil {
after()
}
h.persist(c)
return
}
if body.Old != nil && body.New != nil {
for i := range *target {
if (*target)[i] == *body.Old {
(*target)[i] = *body.New
if after != nil {
after()
}
h.persist(c)
return
}
}
*target = append(*target, *body.New)
if after != nil {
after()
}
h.persist(c)
return
}
c.JSON(400, gin.H{"error": "missing fields"})
}
func (h *Handler) deleteFromStringList(c *gin.Context, target *[]string, after func()) {
if idxStr := c.Query("index"); idxStr != "" {
var idx int
_, err := fmt.Sscanf(idxStr, "%d", &idx)
if err == nil && idx >= 0 && idx < len(*target) {
*target = append((*target)[:idx], (*target)[idx+1:]...)
if after != nil {
after()
}
h.persist(c)
return
}
}
if val := c.Query("value"); val != "" {
out := make([]string, 0, len(*target))
for _, v := range *target {
if v != val {
out = append(out, v)
}
}
*target = out
if after != nil {
after()
}
h.persist(c)
return
}
c.JSON(400, gin.H{"error": "missing index or value"})
}
// api-keys
func (h *Handler) GetAPIKeys(c *gin.Context) { c.JSON(200, gin.H{"api-keys": h.cfg.APIKeys}) }
func (h *Handler) PutAPIKeys(c *gin.Context) {
h.putStringList(c, func(v []string) { config.SyncInlineAPIKeys(h.cfg, v) }, nil)
}
func (h *Handler) PatchAPIKeys(c *gin.Context) {
h.patchStringList(c, &h.cfg.APIKeys, func() { config.SyncInlineAPIKeys(h.cfg, h.cfg.APIKeys) })
}
func (h *Handler) DeleteAPIKeys(c *gin.Context) {
h.deleteFromStringList(c, &h.cfg.APIKeys, func() { config.SyncInlineAPIKeys(h.cfg, h.cfg.APIKeys) })
}
// generative-language-api-key
func (h *Handler) GetGlKeys(c *gin.Context) {
c.JSON(200, gin.H{"generative-language-api-key": h.cfg.GlAPIKey})
}
func (h *Handler) PutGlKeys(c *gin.Context) {
h.putStringList(c, func(v []string) { h.cfg.GlAPIKey = v }, nil)
}
func (h *Handler) PatchGlKeys(c *gin.Context) { h.patchStringList(c, &h.cfg.GlAPIKey, nil) }
func (h *Handler) DeleteGlKeys(c *gin.Context) { h.deleteFromStringList(c, &h.cfg.GlAPIKey, nil) }
// claude-api-key: []ClaudeKey
func (h *Handler) GetClaudeKeys(c *gin.Context) {
c.JSON(200, gin.H{"claude-api-key": h.cfg.ClaudeKey})
}
func (h *Handler) PutClaudeKeys(c *gin.Context) {
data, err := c.GetRawData()
if err != nil {
c.JSON(400, gin.H{"error": "failed to read body"})
return
}
var arr []config.ClaudeKey
if err = json.Unmarshal(data, &arr); err != nil {
var obj struct {
Items []config.ClaudeKey `json:"items"`
}
if err2 := json.Unmarshal(data, &obj); err2 != nil || len(obj.Items) == 0 {
c.JSON(400, gin.H{"error": "invalid body"})
return
}
arr = obj.Items
}
h.cfg.ClaudeKey = arr
h.persist(c)
}
func (h *Handler) PatchClaudeKey(c *gin.Context) {
var body struct {
Index *int `json:"index"`
Match *string `json:"match"`
Value *config.ClaudeKey `json:"value"`
}
if err := c.ShouldBindJSON(&body); err != nil || body.Value == nil {
c.JSON(400, gin.H{"error": "invalid body"})
return
}
if body.Index != nil && *body.Index >= 0 && *body.Index < len(h.cfg.ClaudeKey) {
h.cfg.ClaudeKey[*body.Index] = *body.Value
h.persist(c)
return
}
if body.Match != nil {
for i := range h.cfg.ClaudeKey {
if h.cfg.ClaudeKey[i].APIKey == *body.Match {
h.cfg.ClaudeKey[i] = *body.Value
h.persist(c)
return
}
}
}
c.JSON(404, gin.H{"error": "item not found"})
}
func (h *Handler) DeleteClaudeKey(c *gin.Context) {
if val := c.Query("api-key"); val != "" {
out := make([]config.ClaudeKey, 0, len(h.cfg.ClaudeKey))
for _, v := range h.cfg.ClaudeKey {
if v.APIKey != val {
out = append(out, v)
}
}
h.cfg.ClaudeKey = out
h.persist(c)
return
}
if idxStr := c.Query("index"); idxStr != "" {
var idx int
_, err := fmt.Sscanf(idxStr, "%d", &idx)
if err == nil && idx >= 0 && idx < len(h.cfg.ClaudeKey) {
h.cfg.ClaudeKey = append(h.cfg.ClaudeKey[:idx], h.cfg.ClaudeKey[idx+1:]...)
h.persist(c)
return
}
}
c.JSON(400, gin.H{"error": "missing api-key or index"})
}
// openai-compatibility: []OpenAICompatibility
func (h *Handler) GetOpenAICompat(c *gin.Context) {
c.JSON(200, gin.H{"openai-compatibility": h.cfg.OpenAICompatibility})
}
func (h *Handler) PutOpenAICompat(c *gin.Context) {
data, err := c.GetRawData()
if err != nil {
c.JSON(400, gin.H{"error": "failed to read body"})
return
}
var arr []config.OpenAICompatibility
if err = json.Unmarshal(data, &arr); err != nil {
var obj struct {
Items []config.OpenAICompatibility `json:"items"`
}
if err2 := json.Unmarshal(data, &obj); err2 != nil || len(obj.Items) == 0 {
c.JSON(400, gin.H{"error": "invalid body"})
return
}
arr = obj.Items
}
h.cfg.OpenAICompatibility = arr
h.persist(c)
}
func (h *Handler) PatchOpenAICompat(c *gin.Context) {
var body struct {
Name *string `json:"name"`
Index *int `json:"index"`
Value *config.OpenAICompatibility `json:"value"`
}
if err := c.ShouldBindJSON(&body); err != nil || body.Value == nil {
c.JSON(400, gin.H{"error": "invalid body"})
return
}
if body.Index != nil && *body.Index >= 0 && *body.Index < len(h.cfg.OpenAICompatibility) {
h.cfg.OpenAICompatibility[*body.Index] = *body.Value
h.persist(c)
return
}
if body.Name != nil {
for i := range h.cfg.OpenAICompatibility {
if h.cfg.OpenAICompatibility[i].Name == *body.Name {
h.cfg.OpenAICompatibility[i] = *body.Value
h.persist(c)
return
}
}
}
c.JSON(404, gin.H{"error": "item not found"})
}
func (h *Handler) DeleteOpenAICompat(c *gin.Context) {
if name := c.Query("name"); name != "" {
out := make([]config.OpenAICompatibility, 0, len(h.cfg.OpenAICompatibility))
for _, v := range h.cfg.OpenAICompatibility {
if v.Name != name {
out = append(out, v)
}
}
h.cfg.OpenAICompatibility = out
h.persist(c)
return
}
if idxStr := c.Query("index"); idxStr != "" {
var idx int
_, err := fmt.Sscanf(idxStr, "%d", &idx)
if err == nil && idx >= 0 && idx < len(h.cfg.OpenAICompatibility) {
h.cfg.OpenAICompatibility = append(h.cfg.OpenAICompatibility[:idx], h.cfg.OpenAICompatibility[idx+1:]...)
h.persist(c)
return
}
}
c.JSON(400, gin.H{"error": "missing name or index"})
}
// codex-api-key: []CodexKey
func (h *Handler) GetCodexKeys(c *gin.Context) {
c.JSON(200, gin.H{"codex-api-key": h.cfg.CodexKey})
}
func (h *Handler) PutCodexKeys(c *gin.Context) {
data, err := c.GetRawData()
if err != nil {
c.JSON(400, gin.H{"error": "failed to read body"})
return
}
var arr []config.CodexKey
if err = json.Unmarshal(data, &arr); err != nil {
var obj struct {
Items []config.CodexKey `json:"items"`
}
if err2 := json.Unmarshal(data, &obj); err2 != nil || len(obj.Items) == 0 {
c.JSON(400, gin.H{"error": "invalid body"})
return
}
arr = obj.Items
}
h.cfg.CodexKey = arr
h.persist(c)
}
func (h *Handler) PatchCodexKey(c *gin.Context) {
var body struct {
Index *int `json:"index"`
Match *string `json:"match"`
Value *config.CodexKey `json:"value"`
}
if err := c.ShouldBindJSON(&body); err != nil || body.Value == nil {
c.JSON(400, gin.H{"error": "invalid body"})
return
}
if body.Index != nil && *body.Index >= 0 && *body.Index < len(h.cfg.CodexKey) {
h.cfg.CodexKey[*body.Index] = *body.Value
h.persist(c)
return
}
if body.Match != nil {
for i := range h.cfg.CodexKey {
if h.cfg.CodexKey[i].APIKey == *body.Match {
h.cfg.CodexKey[i] = *body.Value
h.persist(c)
return
}
}
}
c.JSON(404, gin.H{"error": "item not found"})
}
func (h *Handler) DeleteCodexKey(c *gin.Context) {
if val := c.Query("api-key"); val != "" {
out := make([]config.CodexKey, 0, len(h.cfg.CodexKey))
for _, v := range h.cfg.CodexKey {
if v.APIKey != val {
out = append(out, v)
}
}
h.cfg.CodexKey = out
h.persist(c)
return
}
if idxStr := c.Query("index"); idxStr != "" {
var idx int
_, err := fmt.Sscanf(idxStr, "%d", &idx)
if err == nil && idx >= 0 && idx < len(h.cfg.CodexKey) {
h.cfg.CodexKey = append(h.cfg.CodexKey[:idx], h.cfg.CodexKey[idx+1:]...)
h.persist(c)
return
}
}
c.JSON(400, gin.H{"error": "missing api-key or index"})
}

View File

@@ -0,0 +1,215 @@
// Package management provides the management API handlers and middleware
// for configuring the server and managing auth files.
package management
import (
"fmt"
"net/http"
"strings"
"sync"
"time"
"github.com/gin-gonic/gin"
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
"github.com/router-for-me/CLIProxyAPI/v6/internal/usage"
sdkAuth "github.com/router-for-me/CLIProxyAPI/v6/sdk/auth"
coreauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
"golang.org/x/crypto/bcrypt"
)
type attemptInfo struct {
count int
blockedUntil time.Time
}
// Handler aggregates config reference, persistence path and helpers.
type Handler struct {
cfg *config.Config
configFilePath string
mu sync.Mutex
attemptsMu sync.Mutex
failedAttempts map[string]*attemptInfo // keyed by client IP
authManager *coreauth.Manager
usageStats *usage.RequestStatistics
tokenStore sdkAuth.TokenStore
}
// NewHandler creates a new management handler instance.
func NewHandler(cfg *config.Config, configFilePath string, manager *coreauth.Manager) *Handler {
return &Handler{
cfg: cfg,
configFilePath: configFilePath,
failedAttempts: make(map[string]*attemptInfo),
authManager: manager,
usageStats: usage.GetRequestStatistics(),
tokenStore: sdkAuth.GetTokenStore(),
}
}
// SetConfig updates the in-memory config reference when the server hot-reloads.
func (h *Handler) SetConfig(cfg *config.Config) { h.cfg = cfg }
// SetAuthManager updates the auth manager reference used by management endpoints.
func (h *Handler) SetAuthManager(manager *coreauth.Manager) { h.authManager = manager }
// SetUsageStatistics allows replacing the usage statistics reference.
func (h *Handler) SetUsageStatistics(stats *usage.RequestStatistics) { h.usageStats = stats }
// Middleware enforces access control for management endpoints.
// All requests (local and remote) require a valid management key.
// Additionally, remote access requires allow-remote-management=true.
func (h *Handler) Middleware() gin.HandlerFunc {
const maxFailures = 5
const banDuration = 30 * time.Minute
return func(c *gin.Context) {
clientIP := c.ClientIP()
// For remote IPs, enforce allow-remote-management and ban checks
if !(clientIP == "127.0.0.1" || clientIP == "::1") {
// Check if IP is currently blocked
h.attemptsMu.Lock()
ai := h.failedAttempts[clientIP]
if ai != nil {
if !ai.blockedUntil.IsZero() {
if time.Now().Before(ai.blockedUntil) {
remaining := time.Until(ai.blockedUntil).Round(time.Second)
h.attemptsMu.Unlock()
c.AbortWithStatusJSON(http.StatusForbidden, gin.H{"error": fmt.Sprintf("IP banned due to too many failed attempts. Try again in %s", remaining)})
return
}
// Ban expired, reset state
ai.blockedUntil = time.Time{}
ai.count = 0
}
}
h.attemptsMu.Unlock()
allowRemote := h.cfg.RemoteManagement.AllowRemote
if !allowRemote {
c.AbortWithStatusJSON(http.StatusForbidden, gin.H{"error": "remote management disabled"})
return
}
}
secret := h.cfg.RemoteManagement.SecretKey
if secret == "" {
c.AbortWithStatusJSON(http.StatusForbidden, gin.H{"error": "remote management key not set"})
return
}
// Accept either Authorization: Bearer <key> or X-Management-Key
var provided string
if ah := c.GetHeader("Authorization"); ah != "" {
parts := strings.SplitN(ah, " ", 2)
if len(parts) == 2 && strings.ToLower(parts[0]) == "bearer" {
provided = parts[1]
} else {
provided = ah
}
}
if provided == "" {
provided = c.GetHeader("X-Management-Key")
}
if !(clientIP == "127.0.0.1" || clientIP == "::1") {
// For remote IPs, enforce key and track failures
fail := func() {
h.attemptsMu.Lock()
ai := h.failedAttempts[clientIP]
if ai == nil {
ai = &attemptInfo{}
h.failedAttempts[clientIP] = ai
}
ai.count++
if ai.count >= maxFailures {
ai.blockedUntil = time.Now().Add(banDuration)
ai.count = 0
}
h.attemptsMu.Unlock()
}
if provided == "" {
fail()
c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{"error": "missing management key"})
return
}
if err := bcrypt.CompareHashAndPassword([]byte(secret), []byte(provided)); err != nil {
fail()
c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{"error": "invalid management key"})
return
}
// Success: reset failed count for this IP
h.attemptsMu.Lock()
if ai := h.failedAttempts[clientIP]; ai != nil {
ai.count = 0
ai.blockedUntil = time.Time{}
}
h.attemptsMu.Unlock()
}
c.Next()
}
}
// persist saves the current in-memory config to disk.
func (h *Handler) persist(c *gin.Context) bool {
h.mu.Lock()
defer h.mu.Unlock()
// Preserve comments when writing
if err := config.SaveConfigPreserveComments(h.configFilePath, h.cfg); err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": fmt.Sprintf("failed to save config: %v", err)})
return false
}
c.JSON(http.StatusOK, gin.H{"status": "ok"})
return true
}
// Helper methods for simple types
func (h *Handler) updateBoolField(c *gin.Context, set func(bool)) {
var body struct {
Value *bool `json:"value"`
}
if err := c.ShouldBindJSON(&body); err != nil || body.Value == nil {
var m map[string]any
if err2 := c.ShouldBindJSON(&m); err2 == nil {
for _, v := range m {
if b, ok := v.(bool); ok {
set(b)
h.persist(c)
return
}
}
}
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid body"})
return
}
set(*body.Value)
h.persist(c)
}
func (h *Handler) updateIntField(c *gin.Context, set func(int)) {
var body struct {
Value *int `json:"value"`
}
if err := c.ShouldBindJSON(&body); err != nil || body.Value == nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid body"})
return
}
set(*body.Value)
h.persist(c)
}
func (h *Handler) updateStringField(c *gin.Context, set func(string)) {
var body struct {
Value *string `json:"value"`
}
if err := c.ShouldBindJSON(&body); err != nil || body.Value == nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid body"})
return
}
set(*body.Value)
h.persist(c)
}

View File

@@ -0,0 +1,18 @@
package management
import "github.com/gin-gonic/gin"
// Quota exceeded toggles
func (h *Handler) GetSwitchProject(c *gin.Context) {
c.JSON(200, gin.H{"switch-project": h.cfg.QuotaExceeded.SwitchProject})
}
func (h *Handler) PutSwitchProject(c *gin.Context) {
h.updateBoolField(c, func(v bool) { h.cfg.QuotaExceeded.SwitchProject = v })
}
func (h *Handler) GetSwitchPreviewModel(c *gin.Context) {
c.JSON(200, gin.H{"switch-preview-model": h.cfg.QuotaExceeded.SwitchPreviewModel})
}
func (h *Handler) PutSwitchPreviewModel(c *gin.Context) {
h.updateBoolField(c, func(v bool) { h.cfg.QuotaExceeded.SwitchPreviewModel = v })
}

View File

@@ -0,0 +1,17 @@
package management
import (
"net/http"
"github.com/gin-gonic/gin"
"github.com/router-for-me/CLIProxyAPI/v6/internal/usage"
)
// GetUsageStatistics returns the in-memory request statistics snapshot.
func (h *Handler) GetUsageStatistics(c *gin.Context) {
var snapshot usage.StatisticsSnapshot
if h != nil && h.usageStats != nil {
snapshot = h.usageStats.Snapshot()
}
c.JSON(http.StatusOK, gin.H{"usage": snapshot})
}

View File

@@ -0,0 +1,568 @@
// Package openai provides HTTP handlers for OpenAI API endpoints.
// This package implements the OpenAI-compatible API interface, including model listing
// and chat completion functionality. It supports both streaming and non-streaming responses,
// and manages a pool of clients to interact with backend services.
// The handlers translate OpenAI API requests to the appropriate backend format and
// convert responses back to OpenAI-compatible format.
package openai
import (
"context"
"encoding/json"
"fmt"
"net/http"
"time"
"github.com/gin-gonic/gin"
"github.com/router-for-me/CLIProxyAPI/v6/internal/api/handlers"
. "github.com/router-for-me/CLIProxyAPI/v6/internal/constant"
"github.com/router-for-me/CLIProxyAPI/v6/internal/interfaces"
"github.com/router-for-me/CLIProxyAPI/v6/internal/registry"
"github.com/tidwall/gjson"
"github.com/tidwall/sjson"
)
// OpenAIAPIHandler contains the handlers for OpenAI API endpoints.
// It holds a pool of clients to interact with the backend service.
type OpenAIAPIHandler struct {
*handlers.BaseAPIHandler
}
// NewOpenAIAPIHandler creates a new OpenAI API handlers instance.
// It takes an BaseAPIHandler instance as input and returns an OpenAIAPIHandler.
//
// Parameters:
// - apiHandlers: The base API handlers instance
//
// Returns:
// - *OpenAIAPIHandler: A new OpenAI API handlers instance
func NewOpenAIAPIHandler(apiHandlers *handlers.BaseAPIHandler) *OpenAIAPIHandler {
return &OpenAIAPIHandler{
BaseAPIHandler: apiHandlers,
}
}
// HandlerType returns the identifier for this handler implementation.
func (h *OpenAIAPIHandler) HandlerType() string {
return OpenAI
}
// Models returns the OpenAI-compatible model metadata supported by this handler.
func (h *OpenAIAPIHandler) Models() []map[string]any {
// Get dynamic models from the global registry
modelRegistry := registry.GetGlobalRegistry()
return modelRegistry.GetAvailableModels("openai")
}
// OpenAIModels handles the /v1/models endpoint.
// It returns a list of available AI models with their capabilities
// and specifications in OpenAI-compatible format.
func (h *OpenAIAPIHandler) OpenAIModels(c *gin.Context) {
// Get all available models
allModels := h.Models()
// Filter to only include the 4 required fields: id, object, created, owned_by
filteredModels := make([]map[string]any, len(allModels))
for i, model := range allModels {
filteredModel := map[string]any{
"id": model["id"],
"object": model["object"],
}
// Add created field if it exists
if created, exists := model["created"]; exists {
filteredModel["created"] = created
}
// Add owned_by field if it exists
if ownedBy, exists := model["owned_by"]; exists {
filteredModel["owned_by"] = ownedBy
}
filteredModels[i] = filteredModel
}
c.JSON(http.StatusOK, gin.H{
"object": "list",
"data": filteredModels,
})
}
// ChatCompletions handles the /v1/chat/completions endpoint.
// It determines whether the request is for a streaming or non-streaming response
// and calls the appropriate handler based on the model provider.
//
// Parameters:
// - c: The Gin context containing the HTTP request and response
func (h *OpenAIAPIHandler) ChatCompletions(c *gin.Context) {
rawJSON, err := c.GetRawData()
// If data retrieval fails, return a 400 Bad Request error.
if err != nil {
c.JSON(http.StatusBadRequest, handlers.ErrorResponse{
Error: handlers.ErrorDetail{
Message: fmt.Sprintf("Invalid request: %v", err),
Type: "invalid_request_error",
},
})
return
}
// Check if the client requested a streaming response.
streamResult := gjson.GetBytes(rawJSON, "stream")
if streamResult.Type == gjson.True {
h.handleStreamingResponse(c, rawJSON)
} else {
h.handleNonStreamingResponse(c, rawJSON)
}
}
// Completions handles the /v1/completions endpoint.
// It determines whether the request is for a streaming or non-streaming response
// and calls the appropriate handler based on the model provider.
// This endpoint follows the OpenAI completions API specification.
//
// Parameters:
// - c: The Gin context containing the HTTP request and response
func (h *OpenAIAPIHandler) Completions(c *gin.Context) {
rawJSON, err := c.GetRawData()
// If data retrieval fails, return a 400 Bad Request error.
if err != nil {
c.JSON(http.StatusBadRequest, handlers.ErrorResponse{
Error: handlers.ErrorDetail{
Message: fmt.Sprintf("Invalid request: %v", err),
Type: "invalid_request_error",
},
})
return
}
// Check if the client requested a streaming response.
streamResult := gjson.GetBytes(rawJSON, "stream")
if streamResult.Type == gjson.True {
h.handleCompletionsStreamingResponse(c, rawJSON)
} else {
h.handleCompletionsNonStreamingResponse(c, rawJSON)
}
}
// convertCompletionsRequestToChatCompletions converts OpenAI completions API request to chat completions format.
// This allows the completions endpoint to use the existing chat completions infrastructure.
//
// Parameters:
// - rawJSON: The raw JSON bytes of the completions request
//
// Returns:
// - []byte: The converted chat completions request
func convertCompletionsRequestToChatCompletions(rawJSON []byte) []byte {
root := gjson.ParseBytes(rawJSON)
// Extract prompt from completions request
prompt := root.Get("prompt").String()
if prompt == "" {
prompt = "Complete this:"
}
// Create chat completions structure
out := `{"model":"","messages":[{"role":"user","content":""}]}`
// Set model
if model := root.Get("model"); model.Exists() {
out, _ = sjson.Set(out, "model", model.String())
}
// Set the prompt as user message content
out, _ = sjson.Set(out, "messages.0.content", prompt)
// Copy other parameters from completions to chat completions
if maxTokens := root.Get("max_tokens"); maxTokens.Exists() {
out, _ = sjson.Set(out, "max_tokens", maxTokens.Int())
}
if temperature := root.Get("temperature"); temperature.Exists() {
out, _ = sjson.Set(out, "temperature", temperature.Float())
}
if topP := root.Get("top_p"); topP.Exists() {
out, _ = sjson.Set(out, "top_p", topP.Float())
}
if frequencyPenalty := root.Get("frequency_penalty"); frequencyPenalty.Exists() {
out, _ = sjson.Set(out, "frequency_penalty", frequencyPenalty.Float())
}
if presencePenalty := root.Get("presence_penalty"); presencePenalty.Exists() {
out, _ = sjson.Set(out, "presence_penalty", presencePenalty.Float())
}
if stop := root.Get("stop"); stop.Exists() {
out, _ = sjson.SetRaw(out, "stop", stop.Raw)
}
if stream := root.Get("stream"); stream.Exists() {
out, _ = sjson.Set(out, "stream", stream.Bool())
}
if logprobs := root.Get("logprobs"); logprobs.Exists() {
out, _ = sjson.Set(out, "logprobs", logprobs.Bool())
}
if topLogprobs := root.Get("top_logprobs"); topLogprobs.Exists() {
out, _ = sjson.Set(out, "top_logprobs", topLogprobs.Int())
}
if echo := root.Get("echo"); echo.Exists() {
out, _ = sjson.Set(out, "echo", echo.Bool())
}
return []byte(out)
}
// convertChatCompletionsResponseToCompletions converts chat completions API response back to completions format.
// This ensures the completions endpoint returns data in the expected format.
//
// Parameters:
// - rawJSON: The raw JSON bytes of the chat completions response
//
// Returns:
// - []byte: The converted completions response
func convertChatCompletionsResponseToCompletions(rawJSON []byte) []byte {
root := gjson.ParseBytes(rawJSON)
// Base completions response structure
out := `{"id":"","object":"text_completion","created":0,"model":"","choices":[]}`
// Copy basic fields
if id := root.Get("id"); id.Exists() {
out, _ = sjson.Set(out, "id", id.String())
}
if created := root.Get("created"); created.Exists() {
out, _ = sjson.Set(out, "created", created.Int())
}
if model := root.Get("model"); model.Exists() {
out, _ = sjson.Set(out, "model", model.String())
}
if usage := root.Get("usage"); usage.Exists() {
out, _ = sjson.SetRaw(out, "usage", usage.Raw)
}
// Convert choices from chat completions to completions format
var choices []interface{}
if chatChoices := root.Get("choices"); chatChoices.Exists() && chatChoices.IsArray() {
chatChoices.ForEach(func(_, choice gjson.Result) bool {
completionsChoice := map[string]interface{}{
"index": choice.Get("index").Int(),
}
// Extract text content from message.content
if message := choice.Get("message"); message.Exists() {
if content := message.Get("content"); content.Exists() {
completionsChoice["text"] = content.String()
}
} else if delta := choice.Get("delta"); delta.Exists() {
// For streaming responses, use delta.content
if content := delta.Get("content"); content.Exists() {
completionsChoice["text"] = content.String()
}
}
// Copy finish_reason
if finishReason := choice.Get("finish_reason"); finishReason.Exists() {
completionsChoice["finish_reason"] = finishReason.String()
}
// Copy logprobs if present
if logprobs := choice.Get("logprobs"); logprobs.Exists() {
completionsChoice["logprobs"] = logprobs.Value()
}
choices = append(choices, completionsChoice)
return true
})
}
if len(choices) > 0 {
choicesJSON, _ := json.Marshal(choices)
out, _ = sjson.SetRaw(out, "choices", string(choicesJSON))
}
return []byte(out)
}
// convertChatCompletionsStreamChunkToCompletions converts a streaming chat completions chunk to completions format.
// This handles the real-time conversion of streaming response chunks and filters out empty text responses.
//
// Parameters:
// - chunkData: The raw JSON bytes of a single chat completions stream chunk
//
// Returns:
// - []byte: The converted completions stream chunk, or nil if should be filtered out
func convertChatCompletionsStreamChunkToCompletions(chunkData []byte) []byte {
root := gjson.ParseBytes(chunkData)
// Check if this chunk has any meaningful content
hasContent := false
if chatChoices := root.Get("choices"); chatChoices.Exists() && chatChoices.IsArray() {
chatChoices.ForEach(func(_, choice gjson.Result) bool {
// Check if delta has content or finish_reason
if delta := choice.Get("delta"); delta.Exists() {
if content := delta.Get("content"); content.Exists() && content.String() != "" {
hasContent = true
return false // Break out of forEach
}
}
// Also check for finish_reason to ensure we don't skip final chunks
if finishReason := choice.Get("finish_reason"); finishReason.Exists() && finishReason.String() != "" && finishReason.String() != "null" {
hasContent = true
return false // Break out of forEach
}
return true
})
}
// If no meaningful content, return nil to indicate this chunk should be skipped
if !hasContent {
return nil
}
// Base completions stream response structure
out := `{"id":"","object":"text_completion","created":0,"model":"","choices":[]}`
// Copy basic fields
if id := root.Get("id"); id.Exists() {
out, _ = sjson.Set(out, "id", id.String())
}
if created := root.Get("created"); created.Exists() {
out, _ = sjson.Set(out, "created", created.Int())
}
if model := root.Get("model"); model.Exists() {
out, _ = sjson.Set(out, "model", model.String())
}
// Convert choices from chat completions delta to completions format
var choices []interface{}
if chatChoices := root.Get("choices"); chatChoices.Exists() && chatChoices.IsArray() {
chatChoices.ForEach(func(_, choice gjson.Result) bool {
completionsChoice := map[string]interface{}{
"index": choice.Get("index").Int(),
}
// Extract text content from delta.content
if delta := choice.Get("delta"); delta.Exists() {
if content := delta.Get("content"); content.Exists() && content.String() != "" {
completionsChoice["text"] = content.String()
} else {
completionsChoice["text"] = ""
}
} else {
completionsChoice["text"] = ""
}
// Copy finish_reason
if finishReason := choice.Get("finish_reason"); finishReason.Exists() && finishReason.String() != "null" {
completionsChoice["finish_reason"] = finishReason.String()
}
// Copy logprobs if present
if logprobs := choice.Get("logprobs"); logprobs.Exists() {
completionsChoice["logprobs"] = logprobs.Value()
}
choices = append(choices, completionsChoice)
return true
})
}
if len(choices) > 0 {
choicesJSON, _ := json.Marshal(choices)
out, _ = sjson.SetRaw(out, "choices", string(choicesJSON))
}
return []byte(out)
}
// handleNonStreamingResponse handles non-streaming chat completion responses
// for Gemini models. It selects a client from the pool, sends the request, and
// aggregates the response before sending it back to the client in OpenAI format.
//
// Parameters:
// - c: The Gin context containing the HTTP request and response
// - rawJSON: The raw JSON bytes of the OpenAI-compatible request
func (h *OpenAIAPIHandler) handleNonStreamingResponse(c *gin.Context, rawJSON []byte) {
c.Header("Content-Type", "application/json")
modelName := gjson.GetBytes(rawJSON, "model").String()
cliCtx, cliCancel := h.GetContextWithCancel(h, c, context.Background())
resp, errMsg := h.ExecuteWithAuthManager(cliCtx, h.HandlerType(), modelName, rawJSON, h.GetAlt(c))
if errMsg != nil {
h.WriteErrorResponse(c, errMsg)
cliCancel(errMsg.Error)
return
}
_, _ = c.Writer.Write(resp)
cliCancel()
}
// handleStreamingResponse handles streaming responses for Gemini models.
// It establishes a streaming connection with the backend service and forwards
// the response chunks to the client in real-time using Server-Sent Events.
//
// Parameters:
// - c: The Gin context containing the HTTP request and response
// - rawJSON: The raw JSON bytes of the OpenAI-compatible request
func (h *OpenAIAPIHandler) handleStreamingResponse(c *gin.Context, rawJSON []byte) {
c.Header("Content-Type", "text/event-stream")
c.Header("Cache-Control", "no-cache")
c.Header("Connection", "keep-alive")
c.Header("Access-Control-Allow-Origin", "*")
// Get the http.Flusher interface to manually flush the response.
flusher, ok := c.Writer.(http.Flusher)
if !ok {
c.JSON(http.StatusInternalServerError, handlers.ErrorResponse{
Error: handlers.ErrorDetail{
Message: "Streaming not supported",
Type: "server_error",
},
})
return
}
modelName := gjson.GetBytes(rawJSON, "model").String()
cliCtx, cliCancel := h.GetContextWithCancel(h, c, context.Background())
dataChan, errChan := h.ExecuteStreamWithAuthManager(cliCtx, h.HandlerType(), modelName, rawJSON, h.GetAlt(c))
h.handleStreamResult(c, flusher, func(err error) { cliCancel(err) }, dataChan, errChan)
}
// handleCompletionsNonStreamingResponse handles non-streaming completions responses.
// It converts completions request to chat completions format, sends to backend,
// then converts the response back to completions format before sending to client.
//
// Parameters:
// - c: The Gin context containing the HTTP request and response
// - rawJSON: The raw JSON bytes of the OpenAI-compatible completions request
func (h *OpenAIAPIHandler) handleCompletionsNonStreamingResponse(c *gin.Context, rawJSON []byte) {
c.Header("Content-Type", "application/json")
// Convert completions request to chat completions format
chatCompletionsJSON := convertCompletionsRequestToChatCompletions(rawJSON)
modelName := gjson.GetBytes(chatCompletionsJSON, "model").String()
cliCtx, cliCancel := h.GetContextWithCancel(h, c, context.Background())
resp, errMsg := h.ExecuteWithAuthManager(cliCtx, h.HandlerType(), modelName, chatCompletionsJSON, "")
if errMsg != nil {
h.WriteErrorResponse(c, errMsg)
cliCancel(errMsg.Error)
return
}
completionsResp := convertChatCompletionsResponseToCompletions(resp)
_, _ = c.Writer.Write(completionsResp)
cliCancel()
}
// handleCompletionsStreamingResponse handles streaming completions responses.
// It converts completions request to chat completions format, streams from backend,
// then converts each response chunk back to completions format before sending to client.
//
// Parameters:
// - c: The Gin context containing the HTTP request and response
// - rawJSON: The raw JSON bytes of the OpenAI-compatible completions request
func (h *OpenAIAPIHandler) handleCompletionsStreamingResponse(c *gin.Context, rawJSON []byte) {
c.Header("Content-Type", "text/event-stream")
c.Header("Cache-Control", "no-cache")
c.Header("Connection", "keep-alive")
c.Header("Access-Control-Allow-Origin", "*")
// Get the http.Flusher interface to manually flush the response.
flusher, ok := c.Writer.(http.Flusher)
if !ok {
c.JSON(http.StatusInternalServerError, handlers.ErrorResponse{
Error: handlers.ErrorDetail{
Message: "Streaming not supported",
Type: "server_error",
},
})
return
}
// Convert completions request to chat completions format
chatCompletionsJSON := convertCompletionsRequestToChatCompletions(rawJSON)
modelName := gjson.GetBytes(chatCompletionsJSON, "model").String()
cliCtx, cliCancel := h.GetContextWithCancel(h, c, context.Background())
dataChan, errChan := h.ExecuteStreamWithAuthManager(cliCtx, h.HandlerType(), modelName, chatCompletionsJSON, "")
for {
select {
case <-c.Request.Context().Done():
cliCancel(c.Request.Context().Err())
return
case chunk, isOk := <-dataChan:
if !isOk {
_, _ = fmt.Fprintf(c.Writer, "data: [DONE]\n\n")
flusher.Flush()
cliCancel()
return
}
converted := convertChatCompletionsStreamChunkToCompletions(chunk)
if converted != nil {
_, _ = fmt.Fprintf(c.Writer, "data: %s\n\n", string(converted))
flusher.Flush()
}
case errMsg, isOk := <-errChan:
if !isOk {
continue
}
if errMsg != nil {
h.WriteErrorResponse(c, errMsg)
flusher.Flush()
}
var execErr error
if errMsg != nil {
execErr = errMsg.Error
}
cliCancel(execErr)
return
case <-time.After(500 * time.Millisecond):
}
}
}
func (h *OpenAIAPIHandler) handleStreamResult(c *gin.Context, flusher http.Flusher, cancel func(error), data <-chan []byte, errs <-chan *interfaces.ErrorMessage) {
for {
select {
case <-c.Request.Context().Done():
cancel(c.Request.Context().Err())
return
case chunk, ok := <-data:
if !ok {
_, _ = fmt.Fprintf(c.Writer, "data: [DONE]\n\n")
flusher.Flush()
cancel(nil)
return
}
_, _ = fmt.Fprintf(c.Writer, "data: %s\n\n", string(chunk))
flusher.Flush()
case errMsg, ok := <-errs:
if !ok {
continue
}
if errMsg != nil {
h.WriteErrorResponse(c, errMsg)
flusher.Flush()
}
var execErr error
if errMsg != nil {
execErr = errMsg.Error
}
cancel(execErr)
return
case <-time.After(500 * time.Millisecond):
}
}
}

View File

@@ -0,0 +1,194 @@
// Package openai provides HTTP handlers for OpenAIResponses API endpoints.
// This package implements the OpenAIResponses-compatible API interface, including model listing
// and chat completion functionality. It supports both streaming and non-streaming responses,
// and manages a pool of clients to interact with backend services.
// The handlers translate OpenAIResponses API requests to the appropriate backend format and
// convert responses back to OpenAIResponses-compatible format.
package openai
import (
"bytes"
"context"
"fmt"
"net/http"
"time"
"github.com/gin-gonic/gin"
"github.com/router-for-me/CLIProxyAPI/v6/internal/api/handlers"
. "github.com/router-for-me/CLIProxyAPI/v6/internal/constant"
"github.com/router-for-me/CLIProxyAPI/v6/internal/interfaces"
"github.com/router-for-me/CLIProxyAPI/v6/internal/registry"
"github.com/tidwall/gjson"
)
// OpenAIResponsesAPIHandler contains the handlers for OpenAIResponses API endpoints.
// It holds a pool of clients to interact with the backend service.
type OpenAIResponsesAPIHandler struct {
*handlers.BaseAPIHandler
}
// NewOpenAIResponsesAPIHandler creates a new OpenAIResponses API handlers instance.
// It takes an BaseAPIHandler instance as input and returns an OpenAIResponsesAPIHandler.
//
// Parameters:
// - apiHandlers: The base API handlers instance
//
// Returns:
// - *OpenAIResponsesAPIHandler: A new OpenAIResponses API handlers instance
func NewOpenAIResponsesAPIHandler(apiHandlers *handlers.BaseAPIHandler) *OpenAIResponsesAPIHandler {
return &OpenAIResponsesAPIHandler{
BaseAPIHandler: apiHandlers,
}
}
// HandlerType returns the identifier for this handler implementation.
func (h *OpenAIResponsesAPIHandler) HandlerType() string {
return OpenaiResponse
}
// Models returns the OpenAIResponses-compatible model metadata supported by this handler.
func (h *OpenAIResponsesAPIHandler) Models() []map[string]any {
// Get dynamic models from the global registry
modelRegistry := registry.GetGlobalRegistry()
return modelRegistry.GetAvailableModels("openai")
}
// OpenAIResponsesModels handles the /v1/models endpoint.
// It returns a list of available AI models with their capabilities
// and specifications in OpenAIResponses-compatible format.
func (h *OpenAIResponsesAPIHandler) OpenAIResponsesModels(c *gin.Context) {
c.JSON(http.StatusOK, gin.H{
"object": "list",
"data": h.Models(),
})
}
// Responses handles the /v1/responses endpoint.
// It determines whether the request is for a streaming or non-streaming response
// and calls the appropriate handler based on the model provider.
//
// Parameters:
// - c: The Gin context containing the HTTP request and response
func (h *OpenAIResponsesAPIHandler) Responses(c *gin.Context) {
rawJSON, err := c.GetRawData()
// If data retrieval fails, return a 400 Bad Request error.
if err != nil {
c.JSON(http.StatusBadRequest, handlers.ErrorResponse{
Error: handlers.ErrorDetail{
Message: fmt.Sprintf("Invalid request: %v", err),
Type: "invalid_request_error",
},
})
return
}
// Check if the client requested a streaming response.
streamResult := gjson.GetBytes(rawJSON, "stream")
if streamResult.Type == gjson.True {
h.handleStreamingResponse(c, rawJSON)
} else {
h.handleNonStreamingResponse(c, rawJSON)
}
}
// handleNonStreamingResponse handles non-streaming chat completion responses
// for Gemini models. It selects a client from the pool, sends the request, and
// aggregates the response before sending it back to the client in OpenAIResponses format.
//
// Parameters:
// - c: The Gin context containing the HTTP request and response
// - rawJSON: The raw JSON bytes of the OpenAIResponses-compatible request
func (h *OpenAIResponsesAPIHandler) handleNonStreamingResponse(c *gin.Context, rawJSON []byte) {
c.Header("Content-Type", "application/json")
modelName := gjson.GetBytes(rawJSON, "model").String()
cliCtx, cliCancel := h.GetContextWithCancel(h, c, context.Background())
defer func() {
cliCancel()
}()
resp, errMsg := h.ExecuteWithAuthManager(cliCtx, h.HandlerType(), modelName, rawJSON, "")
if errMsg != nil {
h.WriteErrorResponse(c, errMsg)
return
}
_, _ = c.Writer.Write(resp)
return
// no legacy fallback
}
// handleStreamingResponse handles streaming responses for Gemini models.
// It establishes a streaming connection with the backend service and forwards
// the response chunks to the client in real-time using Server-Sent Events.
//
// Parameters:
// - c: The Gin context containing the HTTP request and response
// - rawJSON: The raw JSON bytes of the OpenAIResponses-compatible request
func (h *OpenAIResponsesAPIHandler) handleStreamingResponse(c *gin.Context, rawJSON []byte) {
c.Header("Content-Type", "text/event-stream")
c.Header("Cache-Control", "no-cache")
c.Header("Connection", "keep-alive")
c.Header("Access-Control-Allow-Origin", "*")
// Get the http.Flusher interface to manually flush the response.
flusher, ok := c.Writer.(http.Flusher)
if !ok {
c.JSON(http.StatusInternalServerError, handlers.ErrorResponse{
Error: handlers.ErrorDetail{
Message: "Streaming not supported",
Type: "server_error",
},
})
return
}
// New core execution path
modelName := gjson.GetBytes(rawJSON, "model").String()
cliCtx, cliCancel := h.GetContextWithCancel(h, c, context.Background())
dataChan, errChan := h.ExecuteStreamWithAuthManager(cliCtx, h.HandlerType(), modelName, rawJSON, "")
h.forwardResponsesStream(c, flusher, func(err error) { cliCancel(err) }, dataChan, errChan)
return
}
func (h *OpenAIResponsesAPIHandler) forwardResponsesStream(c *gin.Context, flusher http.Flusher, cancel func(error), data <-chan []byte, errs <-chan *interfaces.ErrorMessage) {
for {
select {
case <-c.Request.Context().Done():
cancel(c.Request.Context().Err())
return
case chunk, ok := <-data:
if !ok {
_, _ = c.Writer.Write([]byte("\n"))
flusher.Flush()
cancel(nil)
return
}
if bytes.HasPrefix(chunk, []byte("event:")) {
_, _ = c.Writer.Write([]byte("\n"))
}
_, _ = c.Writer.Write(chunk)
_, _ = c.Writer.Write([]byte("\n"))
flusher.Flush()
case errMsg, ok := <-errs:
if !ok {
continue
}
if errMsg != nil {
h.WriteErrorResponse(c, errMsg)
flusher.Flush()
}
var execErr error
if errMsg != nil {
execErr = errMsg.Error
}
cancel(execErr)
return
case <-time.After(500 * time.Millisecond):
}
}
}

View File

@@ -0,0 +1,92 @@
// Package middleware provides HTTP middleware components for the CLI Proxy API server.
// This file contains the request logging middleware that captures comprehensive
// request and response data when enabled through configuration.
package middleware
import (
"bytes"
"io"
"github.com/gin-gonic/gin"
"github.com/router-for-me/CLIProxyAPI/v6/internal/logging"
)
// RequestLoggingMiddleware creates a Gin middleware that logs HTTP requests and responses.
// It captures detailed information about the request and response, including headers and body,
// and uses the provided RequestLogger to record this data. If logging is disabled in the
// logger, the middleware has minimal overhead.
func RequestLoggingMiddleware(logger logging.RequestLogger) gin.HandlerFunc {
return func(c *gin.Context) {
// Early return if logging is disabled (zero overhead)
if !logger.IsEnabled() {
c.Next()
return
}
// Capture request information
requestInfo, err := captureRequestInfo(c)
if err != nil {
// Log error but continue processing
// In a real implementation, you might want to use a proper logger here
c.Next()
return
}
// Create response writer wrapper
wrapper := NewResponseWriterWrapper(c.Writer, logger, requestInfo)
c.Writer = wrapper
// Process the request
c.Next()
// Finalize logging after request processing
if err = wrapper.Finalize(c); err != nil {
// Log error but don't interrupt the response
// In a real implementation, you might want to use a proper logger here
}
}
}
// captureRequestInfo extracts relevant information from the incoming HTTP request.
// It captures the URL, method, headers, and body. The request body is read and then
// restored so that it can be processed by subsequent handlers.
func captureRequestInfo(c *gin.Context) (*RequestInfo, error) {
// Capture URL
url := c.Request.URL.String()
if c.Request.URL.Path != "" {
url = c.Request.URL.Path
if c.Request.URL.RawQuery != "" {
url += "?" + c.Request.URL.RawQuery
}
}
// Capture method
method := c.Request.Method
// Capture headers
headers := make(map[string][]string)
for key, values := range c.Request.Header {
headers[key] = values
}
// Capture request body
var body []byte
if c.Request.Body != nil {
// Read the body
bodyBytes, err := io.ReadAll(c.Request.Body)
if err != nil {
return nil, err
}
// Restore the body for the actual request processing
c.Request.Body = io.NopCloser(bytes.NewBuffer(bodyBytes))
body = bodyBytes
}
return &RequestInfo{
URL: url,
Method: method,
Headers: headers,
Body: body,
}, nil
}

View File

@@ -0,0 +1,309 @@
// Package middleware provides Gin HTTP middleware for the CLI Proxy API server.
// It includes a sophisticated response writer wrapper designed to capture and log request and response data,
// including support for streaming responses, without impacting latency.
package middleware
import (
"bytes"
"strings"
"github.com/gin-gonic/gin"
"github.com/router-for-me/CLIProxyAPI/v6/internal/interfaces"
"github.com/router-for-me/CLIProxyAPI/v6/internal/logging"
)
// RequestInfo holds essential details of an incoming HTTP request for logging purposes.
type RequestInfo struct {
URL string // URL is the request URL.
Method string // Method is the HTTP method (e.g., GET, POST).
Headers map[string][]string // Headers contains the request headers.
Body []byte // Body is the raw request body.
}
// ResponseWriterWrapper wraps the standard gin.ResponseWriter to intercept and log response data.
// It is designed to handle both standard and streaming responses, ensuring that logging operations do not block the client response.
type ResponseWriterWrapper struct {
gin.ResponseWriter
body *bytes.Buffer // body is a buffer to store the response body for non-streaming responses.
isStreaming bool // isStreaming indicates whether the response is a streaming type (e.g., text/event-stream).
streamWriter logging.StreamingLogWriter // streamWriter is a writer for handling streaming log entries.
chunkChannel chan []byte // chunkChannel is a channel for asynchronously passing response chunks to the logger.
streamDone chan struct{} // streamDone signals when the streaming goroutine completes.
logger logging.RequestLogger // logger is the instance of the request logger service.
requestInfo *RequestInfo // requestInfo holds the details of the original request.
statusCode int // statusCode stores the HTTP status code of the response.
headers map[string][]string // headers stores the response headers.
}
// NewResponseWriterWrapper creates and initializes a new ResponseWriterWrapper.
// It takes the original gin.ResponseWriter, a logger instance, and request information.
//
// Parameters:
// - w: The original gin.ResponseWriter to wrap.
// - logger: The logging service to use for recording requests.
// - requestInfo: The pre-captured information about the incoming request.
//
// Returns:
// - A pointer to a new ResponseWriterWrapper.
func NewResponseWriterWrapper(w gin.ResponseWriter, logger logging.RequestLogger, requestInfo *RequestInfo) *ResponseWriterWrapper {
return &ResponseWriterWrapper{
ResponseWriter: w,
body: &bytes.Buffer{},
logger: logger,
requestInfo: requestInfo,
headers: make(map[string][]string),
}
}
// Write wraps the underlying ResponseWriter's Write method to capture response data.
// For non-streaming responses, it writes to an internal buffer. For streaming responses,
// it sends data chunks to a non-blocking channel for asynchronous logging.
// CRITICAL: This method prioritizes writing to the client to ensure zero latency,
// handling logging operations subsequently.
func (w *ResponseWriterWrapper) Write(data []byte) (int, error) {
// Ensure headers are captured before first write
// This is critical because Write() may trigger WriteHeader() internally
w.ensureHeadersCaptured()
// CRITICAL: Write to client first (zero latency)
n, err := w.ResponseWriter.Write(data)
// THEN: Handle logging based on response type
if w.isStreaming {
// For streaming responses: Send to async logging channel (non-blocking)
if w.chunkChannel != nil {
select {
case w.chunkChannel <- append([]byte(nil), data...): // Non-blocking send with copy
default: // Channel full, skip logging to avoid blocking
}
}
} else {
// For non-streaming responses: Buffer complete response
w.body.Write(data)
}
return n, err
}
// WriteHeader wraps the underlying ResponseWriter's WriteHeader method.
// It captures the status code, detects if the response is streaming based on the Content-Type header,
// and initializes the appropriate logging mechanism (standard or streaming).
func (w *ResponseWriterWrapper) WriteHeader(statusCode int) {
w.statusCode = statusCode
// Capture response headers using the new method
w.captureCurrentHeaders()
// Detect streaming based on Content-Type
contentType := w.ResponseWriter.Header().Get("Content-Type")
w.isStreaming = w.detectStreaming(contentType)
// If streaming, initialize streaming log writer
if w.isStreaming && w.logger.IsEnabled() {
streamWriter, err := w.logger.LogStreamingRequest(
w.requestInfo.URL,
w.requestInfo.Method,
w.requestInfo.Headers,
w.requestInfo.Body,
)
if err == nil {
w.streamWriter = streamWriter
w.chunkChannel = make(chan []byte, 100) // Buffered channel for async writes
doneChan := make(chan struct{})
w.streamDone = doneChan
// Start async chunk processor
go w.processStreamingChunks(doneChan)
// Write status immediately
_ = streamWriter.WriteStatus(statusCode, w.headers)
}
}
// Call original WriteHeader
w.ResponseWriter.WriteHeader(statusCode)
}
// ensureHeadersCaptured is a helper function to make sure response headers are captured.
// It is safe to call this method multiple times; it will always refresh the headers
// with the latest state from the underlying ResponseWriter.
func (w *ResponseWriterWrapper) ensureHeadersCaptured() {
// Always capture the current headers to ensure we have the latest state
w.captureCurrentHeaders()
}
// captureCurrentHeaders reads all headers from the underlying ResponseWriter and stores them
// in the wrapper's headers map. It creates copies of the header values to prevent race conditions.
func (w *ResponseWriterWrapper) captureCurrentHeaders() {
// Initialize headers map if needed
if w.headers == nil {
w.headers = make(map[string][]string)
}
// Capture all current headers from the underlying ResponseWriter
for key, values := range w.ResponseWriter.Header() {
// Make a copy of the values slice to avoid reference issues
headerValues := make([]string, len(values))
copy(headerValues, values)
w.headers[key] = headerValues
}
}
// detectStreaming determines if a response should be treated as a streaming response.
// It checks for a "text/event-stream" Content-Type or a '"stream": true'
// field in the original request body.
func (w *ResponseWriterWrapper) detectStreaming(contentType string) bool {
// Check Content-Type for Server-Sent Events
if strings.Contains(contentType, "text/event-stream") {
return true
}
// Check request body for streaming indicators
if w.requestInfo.Body != nil {
bodyStr := string(w.requestInfo.Body)
if strings.Contains(bodyStr, `"stream": true`) || strings.Contains(bodyStr, `"stream":true`) {
return true
}
}
return false
}
// processStreamingChunks runs in a separate goroutine to process response chunks from the chunkChannel.
// It asynchronously writes each chunk to the streaming log writer.
func (w *ResponseWriterWrapper) processStreamingChunks(done chan struct{}) {
if done == nil {
return
}
defer close(done)
if w.streamWriter == nil || w.chunkChannel == nil {
return
}
for chunk := range w.chunkChannel {
w.streamWriter.WriteChunkAsync(chunk)
}
}
// Finalize completes the logging process for the request and response.
// For streaming responses, it closes the chunk channel and the stream writer.
// For non-streaming responses, it logs the complete request and response details,
// including any API-specific request/response data stored in the Gin context.
func (w *ResponseWriterWrapper) Finalize(c *gin.Context) error {
if !w.logger.IsEnabled() {
return nil
}
if w.isStreaming {
// Close streaming channel and writer
if w.chunkChannel != nil {
close(w.chunkChannel)
w.chunkChannel = nil
}
if w.streamDone != nil {
<-w.streamDone
w.streamDone = nil
}
if w.streamWriter != nil {
err := w.streamWriter.Close()
w.streamWriter = nil
return err
}
} else {
// Capture final status code and headers if not already captured
finalStatusCode := w.statusCode
if finalStatusCode == 0 {
// Get status from underlying ResponseWriter if available
if statusWriter, ok := w.ResponseWriter.(interface{ Status() int }); ok {
finalStatusCode = statusWriter.Status()
} else {
finalStatusCode = 200 // Default
}
}
// Ensure we have the latest headers before finalizing
w.ensureHeadersCaptured()
// Use the captured headers as the final headers
finalHeaders := make(map[string][]string)
for key, values := range w.headers {
// Make a copy of the values slice to avoid reference issues
headerValues := make([]string, len(values))
copy(headerValues, values)
finalHeaders[key] = headerValues
}
var apiRequestBody []byte
apiRequest, isExist := c.Get("API_REQUEST")
if isExist {
var ok bool
apiRequestBody, ok = apiRequest.([]byte)
if !ok {
apiRequestBody = nil
}
}
var apiResponseBody []byte
apiResponse, isExist := c.Get("API_RESPONSE")
if isExist {
var ok bool
apiResponseBody, ok = apiResponse.([]byte)
if !ok {
apiResponseBody = nil
}
}
var slicesAPIResponseError []*interfaces.ErrorMessage
apiResponseError, isExist := c.Get("API_RESPONSE_ERROR")
if isExist {
var ok bool
slicesAPIResponseError, ok = apiResponseError.([]*interfaces.ErrorMessage)
if !ok {
slicesAPIResponseError = nil
}
}
// Log complete non-streaming response
return w.logger.LogRequest(
w.requestInfo.URL,
w.requestInfo.Method,
w.requestInfo.Headers,
w.requestInfo.Body,
finalStatusCode,
finalHeaders,
w.body.Bytes(),
apiRequestBody,
apiResponseBody,
slicesAPIResponseError,
)
}
return nil
}
// Status returns the HTTP response status code captured by the wrapper.
// It defaults to 200 if WriteHeader has not been called.
func (w *ResponseWriterWrapper) Status() int {
if w.statusCode == 0 {
return 200 // Default status code
}
return w.statusCode
}
// Size returns the size of the response body in bytes for non-streaming responses.
// For streaming responses, it returns -1, as the total size is unknown.
func (w *ResponseWriterWrapper) Size() int {
if w.isStreaming {
return -1 // Unknown size for streaming responses
}
return w.body.Len()
}
// Written returns true if the response header has been written (i.e., a status code has been set).
func (w *ResponseWriterWrapper) Written() bool {
return w.statusCode != 0
}

516
internal/api/server.go Normal file
View File

@@ -0,0 +1,516 @@
// Package api provides the HTTP API server implementation for the CLI Proxy API.
// It includes the main server struct, routing setup, middleware for CORS and authentication,
// and integration with various AI API handlers (OpenAI, Claude, Gemini).
// The server supports hot-reloading of clients and configuration.
package api
import (
"context"
"errors"
"fmt"
"net/http"
"os"
"path/filepath"
"strings"
"github.com/gin-gonic/gin"
"github.com/router-for-me/CLIProxyAPI/v6/internal/api/handlers"
"github.com/router-for-me/CLIProxyAPI/v6/internal/api/handlers/claude"
"github.com/router-for-me/CLIProxyAPI/v6/internal/api/handlers/gemini"
managementHandlers "github.com/router-for-me/CLIProxyAPI/v6/internal/api/handlers/management"
"github.com/router-for-me/CLIProxyAPI/v6/internal/api/handlers/openai"
"github.com/router-for-me/CLIProxyAPI/v6/internal/api/middleware"
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
"github.com/router-for-me/CLIProxyAPI/v6/internal/logging"
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
sdkaccess "github.com/router-for-me/CLIProxyAPI/v6/sdk/access"
"github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
log "github.com/sirupsen/logrus"
)
type serverOptionConfig struct {
extraMiddleware []gin.HandlerFunc
engineConfigurator func(*gin.Engine)
routerConfigurator func(*gin.Engine, *handlers.BaseAPIHandler, *config.Config)
requestLoggerFactory func(*config.Config, string) logging.RequestLogger
}
// ServerOption customises HTTP server construction.
type ServerOption func(*serverOptionConfig)
func defaultRequestLoggerFactory(cfg *config.Config, configPath string) logging.RequestLogger {
return logging.NewFileRequestLogger(cfg.RequestLog, "logs", filepath.Dir(configPath))
}
// WithMiddleware appends additional Gin middleware during server construction.
func WithMiddleware(mw ...gin.HandlerFunc) ServerOption {
return func(cfg *serverOptionConfig) {
cfg.extraMiddleware = append(cfg.extraMiddleware, mw...)
}
}
// WithEngineConfigurator allows callers to mutate the Gin engine prior to middleware setup.
func WithEngineConfigurator(fn func(*gin.Engine)) ServerOption {
return func(cfg *serverOptionConfig) {
cfg.engineConfigurator = fn
}
}
// WithRouterConfigurator appends a callback after default routes are registered.
func WithRouterConfigurator(fn func(*gin.Engine, *handlers.BaseAPIHandler, *config.Config)) ServerOption {
return func(cfg *serverOptionConfig) {
cfg.routerConfigurator = fn
}
}
// WithRequestLoggerFactory customises request logger creation.
func WithRequestLoggerFactory(factory func(*config.Config, string) logging.RequestLogger) ServerOption {
return func(cfg *serverOptionConfig) {
cfg.requestLoggerFactory = factory
}
}
// Server represents the main API server.
// It encapsulates the Gin engine, HTTP server, handlers, and configuration.
type Server struct {
// engine is the Gin web framework engine instance.
engine *gin.Engine
// server is the underlying HTTP server.
server *http.Server
// handlers contains the API handlers for processing requests.
handlers *handlers.BaseAPIHandler
// cfg holds the current server configuration.
cfg *config.Config
// accessManager handles request authentication providers.
accessManager *sdkaccess.Manager
// requestLogger is the request logger instance for dynamic configuration updates.
requestLogger logging.RequestLogger
loggerToggle func(bool)
// configFilePath is the absolute path to the YAML config file for persistence.
configFilePath string
// management handler
mgmt *managementHandlers.Handler
}
// NewServer creates and initializes a new API server instance.
// It sets up the Gin engine, middleware, routes, and handlers.
//
// Parameters:
// - cfg: The server configuration
// - authManager: core runtime auth manager
// - accessManager: request authentication manager
//
// Returns:
// - *Server: A new server instance
func NewServer(cfg *config.Config, authManager *auth.Manager, accessManager *sdkaccess.Manager, configFilePath string, opts ...ServerOption) *Server {
optionState := &serverOptionConfig{
requestLoggerFactory: defaultRequestLoggerFactory,
}
for i := range opts {
opts[i](optionState)
}
// Set gin mode
if !cfg.Debug {
gin.SetMode(gin.ReleaseMode)
}
// Create gin engine
engine := gin.New()
if optionState.engineConfigurator != nil {
optionState.engineConfigurator(engine)
}
// Add middleware
engine.Use(logging.GinLogrusLogger())
engine.Use(logging.GinLogrusRecovery())
for _, mw := range optionState.extraMiddleware {
engine.Use(mw)
}
// Add request logging middleware (positioned after recovery, before auth)
// Resolve logs directory relative to the configuration file directory.
var requestLogger logging.RequestLogger
var toggle func(bool)
if optionState.requestLoggerFactory != nil {
requestLogger = optionState.requestLoggerFactory(cfg, configFilePath)
}
if requestLogger != nil {
engine.Use(middleware.RequestLoggingMiddleware(requestLogger))
if setter, ok := requestLogger.(interface{ SetEnabled(bool) }); ok {
toggle = setter.SetEnabled
}
}
engine.Use(corsMiddleware())
// Create server instance
s := &Server{
engine: engine,
handlers: handlers.NewBaseAPIHandlers(cfg, authManager),
cfg: cfg,
accessManager: accessManager,
requestLogger: requestLogger,
loggerToggle: toggle,
configFilePath: configFilePath,
}
s.applyAccessConfig(cfg)
// Initialize management handler
s.mgmt = managementHandlers.NewHandler(cfg, configFilePath, authManager)
// Setup routes
s.setupRoutes()
if optionState.routerConfigurator != nil {
optionState.routerConfigurator(engine, s.handlers, cfg)
}
// Create HTTP server
s.server = &http.Server{
Addr: fmt.Sprintf(":%d", cfg.Port),
Handler: engine,
}
return s
}
// setupRoutes configures the API routes for the server.
// It defines the endpoints and associates them with their respective handlers.
func (s *Server) setupRoutes() {
openaiHandlers := openai.NewOpenAIAPIHandler(s.handlers)
geminiHandlers := gemini.NewGeminiAPIHandler(s.handlers)
geminiCLIHandlers := gemini.NewGeminiCLIAPIHandler(s.handlers)
claudeCodeHandlers := claude.NewClaudeCodeAPIHandler(s.handlers)
openaiResponsesHandlers := openai.NewOpenAIResponsesAPIHandler(s.handlers)
// OpenAI compatible API routes
v1 := s.engine.Group("/v1")
v1.Use(AuthMiddleware(s.accessManager))
{
v1.GET("/models", s.unifiedModelsHandler(openaiHandlers, claudeCodeHandlers))
v1.POST("/chat/completions", openaiHandlers.ChatCompletions)
v1.POST("/completions", openaiHandlers.Completions)
v1.POST("/messages", claudeCodeHandlers.ClaudeMessages)
v1.POST("/messages/count_tokens", claudeCodeHandlers.ClaudeCountTokens)
v1.POST("/responses", openaiResponsesHandlers.Responses)
}
// Gemini compatible API routes
v1beta := s.engine.Group("/v1beta")
v1beta.Use(AuthMiddleware(s.accessManager))
{
v1beta.GET("/models", geminiHandlers.GeminiModels)
v1beta.POST("/models/:action", geminiHandlers.GeminiHandler)
v1beta.GET("/models/:action", geminiHandlers.GeminiGetHandler)
}
// Root endpoint
s.engine.GET("/", func(c *gin.Context) {
c.JSON(http.StatusOK, gin.H{
"message": "CLI Proxy API Server",
"version": "1.0.0",
"endpoints": []string{
"POST /v1/chat/completions",
"POST /v1/completions",
"GET /v1/models",
},
})
})
s.engine.POST("/v1internal:method", geminiCLIHandlers.CLIHandler)
// OAuth callback endpoints (reuse main server port)
// These endpoints receive provider redirects and persist
// the short-lived code/state for the waiting goroutine.
s.engine.GET("/anthropic/callback", func(c *gin.Context) {
code := c.Query("code")
state := c.Query("state")
errStr := c.Query("error")
// Persist to a temporary file keyed by state
if state != "" {
file := fmt.Sprintf("%s/.oauth-anthropic-%s.oauth", s.cfg.AuthDir, state)
_ = os.WriteFile(file, []byte(fmt.Sprintf(`{"code":"%s","state":"%s","error":"%s"}`, code, state, errStr)), 0o600)
}
c.Header("Content-Type", "text/html; charset=utf-8")
c.String(http.StatusOK, "<html><body><h1>Authentication successful!</h1><p>You can close this window.</p></body></html>")
})
s.engine.GET("/codex/callback", func(c *gin.Context) {
code := c.Query("code")
state := c.Query("state")
errStr := c.Query("error")
if state != "" {
file := fmt.Sprintf("%s/.oauth-codex-%s.oauth", s.cfg.AuthDir, state)
_ = os.WriteFile(file, []byte(fmt.Sprintf(`{"code":"%s","state":"%s","error":"%s"}`, code, state, errStr)), 0o600)
}
c.Header("Content-Type", "text/html; charset=utf-8")
c.String(http.StatusOK, "<html><body><h1>Authentication successful!</h1><p>You can close this window.</p></body></html>")
})
s.engine.GET("/google/callback", func(c *gin.Context) {
code := c.Query("code")
state := c.Query("state")
errStr := c.Query("error")
if state != "" {
file := fmt.Sprintf("%s/.oauth-gemini-%s.oauth", s.cfg.AuthDir, state)
_ = os.WriteFile(file, []byte(fmt.Sprintf(`{"code":"%s","state":"%s","error":"%s"}`, code, state, errStr)), 0o600)
}
c.Header("Content-Type", "text/html; charset=utf-8")
c.String(http.StatusOK, "<html><body><h1>Authentication successful!</h1><p>You can close this window.</p></body></html>")
})
// Management API routes (delegated to management handlers)
// New logic: if remote-management-key is empty, do not expose any management endpoint (404).
if s.cfg.RemoteManagement.SecretKey != "" {
mgmt := s.engine.Group("/v0/management")
mgmt.Use(s.mgmt.Middleware())
{
mgmt.GET("/usage", s.mgmt.GetUsageStatistics)
mgmt.GET("/config", s.mgmt.GetConfig)
mgmt.GET("/debug", s.mgmt.GetDebug)
mgmt.PUT("/debug", s.mgmt.PutDebug)
mgmt.PATCH("/debug", s.mgmt.PutDebug)
mgmt.GET("/proxy-url", s.mgmt.GetProxyURL)
mgmt.PUT("/proxy-url", s.mgmt.PutProxyURL)
mgmt.PATCH("/proxy-url", s.mgmt.PutProxyURL)
mgmt.DELETE("/proxy-url", s.mgmt.DeleteProxyURL)
mgmt.GET("/quota-exceeded/switch-project", s.mgmt.GetSwitchProject)
mgmt.PUT("/quota-exceeded/switch-project", s.mgmt.PutSwitchProject)
mgmt.PATCH("/quota-exceeded/switch-project", s.mgmt.PutSwitchProject)
mgmt.GET("/quota-exceeded/switch-preview-model", s.mgmt.GetSwitchPreviewModel)
mgmt.PUT("/quota-exceeded/switch-preview-model", s.mgmt.PutSwitchPreviewModel)
mgmt.PATCH("/quota-exceeded/switch-preview-model", s.mgmt.PutSwitchPreviewModel)
mgmt.GET("/api-keys", s.mgmt.GetAPIKeys)
mgmt.PUT("/api-keys", s.mgmt.PutAPIKeys)
mgmt.PATCH("/api-keys", s.mgmt.PatchAPIKeys)
mgmt.DELETE("/api-keys", s.mgmt.DeleteAPIKeys)
mgmt.GET("/generative-language-api-key", s.mgmt.GetGlKeys)
mgmt.PUT("/generative-language-api-key", s.mgmt.PutGlKeys)
mgmt.PATCH("/generative-language-api-key", s.mgmt.PatchGlKeys)
mgmt.DELETE("/generative-language-api-key", s.mgmt.DeleteGlKeys)
mgmt.GET("/request-log", s.mgmt.GetRequestLog)
mgmt.PUT("/request-log", s.mgmt.PutRequestLog)
mgmt.PATCH("/request-log", s.mgmt.PutRequestLog)
mgmt.GET("/request-retry", s.mgmt.GetRequestRetry)
mgmt.PUT("/request-retry", s.mgmt.PutRequestRetry)
mgmt.PATCH("/request-retry", s.mgmt.PutRequestRetry)
mgmt.GET("/claude-api-key", s.mgmt.GetClaudeKeys)
mgmt.PUT("/claude-api-key", s.mgmt.PutClaudeKeys)
mgmt.PATCH("/claude-api-key", s.mgmt.PatchClaudeKey)
mgmt.DELETE("/claude-api-key", s.mgmt.DeleteClaudeKey)
mgmt.GET("/codex-api-key", s.mgmt.GetCodexKeys)
mgmt.PUT("/codex-api-key", s.mgmt.PutCodexKeys)
mgmt.PATCH("/codex-api-key", s.mgmt.PatchCodexKey)
mgmt.DELETE("/codex-api-key", s.mgmt.DeleteCodexKey)
mgmt.GET("/openai-compatibility", s.mgmt.GetOpenAICompat)
mgmt.PUT("/openai-compatibility", s.mgmt.PutOpenAICompat)
mgmt.PATCH("/openai-compatibility", s.mgmt.PatchOpenAICompat)
mgmt.DELETE("/openai-compatibility", s.mgmt.DeleteOpenAICompat)
mgmt.GET("/auth-files", s.mgmt.ListAuthFiles)
mgmt.GET("/auth-files/download", s.mgmt.DownloadAuthFile)
mgmt.POST("/auth-files", s.mgmt.UploadAuthFile)
mgmt.DELETE("/auth-files", s.mgmt.DeleteAuthFile)
mgmt.GET("/anthropic-auth-url", s.mgmt.RequestAnthropicToken)
mgmt.GET("/codex-auth-url", s.mgmt.RequestCodexToken)
mgmt.GET("/gemini-cli-auth-url", s.mgmt.RequestGeminiCLIToken)
mgmt.POST("/gemini-web-token", s.mgmt.CreateGeminiWebToken)
mgmt.GET("/qwen-auth-url", s.mgmt.RequestQwenToken)
mgmt.GET("/get-auth-status", s.mgmt.GetAuthStatus)
}
}
}
// unifiedModelsHandler creates a unified handler for the /v1/models endpoint
// that routes to different handlers based on the User-Agent header.
// If User-Agent starts with "claude-cli", it routes to Claude handler,
// otherwise it routes to OpenAI handler.
func (s *Server) unifiedModelsHandler(openaiHandler *openai.OpenAIAPIHandler, claudeHandler *claude.ClaudeCodeAPIHandler) gin.HandlerFunc {
return func(c *gin.Context) {
userAgent := c.GetHeader("User-Agent")
// Route to Claude handler if User-Agent starts with "claude-cli"
if strings.HasPrefix(userAgent, "claude-cli") {
// log.Debugf("Routing /v1/models to Claude handler for User-Agent: %s", userAgent)
claudeHandler.ClaudeModels(c)
} else {
// log.Debugf("Routing /v1/models to OpenAI handler for User-Agent: %s", userAgent)
openaiHandler.OpenAIModels(c)
}
}
}
// Start begins listening for and serving HTTP requests.
// It's a blocking call and will only return on an unrecoverable error.
//
// Returns:
// - error: An error if the server fails to start
func (s *Server) Start() error {
log.Debugf("Starting API server on %s", s.server.Addr)
// Start the HTTP server.
if err := s.server.ListenAndServe(); err != nil && !errors.Is(err, http.ErrServerClosed) {
return fmt.Errorf("failed to start HTTP server: %v", err)
}
return nil
}
// Stop gracefully shuts down the API server without interrupting any
// active connections.
//
// Parameters:
// - ctx: The context for graceful shutdown
//
// Returns:
// - error: An error if the server fails to stop
func (s *Server) Stop(ctx context.Context) error {
log.Debug("Stopping API server...")
// Shutdown the HTTP server.
if err := s.server.Shutdown(ctx); err != nil {
return fmt.Errorf("failed to shutdown HTTP server: %v", err)
}
log.Debug("API server stopped")
return nil
}
// corsMiddleware returns a Gin middleware handler that adds CORS headers
// to every response, allowing cross-origin requests.
//
// Returns:
// - gin.HandlerFunc: The CORS middleware handler
func corsMiddleware() gin.HandlerFunc {
return func(c *gin.Context) {
c.Header("Access-Control-Allow-Origin", "*")
c.Header("Access-Control-Allow-Methods", "GET, POST, PUT, DELETE, OPTIONS")
c.Header("Access-Control-Allow-Headers", "*")
if c.Request.Method == "OPTIONS" {
c.AbortWithStatus(http.StatusNoContent)
return
}
c.Next()
}
}
func (s *Server) applyAccessConfig(cfg *config.Config) {
if s == nil || s.accessManager == nil {
return
}
providers, err := sdkaccess.BuildProviders(cfg)
if err != nil {
log.Errorf("failed to update request auth providers: %v", err)
return
}
s.accessManager.SetProviders(providers)
}
// UpdateClients updates the server's client list and configuration.
// This method is called when the configuration or authentication tokens change.
//
// Parameters:
// - clients: The new slice of AI service clients
// - cfg: The new application configuration
func (s *Server) UpdateClients(cfg *config.Config) {
// Update request logger enabled state if it has changed
if s.requestLogger != nil && s.cfg.RequestLog != cfg.RequestLog {
if s.loggerToggle != nil {
s.loggerToggle(cfg.RequestLog)
} else if toggler, ok := s.requestLogger.(interface{ SetEnabled(bool) }); ok {
toggler.SetEnabled(cfg.RequestLog)
}
log.Debugf("request logging updated from %t to %t", s.cfg.RequestLog, cfg.RequestLog)
}
// Update log level dynamically when debug flag changes
if s.cfg.Debug != cfg.Debug {
util.SetLogLevel(cfg)
log.Debugf("debug mode updated from %t to %t", s.cfg.Debug, cfg.Debug)
}
s.cfg = cfg
s.handlers.UpdateClients(cfg)
if s.mgmt != nil {
s.mgmt.SetConfig(cfg)
s.mgmt.SetAuthManager(s.handlers.AuthManager)
}
s.applyAccessConfig(cfg)
// Count client sources from configuration and auth directory
authFiles := util.CountAuthFiles(cfg.AuthDir)
glAPIKeyCount := len(cfg.GlAPIKey)
claudeAPIKeyCount := len(cfg.ClaudeKey)
codexAPIKeyCount := len(cfg.CodexKey)
openAICompatCount := 0
for i := range cfg.OpenAICompatibility {
openAICompatCount += len(cfg.OpenAICompatibility[i].APIKeys)
}
total := authFiles + glAPIKeyCount + claudeAPIKeyCount + codexAPIKeyCount + openAICompatCount
log.Infof("server clients and configuration updated: %d clients (%d auth files + %d GL API keys + %d Claude API keys + %d Codex keys + %d OpenAI-compat)",
total,
authFiles,
glAPIKeyCount,
claudeAPIKeyCount,
codexAPIKeyCount,
openAICompatCount,
)
}
// (management handlers moved to internal/api/handlers/management)
// AuthMiddleware returns a Gin middleware handler that authenticates requests
// using the configured authentication providers. When no providers are available,
// it allows all requests (legacy behaviour).
func AuthMiddleware(manager *sdkaccess.Manager) gin.HandlerFunc {
return func(c *gin.Context) {
if manager == nil {
c.Next()
return
}
result, err := manager.Authenticate(c.Request.Context(), c.Request)
if err == nil {
if result != nil {
c.Set("apiKey", result.Principal)
c.Set("accessProvider", result.Provider)
if len(result.Metadata) > 0 {
c.Set("accessMetadata", result.Metadata)
}
}
c.Next()
return
}
switch {
case errors.Is(err, sdkaccess.ErrNoCredentials):
c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{"error": "Missing API key"})
case errors.Is(err, sdkaccess.ErrInvalidCredential):
c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{"error": "Invalid API key"})
default:
log.Errorf("authentication middleware error: %v", err)
c.AbortWithStatusJSON(http.StatusInternalServerError, gin.H{"error": "Authentication service error"})
}
}
}
// legacy clientsToSlice removed; handlers no longer consume legacy client slices