Refactor API handlers organization and simplify error response handling

- Modularized handlers into dedicated packages (`gemini`, `claude`, `cli`) for better structure.
- Centralized `ErrorResponse` and `ErrorDetail` types under `handlers` package for reuse.
- Updated all handlers to utilize the shared `ErrorResponse` model.
- Introduced specialization of handler structs (`GeminiAPIHandlers`, `ClaudeCodeAPIHandlers`, `GeminiCLIAPIHandlers`) for improved clarity and separation of concerns.
- Refactored `getClient` logic with additional properties and better state management.

Refactor `translator` package by modularizing code for `claude` and `gemini`

- Moved Claude-specific logic (`PrepareClaudeRequest`, `ConvertCliToClaude`) to `translator/claude/code`.
- Moved Gemini-specific logic (`FixCLIToolResponse`) to `translator/gemini/cli` for better package structure.
- Updated affected handler imports and method references.

Add comprehensive package-level documentation across key modules

- Introduced detailed package-level documentation for core modules: `auth`, `client`, `cmd`, `handlers`, `util`, `watcher`, `config`, `translator`, and `api`.
- Enhanced code readability and maintainability by clarifying the purpose and functionality of each package.
- Aligned documentation style and tone with existing codebase conventions.

Refactor API handlers and translator modules for improved clarity and consistency

- Standardized handler struct names (`GeminiAPIHandlers`, `ClaudeCodeAPIHandlers`, `GeminiCLIAPIHandlers`, `OpenAIAPIHandlers`) and updated related comments.
- Fixed unnecessary `else` blocks in streaming logic for cleaner error handling.
- Renamed variables for better readability (`responseIdResult` to `responseIDResult`, `activationUrl` to `activationURL`, etc.).
- Addressed minor inconsistencies in API handler comments and SSE header initialization.
- Improved modularization of `claude` and `gemini` translator components.

Standardize configuration field naming for consistency across modules

- Renamed `ProxyUrl` to `ProxyURL`, `ApiKeys` to `APIKeys`, and `ConfigQuotaExceeded` to `QuotaExceeded`.
- Updated all relevant references and comments in `config`, `auth`, `api`, `util`, and `watcher`.
- Ensured consistent casing for `GlAPIKey` debug logs.
This commit is contained in:
Luis Pater
2025-08-05 17:37:00 +08:00
parent 00f33f5f3a
commit 1483c31c73
23 changed files with 1484 additions and 1267 deletions

View File

@@ -1,3 +1,6 @@
// Package main provides the entry point for the CLI Proxy API server.
// This server acts as a proxy that provides OpenAI/Gemini/Claude compatible API interfaces
// for CLI models, allowing CLI models to be used with tools and libraries designed for standard AI APIs.
package main package main
import ( import (

View File

@@ -1,10 +1,17 @@
package api // Package claude provides HTTP handlers for Claude API code-related functionality.
// This package implements Claude-compatible streaming chat completions with sophisticated
// client rotation and quota management systems to ensure high availability and optimal
// resource utilization across multiple backend clients. It handles request translation
// between Claude API format and the underlying Gemini backend, providing seamless
// API compatibility while maintaining robust error handling and connection management.
package claude
import ( import (
"context" "context"
"fmt" "fmt"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/luispater/CLIProxyAPI/internal/api/translator" "github.com/luispater/CLIProxyAPI/internal/api/handlers"
"github.com/luispater/CLIProxyAPI/internal/api/translator/claude/code"
"github.com/luispater/CLIProxyAPI/internal/client" "github.com/luispater/CLIProxyAPI/internal/client"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"net/http" "net/http"
@@ -12,16 +19,30 @@ import (
"time" "time"
) )
// ClaudeCodeAPIHandlers contains the handlers for Claude API endpoints.
// It holds a pool of clients to interact with the backend service.
type ClaudeCodeAPIHandlers struct {
*handlers.APIHandlers
}
// NewClaudeCodeAPIHandlers creates a new Claude API handlers instance.
// It takes an APIHandlers instance as input and returns a ClaudeCodeAPIHandlers.
func NewClaudeCodeAPIHandlers(apiHandlers *handlers.APIHandlers) *ClaudeCodeAPIHandlers {
return &ClaudeCodeAPIHandlers{
APIHandlers: apiHandlers,
}
}
// ClaudeMessages handles Claude-compatible streaming chat completions. // ClaudeMessages handles Claude-compatible streaming chat completions.
// This function implements a sophisticated client rotation and quota management system // This function implements a sophisticated client rotation and quota management system
// to ensure high availability and optimal resource utilization across multiple backend clients. // to ensure high availability and optimal resource utilization across multiple backend clients.
func (h *APIHandlers) ClaudeMessages(c *gin.Context) { func (h *ClaudeCodeAPIHandlers) ClaudeMessages(c *gin.Context) {
// Extract raw JSON data from the incoming request // Extract raw JSON data from the incoming request
rawJson, err := c.GetRawData() rawJSON, err := c.GetRawData()
// If data retrieval fails, return a 400 Bad Request error. // If data retrieval fails, return a 400 Bad Request error.
if err != nil { if err != nil {
c.JSON(http.StatusBadRequest, ErrorResponse{ c.JSON(http.StatusBadRequest, handlers.ErrorResponse{
Error: ErrorDetail{ Error: handlers.ErrorDetail{
Message: fmt.Sprintf("Invalid request: %v", err), Message: fmt.Sprintf("Invalid request: %v", err),
Type: "invalid_request_error", Type: "invalid_request_error",
}, },
@@ -41,8 +62,8 @@ func (h *APIHandlers) ClaudeMessages(c *gin.Context) {
// This is crucial for streaming as it allows immediate sending of data chunks // This is crucial for streaming as it allows immediate sending of data chunks
flusher, ok := c.Writer.(http.Flusher) flusher, ok := c.Writer.(http.Flusher)
if !ok { if !ok {
c.JSON(http.StatusInternalServerError, ErrorResponse{ c.JSON(http.StatusInternalServerError, handlers.ErrorResponse{
Error: ErrorDetail{ Error: handlers.ErrorDetail{
Message: "Streaming not supported", Message: "Streaming not supported",
Type: "server_error", Type: "server_error",
}, },
@@ -52,7 +73,7 @@ func (h *APIHandlers) ClaudeMessages(c *gin.Context) {
// Parse and prepare the Claude request, extracting model name, system instructions, // Parse and prepare the Claude request, extracting model name, system instructions,
// conversation contents, and available tools from the raw JSON // conversation contents, and available tools from the raw JSON
modelName, systemInstruction, contents, tools := translator.PrepareClaudeRequest(rawJson) modelName, systemInstruction, contents, tools := code.PrepareClaudeRequest(rawJSON)
// Map Claude model names to corresponding Gemini models // Map Claude model names to corresponding Gemini models
// This allows the proxy to handle Claude API calls using Gemini backend // This allows the proxy to handle Claude API calls using Gemini backend
@@ -79,7 +100,7 @@ func (h *APIHandlers) ClaudeMessages(c *gin.Context) {
outLoop: outLoop:
for { for {
var errorResponse *client.ErrorMessage var errorResponse *client.ErrorMessage
cliClient, errorResponse = h.getClient(modelName) cliClient, errorResponse = h.GetClient(modelName)
if errorResponse != nil { if errorResponse != nil {
c.Status(errorResponse.StatusCode) c.Status(errorResponse.StatusCode)
_, _ = fmt.Fprint(c.Writer, errorResponse.Error) _, _ = fmt.Fprint(c.Writer, errorResponse.Error)
@@ -105,7 +126,7 @@ outLoop:
includeThoughts = !strings.Contains(userAgent[0], "claude-cli") includeThoughts = !strings.Contains(userAgent[0], "claude-cli")
} }
respChan, errChan := cliClient.SendMessageStream(cliCtx, rawJson, modelName, systemInstruction, contents, tools, includeThoughts) respChan, errChan := cliClient.SendMessageStream(cliCtx, rawJSON, modelName, systemInstruction, contents, tools, includeThoughts)
// Track response state for proper Claude format conversion // Track response state for proper Claude format conversion
hasFirstResponse := false hasFirstResponse := false
@@ -139,16 +160,15 @@ outLoop:
flusher.Flush() flusher.Flush()
cliCancel() cliCancel()
return return
} else {
// Convert the backend response to Claude-compatible format
// This translation layer ensures API compatibility
claudeFormat := translator.ConvertCliToClaude(chunk, isGlAPIKey, hasFirstResponse, &responseType, &responseIndex)
if claudeFormat != "" {
_, _ = c.Writer.Write([]byte(claudeFormat))
flusher.Flush() // Immediately send the chunk to the client
}
hasFirstResponse = true
} }
// Convert the backend response to Claude-compatible format
// This translation layer ensures API compatibility
claudeFormat := code.ConvertCliToClaude(chunk, isGlAPIKey, hasFirstResponse, &responseType, &responseIndex)
if claudeFormat != "" {
_, _ = c.Writer.Write([]byte(claudeFormat))
flusher.Flush() // Immediately send the chunk to the client
}
hasFirstResponse = true
// Case 3: Handle errors from the backend // Case 3: Handle errors from the backend
// This manages various error conditions and implements retry logic // This manages various error conditions and implements retry logic
@@ -156,7 +176,7 @@ outLoop:
if okError { if okError {
// Special handling for quota exceeded errors // Special handling for quota exceeded errors
// If configured, attempt to switch to a different project/client // If configured, attempt to switch to a different project/client
if errInfo.StatusCode == 429 && h.cfg.QuotaExceeded.SwitchProject { if errInfo.StatusCode == 429 && h.Cfg.QuotaExceeded.SwitchProject {
continue outLoop // Restart the client selection process continue outLoop // Restart the client selection process
} else { } else {
// Forward other errors directly to the client // Forward other errors directly to the client

View File

@@ -1,10 +1,15 @@
package api // Package cli provides HTTP handlers for Gemini CLI API functionality.
// This package implements handlers that process CLI-specific requests for Gemini API operations,
// including content generation and streaming content generation endpoints.
// The handlers restrict access to localhost only and manage communication with the backend service.
package cli
import ( import (
"bytes" "bytes"
"context" "context"
"fmt" "fmt"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/luispater/CLIProxyAPI/internal/api/handlers"
"github.com/luispater/CLIProxyAPI/internal/client" "github.com/luispater/CLIProxyAPI/internal/client"
"github.com/luispater/CLIProxyAPI/internal/util" "github.com/luispater/CLIProxyAPI/internal/util"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
@@ -16,10 +21,26 @@ import (
"time" "time"
) )
func (h *APIHandlers) CLIHandler(c *gin.Context) { // GeminiCLIAPIHandlers contains the handlers for Gemini CLI API endpoints.
// It holds a pool of clients to interact with the backend service.
type GeminiCLIAPIHandlers struct {
*handlers.APIHandlers
}
// NewGeminiCLIAPIHandlers creates a new Gemini CLI API handlers instance.
// It takes an APIHandlers instance as input and returns a GeminiCLIAPIHandlers.
func NewGeminiCLIAPIHandlers(apiHandlers *handlers.APIHandlers) *GeminiCLIAPIHandlers {
return &GeminiCLIAPIHandlers{
APIHandlers: apiHandlers,
}
}
// CLIHandler handles CLI-specific requests for Gemini API operations.
// It restricts access to localhost only and routes requests to appropriate internal handlers.
func (h *GeminiCLIAPIHandlers) CLIHandler(c *gin.Context) {
if !strings.HasPrefix(c.Request.RemoteAddr, "127.0.0.1:") { if !strings.HasPrefix(c.Request.RemoteAddr, "127.0.0.1:") {
c.JSON(http.StatusForbidden, ErrorResponse{ c.JSON(http.StatusForbidden, handlers.ErrorResponse{
Error: ErrorDetail{ Error: handlers.ErrorDetail{
Message: "CLI reply only allow local access", Message: "CLI reply only allow local access",
Type: "forbidden", Type: "forbidden",
}, },
@@ -27,18 +48,18 @@ func (h *APIHandlers) CLIHandler(c *gin.Context) {
return return
} }
rawJson, _ := c.GetRawData() rawJSON, _ := c.GetRawData()
requestRawURI := c.Request.URL.Path requestRawURI := c.Request.URL.Path
if requestRawURI == "/v1internal:generateContent" { if requestRawURI == "/v1internal:generateContent" {
h.internalGenerateContent(c, rawJson) h.internalGenerateContent(c, rawJSON)
} else if requestRawURI == "/v1internal:streamGenerateContent" { } else if requestRawURI == "/v1internal:streamGenerateContent" {
h.internalStreamGenerateContent(c, rawJson) h.internalStreamGenerateContent(c, rawJSON)
} else { } else {
reqBody := bytes.NewBuffer(rawJson) reqBody := bytes.NewBuffer(rawJSON)
req, err := http.NewRequest("POST", fmt.Sprintf("https://cloudcode-pa.googleapis.com%s", c.Request.URL.RequestURI()), reqBody) req, err := http.NewRequest("POST", fmt.Sprintf("https://cloudcode-pa.googleapis.com%s", c.Request.URL.RequestURI()), reqBody)
if err != nil { if err != nil {
c.JSON(http.StatusBadRequest, ErrorResponse{ c.JSON(http.StatusBadRequest, handlers.ErrorResponse{
Error: ErrorDetail{ Error: handlers.ErrorDetail{
Message: fmt.Sprintf("Invalid request: %v", err), Message: fmt.Sprintf("Invalid request: %v", err),
Type: "invalid_request_error", Type: "invalid_request_error",
}, },
@@ -49,15 +70,15 @@ func (h *APIHandlers) CLIHandler(c *gin.Context) {
req.Header[key] = value req.Header[key] = value
} }
httpClient, err := util.SetProxy(h.cfg, &http.Client{}) httpClient, err := util.SetProxy(h.Cfg, &http.Client{})
if err != nil { if err != nil {
log.Fatalf("set proxy failed: %v", err) log.Fatalf("set proxy failed: %v", err)
} }
resp, err := httpClient.Do(req) resp, err := httpClient.Do(req)
if err != nil { if err != nil {
c.JSON(http.StatusBadRequest, ErrorResponse{ c.JSON(http.StatusBadRequest, handlers.ErrorResponse{
Error: ErrorDetail{ Error: handlers.ErrorDetail{
Message: fmt.Sprintf("Invalid request: %v", err), Message: fmt.Sprintf("Invalid request: %v", err),
Type: "invalid_request_error", Type: "invalid_request_error",
}, },
@@ -73,8 +94,8 @@ func (h *APIHandlers) CLIHandler(c *gin.Context) {
}() }()
bodyBytes, _ := io.ReadAll(resp.Body) bodyBytes, _ := io.ReadAll(resp.Body)
c.JSON(http.StatusBadRequest, ErrorResponse{ c.JSON(http.StatusBadRequest, handlers.ErrorResponse{
Error: ErrorDetail{ Error: handlers.ErrorDetail{
Message: string(bodyBytes), Message: string(bodyBytes),
Type: "invalid_request_error", Type: "invalid_request_error",
}, },
@@ -98,8 +119,8 @@ func (h *APIHandlers) CLIHandler(c *gin.Context) {
} }
} }
func (h *APIHandlers) internalStreamGenerateContent(c *gin.Context, rawJson []byte) { func (h *GeminiCLIAPIHandlers) internalStreamGenerateContent(c *gin.Context, rawJSON []byte) {
alt := h.getAlt(c) alt := h.GetAlt(c)
if alt == "" { if alt == "" {
c.Header("Content-Type", "text/event-stream") c.Header("Content-Type", "text/event-stream")
@@ -111,8 +132,8 @@ func (h *APIHandlers) internalStreamGenerateContent(c *gin.Context, rawJson []by
// Get the http.Flusher interface to manually flush the response. // Get the http.Flusher interface to manually flush the response.
flusher, ok := c.Writer.(http.Flusher) flusher, ok := c.Writer.(http.Flusher)
if !ok { if !ok {
c.JSON(http.StatusInternalServerError, ErrorResponse{ c.JSON(http.StatusInternalServerError, handlers.ErrorResponse{
Error: ErrorDetail{ Error: handlers.ErrorDetail{
Message: "Streaming not supported", Message: "Streaming not supported",
Type: "server_error", Type: "server_error",
}, },
@@ -120,7 +141,7 @@ func (h *APIHandlers) internalStreamGenerateContent(c *gin.Context, rawJson []by
return return
} }
modelResult := gjson.GetBytes(rawJson, "model") modelResult := gjson.GetBytes(rawJSON, "model")
modelName := modelResult.String() modelName := modelResult.String()
cliCtx, cliCancel := context.WithCancel(context.Background()) cliCtx, cliCancel := context.WithCancel(context.Background())
@@ -135,7 +156,7 @@ func (h *APIHandlers) internalStreamGenerateContent(c *gin.Context, rawJson []by
outLoop: outLoop:
for { for {
var errorResponse *client.ErrorMessage var errorResponse *client.ErrorMessage
cliClient, errorResponse = h.getClient(modelName) cliClient, errorResponse = h.GetClient(modelName)
if errorResponse != nil { if errorResponse != nil {
c.Status(errorResponse.StatusCode) c.Status(errorResponse.StatusCode)
_, _ = fmt.Fprint(c.Writer, errorResponse.Error) _, _ = fmt.Fprint(c.Writer, errorResponse.Error)
@@ -150,7 +171,7 @@ outLoop:
log.Debugf("Request use account: %s, project id: %s", cliClient.GetEmail(), cliClient.GetProjectID()) log.Debugf("Request use account: %s, project id: %s", cliClient.GetEmail(), cliClient.GetProjectID())
} }
// Send the message and receive response chunks and errors via channels. // Send the message and receive response chunks and errors via channels.
respChan, errChan := cliClient.SendRawMessageStream(cliCtx, rawJson, "") respChan, errChan := cliClient.SendRawMessageStream(cliCtx, rawJSON, "")
hasFirstResponse := false hasFirstResponse := false
for { for {
select { select {
@@ -166,20 +187,19 @@ outLoop:
if !okStream { if !okStream {
cliCancel() cliCancel()
return return
} else {
hasFirstResponse = true
if cliClient.GetGenerativeLanguageAPIKey() != "" {
chunk, _ = sjson.SetRawBytes(chunk, "response", chunk)
}
_, _ = c.Writer.Write([]byte("data: "))
_, _ = c.Writer.Write(chunk)
_, _ = c.Writer.Write([]byte("\n\n"))
flusher.Flush()
} }
hasFirstResponse = true
if cliClient.GetGenerativeLanguageAPIKey() != "" {
chunk, _ = sjson.SetRawBytes(chunk, "response", chunk)
}
_, _ = c.Writer.Write([]byte("data: "))
_, _ = c.Writer.Write(chunk)
_, _ = c.Writer.Write([]byte("\n\n"))
flusher.Flush()
// Handle errors from the backend. // Handle errors from the backend.
case err, okError := <-errChan: case err, okError := <-errChan:
if okError { if okError {
if err.StatusCode == 429 && h.cfg.QuotaExceeded.SwitchProject { if err.StatusCode == 429 && h.Cfg.QuotaExceeded.SwitchProject {
continue outLoop continue outLoop
} else { } else {
c.Status(err.StatusCode) c.Status(err.StatusCode)
@@ -200,10 +220,10 @@ outLoop:
} }
} }
func (h *APIHandlers) internalGenerateContent(c *gin.Context, rawJson []byte) { func (h *GeminiCLIAPIHandlers) internalGenerateContent(c *gin.Context, rawJSON []byte) {
c.Header("Content-Type", "application/json") c.Header("Content-Type", "application/json")
modelResult := gjson.GetBytes(rawJson, "model") modelResult := gjson.GetBytes(rawJSON, "model")
modelName := modelResult.String() modelName := modelResult.String()
cliCtx, cliCancel := context.WithCancel(context.Background()) cliCtx, cliCancel := context.WithCancel(context.Background())
var cliClient *client.Client var cliClient *client.Client
@@ -215,7 +235,7 @@ func (h *APIHandlers) internalGenerateContent(c *gin.Context, rawJson []byte) {
for { for {
var errorResponse *client.ErrorMessage var errorResponse *client.ErrorMessage
cliClient, errorResponse = h.getClient(modelName) cliClient, errorResponse = h.GetClient(modelName)
if errorResponse != nil { if errorResponse != nil {
c.Status(errorResponse.StatusCode) c.Status(errorResponse.StatusCode)
_, _ = fmt.Fprint(c.Writer, errorResponse.Error) _, _ = fmt.Fprint(c.Writer, errorResponse.Error)
@@ -229,9 +249,9 @@ func (h *APIHandlers) internalGenerateContent(c *gin.Context, rawJson []byte) {
log.Debugf("Request use account: %s, project id: %s", cliClient.GetEmail(), cliClient.GetProjectID()) log.Debugf("Request use account: %s, project id: %s", cliClient.GetEmail(), cliClient.GetProjectID())
} }
resp, err := cliClient.SendRawMessage(cliCtx, rawJson, "") resp, err := cliClient.SendRawMessage(cliCtx, rawJSON, "")
if err != nil { if err != nil {
if err.StatusCode == 429 && h.cfg.QuotaExceeded.SwitchProject { if err.StatusCode == 429 && h.Cfg.QuotaExceeded.SwitchProject {
continue continue
} else { } else {
c.Status(err.StatusCode) c.Status(err.StatusCode)

View File

@@ -1,10 +1,16 @@
package api // Package gemini provides HTTP handlers for Gemini API endpoints.
// This package implements handlers for managing Gemini model operations including
// model listing, content generation, streaming content generation, and token counting.
// It serves as a proxy layer between clients and the Gemini backend service,
// handling request translation, client management, and response processing.
package gemini
import ( import (
"context" "context"
"fmt" "fmt"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/luispater/CLIProxyAPI/internal/api/translator" "github.com/luispater/CLIProxyAPI/internal/api/handlers"
"github.com/luispater/CLIProxyAPI/internal/api/translator/gemini/cli"
"github.com/luispater/CLIProxyAPI/internal/client" "github.com/luispater/CLIProxyAPI/internal/client"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"github.com/tidwall/gjson" "github.com/tidwall/gjson"
@@ -14,7 +20,23 @@ import (
"time" "time"
) )
func (h *APIHandlers) GeminiModels(c *gin.Context) { // GeminiAPIHandlers contains the handlers for Gemini API endpoints.
// It holds a pool of clients to interact with the backend service.
type GeminiAPIHandlers struct {
*handlers.APIHandlers
}
// NewGeminiAPIHandlers creates a new Gemini API handlers instance.
// It takes an APIHandlers instance as input and returns a GeminiAPIHandlers.
func NewGeminiAPIHandlers(apiHandlers *handlers.APIHandlers) *GeminiAPIHandlers {
return &GeminiAPIHandlers{
APIHandlers: apiHandlers,
}
}
// GeminiModels handles the Gemini models listing endpoint.
// It returns a JSON response containing available Gemini models and their specifications.
func (h *GeminiAPIHandlers) GeminiModels(c *gin.Context) {
c.Status(http.StatusOK) c.Status(http.StatusOK)
c.Header("Content-Type", "application/json; charset=UTF-8") c.Header("Content-Type", "application/json; charset=UTF-8")
_, _ = c.Writer.Write([]byte(`{"models":[{"name":"models/gemini-2.5-flash","version":"001","displayName":"Gemini `)) _, _ = c.Writer.Write([]byte(`{"models":[{"name":"models/gemini-2.5-flash","version":"001","displayName":"Gemini `))
@@ -30,13 +52,15 @@ func (h *APIHandlers) GeminiModels(c *gin.Context) {
_, _ = c.Writer.Write([]byte(`e":2,"thinking":true}],"nextPageToken":""}`)) _, _ = c.Writer.Write([]byte(`e":2,"thinking":true}],"nextPageToken":""}`))
} }
func (h *APIHandlers) GeminiGetHandler(c *gin.Context) { // GeminiGetHandler handles GET requests for specific Gemini model information.
// It returns detailed information about a specific Gemini model based on the action parameter.
func (h *GeminiAPIHandlers) GeminiGetHandler(c *gin.Context) {
var request struct { var request struct {
Action string `uri:"action" binding:"required"` Action string `uri:"action" binding:"required"`
} }
if err := c.ShouldBindUri(&request); err != nil { if err := c.ShouldBindUri(&request); err != nil {
c.JSON(http.StatusBadRequest, ErrorResponse{ c.JSON(http.StatusBadRequest, handlers.ErrorResponse{
Error: ErrorDetail{ Error: handlers.ErrorDetail{
Message: fmt.Sprintf("Invalid request: %v", err), Message: fmt.Sprintf("Invalid request: %v", err),
Type: "invalid_request_error", Type: "invalid_request_error",
}, },
@@ -68,13 +92,15 @@ func (h *APIHandlers) GeminiGetHandler(c *gin.Context) {
} }
} }
func (h *APIHandlers) GeminiHandler(c *gin.Context) { // GeminiHandler handles POST requests for Gemini API operations.
// It routes requests to appropriate handlers based on the action parameter (model:method format).
func (h *GeminiAPIHandlers) GeminiHandler(c *gin.Context) {
var request struct { var request struct {
Action string `uri:"action" binding:"required"` Action string `uri:"action" binding:"required"`
} }
if err := c.ShouldBindUri(&request); err != nil { if err := c.ShouldBindUri(&request); err != nil {
c.JSON(http.StatusBadRequest, ErrorResponse{ c.JSON(http.StatusBadRequest, handlers.ErrorResponse{
Error: ErrorDetail{ Error: handlers.ErrorDetail{
Message: fmt.Sprintf("Invalid request: %v", err), Message: fmt.Sprintf("Invalid request: %v", err),
Type: "invalid_request_error", Type: "invalid_request_error",
}, },
@@ -83,8 +109,8 @@ func (h *APIHandlers) GeminiHandler(c *gin.Context) {
} }
action := strings.Split(request.Action, ":") action := strings.Split(request.Action, ":")
if len(action) != 2 { if len(action) != 2 {
c.JSON(http.StatusNotFound, ErrorResponse{ c.JSON(http.StatusNotFound, handlers.ErrorResponse{
Error: ErrorDetail{ Error: handlers.ErrorDetail{
Message: fmt.Sprintf("%s not found.", c.Request.URL.Path), Message: fmt.Sprintf("%s not found.", c.Request.URL.Path),
Type: "invalid_request_error", Type: "invalid_request_error",
}, },
@@ -94,20 +120,20 @@ func (h *APIHandlers) GeminiHandler(c *gin.Context) {
modelName := action[0] modelName := action[0]
method := action[1] method := action[1]
rawJson, _ := c.GetRawData() rawJSON, _ := c.GetRawData()
rawJson, _ = sjson.SetBytes(rawJson, "model", []byte(modelName)) rawJSON, _ = sjson.SetBytes(rawJSON, "model", []byte(modelName))
if method == "generateContent" { if method == "generateContent" {
h.geminiGenerateContent(c, rawJson) h.geminiGenerateContent(c, rawJSON)
} else if method == "streamGenerateContent" { } else if method == "streamGenerateContent" {
h.geminiStreamGenerateContent(c, rawJson) h.geminiStreamGenerateContent(c, rawJSON)
} else if method == "countTokens" { } else if method == "countTokens" {
h.geminiCountTokens(c, rawJson) h.geminiCountTokens(c, rawJSON)
} }
} }
func (h *APIHandlers) geminiStreamGenerateContent(c *gin.Context, rawJson []byte) { func (h *GeminiAPIHandlers) geminiStreamGenerateContent(c *gin.Context, rawJSON []byte) {
alt := h.getAlt(c) alt := h.GetAlt(c)
if alt == "" { if alt == "" {
c.Header("Content-Type", "text/event-stream") c.Header("Content-Type", "text/event-stream")
@@ -119,8 +145,8 @@ func (h *APIHandlers) geminiStreamGenerateContent(c *gin.Context, rawJson []byte
// Get the http.Flusher interface to manually flush the response. // Get the http.Flusher interface to manually flush the response.
flusher, ok := c.Writer.(http.Flusher) flusher, ok := c.Writer.(http.Flusher)
if !ok { if !ok {
c.JSON(http.StatusInternalServerError, ErrorResponse{ c.JSON(http.StatusInternalServerError, handlers.ErrorResponse{
Error: ErrorDetail{ Error: handlers.ErrorDetail{
Message: "Streaming not supported", Message: "Streaming not supported",
Type: "server_error", Type: "server_error",
}, },
@@ -128,7 +154,7 @@ func (h *APIHandlers) geminiStreamGenerateContent(c *gin.Context, rawJson []byte
return return
} }
modelResult := gjson.GetBytes(rawJson, "model") modelResult := gjson.GetBytes(rawJSON, "model")
modelName := modelResult.String() modelName := modelResult.String()
cliCtx, cliCancel := context.WithCancel(context.Background()) cliCtx, cliCancel := context.WithCancel(context.Background())
@@ -143,7 +169,7 @@ func (h *APIHandlers) geminiStreamGenerateContent(c *gin.Context, rawJson []byte
outLoop: outLoop:
for { for {
var errorResponse *client.ErrorMessage var errorResponse *client.ErrorMessage
cliClient, errorResponse = h.getClient(modelName) cliClient, errorResponse = h.GetClient(modelName)
if errorResponse != nil { if errorResponse != nil {
c.Status(errorResponse.StatusCode) c.Status(errorResponse.StatusCode)
_, _ = fmt.Fprint(c.Writer, errorResponse.Error) _, _ = fmt.Fprint(c.Writer, errorResponse.Error)
@@ -153,21 +179,21 @@ outLoop:
} }
template := "" template := ""
parsed := gjson.Parse(string(rawJson)) parsed := gjson.Parse(string(rawJSON))
contents := parsed.Get("request.contents") contents := parsed.Get("request.contents")
if contents.Exists() { if contents.Exists() {
template = string(rawJson) template = string(rawJSON)
} else { } else {
template = `{"project":"","request":{},"model":""}` template = `{"project":"","request":{},"model":""}`
template, _ = sjson.SetRaw(template, "request", string(rawJson)) template, _ = sjson.SetRaw(template, "request", string(rawJSON))
template, _ = sjson.Set(template, "model", gjson.Get(template, "request.model").String()) template, _ = sjson.Set(template, "model", gjson.Get(template, "request.model").String())
template, _ = sjson.Delete(template, "request.model") template, _ = sjson.Delete(template, "request.model")
} }
template, errFixCLIToolResponse := translator.FixCLIToolResponse(template) template, errFixCLIToolResponse := cli.FixCLIToolResponse(template)
if errFixCLIToolResponse != nil { if errFixCLIToolResponse != nil {
c.JSON(http.StatusInternalServerError, ErrorResponse{ c.JSON(http.StatusInternalServerError, handlers.ErrorResponse{
Error: ErrorDetail{ Error: handlers.ErrorDetail{
Message: errFixCLIToolResponse.Error(), Message: errFixCLIToolResponse.Error(),
Type: "server_error", Type: "server_error",
}, },
@@ -181,7 +207,7 @@ outLoop:
template, _ = sjson.SetRaw(template, "request.systemInstruction", systemInstructionResult.Raw) template, _ = sjson.SetRaw(template, "request.systemInstruction", systemInstructionResult.Raw)
template, _ = sjson.Delete(template, "request.system_instruction") template, _ = sjson.Delete(template, "request.system_instruction")
} }
rawJson = []byte(template) rawJSON = []byte(template)
if glAPIKey := cliClient.GetGenerativeLanguageAPIKey(); glAPIKey != "" { if glAPIKey := cliClient.GetGenerativeLanguageAPIKey(); glAPIKey != "" {
log.Debugf("Request use generative language API Key: %s", glAPIKey) log.Debugf("Request use generative language API Key: %s", glAPIKey)
@@ -190,7 +216,7 @@ outLoop:
} }
// Send the message and receive response chunks and errors via channels. // Send the message and receive response chunks and errors via channels.
respChan, errChan := cliClient.SendRawMessageStream(cliCtx, rawJson, alt) respChan, errChan := cliClient.SendRawMessageStream(cliCtx, rawJSON, alt)
for { for {
select { select {
// Handle client disconnection. // Handle client disconnection.
@@ -205,41 +231,40 @@ outLoop:
if !okStream { if !okStream {
cliCancel() cliCancel()
return return
} else { }
if cliClient.GetGenerativeLanguageAPIKey() == "" { if cliClient.GetGenerativeLanguageAPIKey() == "" {
if alt == "" { if alt == "" {
responseResult := gjson.GetBytes(chunk, "response") responseResult := gjson.GetBytes(chunk, "response")
if responseResult.Exists() { if responseResult.Exists() {
chunk = []byte(responseResult.Raw) chunk = []byte(responseResult.Raw)
} }
} else { } else {
chunkTemplate := "[]" chunkTemplate := "[]"
responseResult := gjson.ParseBytes(chunk) responseResult := gjson.ParseBytes(chunk)
if responseResult.IsArray() { if responseResult.IsArray() {
responseResultItems := responseResult.Array() responseResultItems := responseResult.Array()
for i := 0; i < len(responseResultItems); i++ { for i := 0; i < len(responseResultItems); i++ {
responseResultItem := responseResultItems[i] responseResultItem := responseResultItems[i]
if responseResultItem.Get("response").Exists() { if responseResultItem.Get("response").Exists() {
chunkTemplate, _ = sjson.SetRaw(chunkTemplate, "-1", responseResultItem.Get("response").Raw) chunkTemplate, _ = sjson.SetRaw(chunkTemplate, "-1", responseResultItem.Get("response").Raw)
}
} }
} }
chunk = []byte(chunkTemplate)
} }
chunk = []byte(chunkTemplate)
} }
if alt == "" {
_, _ = c.Writer.Write([]byte("data: "))
_, _ = c.Writer.Write(chunk)
_, _ = c.Writer.Write([]byte("\n\n"))
} else {
_, _ = c.Writer.Write(chunk)
}
flusher.Flush()
} }
if alt == "" {
_, _ = c.Writer.Write([]byte("data: "))
_, _ = c.Writer.Write(chunk)
_, _ = c.Writer.Write([]byte("\n\n"))
} else {
_, _ = c.Writer.Write(chunk)
}
flusher.Flush()
// Handle errors from the backend. // Handle errors from the backend.
case err, okError := <-errChan: case err, okError := <-errChan:
if okError { if okError {
if err.StatusCode == 429 && h.cfg.QuotaExceeded.SwitchProject { if err.StatusCode == 429 && h.Cfg.QuotaExceeded.SwitchProject {
log.Debugf("quota exceeded, switch client") log.Debugf("quota exceeded, switch client")
continue outLoop continue outLoop
} else { } else {
@@ -258,12 +283,12 @@ outLoop:
} }
} }
func (h *APIHandlers) geminiCountTokens(c *gin.Context, rawJson []byte) { func (h *GeminiAPIHandlers) geminiCountTokens(c *gin.Context, rawJSON []byte) {
c.Header("Content-Type", "application/json") c.Header("Content-Type", "application/json")
alt := h.getAlt(c) alt := h.GetAlt(c)
// orgRawJson := rawJson // orgrawJSON := rawJSON
modelResult := gjson.GetBytes(rawJson, "model") modelResult := gjson.GetBytes(rawJSON, "model")
modelName := modelResult.String() modelName := modelResult.String()
cliCtx, cliCancel := context.WithCancel(context.Background()) cliCtx, cliCancel := context.WithCancel(context.Background())
var cliClient *client.Client var cliClient *client.Client
@@ -275,7 +300,7 @@ func (h *APIHandlers) geminiCountTokens(c *gin.Context, rawJson []byte) {
for { for {
var errorResponse *client.ErrorMessage var errorResponse *client.ErrorMessage
cliClient, errorResponse = h.getClient(modelName, false) cliClient, errorResponse = h.GetClient(modelName, false)
if errorResponse != nil { if errorResponse != nil {
c.Status(errorResponse.StatusCode) c.Status(errorResponse.StatusCode)
_, _ = fmt.Fprint(c.Writer, errorResponse.Error) _, _ = fmt.Fprint(c.Writer, errorResponse.Error)
@@ -289,27 +314,27 @@ func (h *APIHandlers) geminiCountTokens(c *gin.Context, rawJson []byte) {
log.Debugf("Request use account: %s, project id: %s", cliClient.GetEmail(), cliClient.GetProjectID()) log.Debugf("Request use account: %s, project id: %s", cliClient.GetEmail(), cliClient.GetProjectID())
template := `{"request":{}}` template := `{"request":{}}`
if gjson.GetBytes(rawJson, "generateContentRequest").Exists() { if gjson.GetBytes(rawJSON, "generateContentRequest").Exists() {
template, _ = sjson.SetRaw(template, "request", gjson.GetBytes(rawJson, "generateContentRequest").Raw) template, _ = sjson.SetRaw(template, "request", gjson.GetBytes(rawJSON, "generateContentRequest").Raw)
template, _ = sjson.Delete(template, "generateContentRequest") template, _ = sjson.Delete(template, "generateContentRequest")
} else if gjson.GetBytes(rawJson, "contents").Exists() { } else if gjson.GetBytes(rawJSON, "contents").Exists() {
template, _ = sjson.SetRaw(template, "request.contents", gjson.GetBytes(rawJson, "contents").Raw) template, _ = sjson.SetRaw(template, "request.contents", gjson.GetBytes(rawJSON, "contents").Raw)
template, _ = sjson.Delete(template, "contents") template, _ = sjson.Delete(template, "contents")
} }
rawJson = []byte(template) rawJSON = []byte(template)
} }
resp, err := cliClient.SendRawTokenCount(cliCtx, rawJson, alt) resp, err := cliClient.SendRawTokenCount(cliCtx, rawJSON, alt)
if err != nil { if err != nil {
if err.StatusCode == 429 && h.cfg.QuotaExceeded.SwitchProject { if err.StatusCode == 429 && h.Cfg.QuotaExceeded.SwitchProject {
continue continue
} else { } else {
c.Status(err.StatusCode) c.Status(err.StatusCode)
_, _ = c.Writer.Write([]byte(err.Error.Error())) _, _ = c.Writer.Write([]byte(err.Error.Error()))
cliCancel() cliCancel()
// log.Debugf(err.Error.Error()) // log.Debugf(err.Error.Error())
// log.Debugf(string(rawJson)) // log.Debugf(string(rawJSON))
// log.Debugf(string(orgRawJson)) // log.Debugf(string(orgrawJSON))
} }
break break
} else { } else {
@@ -326,12 +351,12 @@ func (h *APIHandlers) geminiCountTokens(c *gin.Context, rawJson []byte) {
} }
} }
func (h *APIHandlers) geminiGenerateContent(c *gin.Context, rawJson []byte) { func (h *GeminiAPIHandlers) geminiGenerateContent(c *gin.Context, rawJSON []byte) {
c.Header("Content-Type", "application/json") c.Header("Content-Type", "application/json")
alt := h.getAlt(c) alt := h.GetAlt(c)
modelResult := gjson.GetBytes(rawJson, "model") modelResult := gjson.GetBytes(rawJSON, "model")
modelName := modelResult.String() modelName := modelResult.String()
cliCtx, cliCancel := context.WithCancel(context.Background()) cliCtx, cliCancel := context.WithCancel(context.Background())
var cliClient *client.Client var cliClient *client.Client
@@ -343,7 +368,7 @@ func (h *APIHandlers) geminiGenerateContent(c *gin.Context, rawJson []byte) {
for { for {
var errorResponse *client.ErrorMessage var errorResponse *client.ErrorMessage
cliClient, errorResponse = h.getClient(modelName) cliClient, errorResponse = h.GetClient(modelName)
if errorResponse != nil { if errorResponse != nil {
c.Status(errorResponse.StatusCode) c.Status(errorResponse.StatusCode)
_, _ = fmt.Fprint(c.Writer, errorResponse.Error) _, _ = fmt.Fprint(c.Writer, errorResponse.Error)
@@ -352,21 +377,21 @@ func (h *APIHandlers) geminiGenerateContent(c *gin.Context, rawJson []byte) {
} }
template := "" template := ""
parsed := gjson.Parse(string(rawJson)) parsed := gjson.Parse(string(rawJSON))
contents := parsed.Get("request.contents") contents := parsed.Get("request.contents")
if contents.Exists() { if contents.Exists() {
template = string(rawJson) template = string(rawJSON)
} else { } else {
template = `{"project":"","request":{},"model":""}` template = `{"project":"","request":{},"model":""}`
template, _ = sjson.SetRaw(template, "request", string(rawJson)) template, _ = sjson.SetRaw(template, "request", string(rawJSON))
template, _ = sjson.Set(template, "model", gjson.Get(template, "request.model").String()) template, _ = sjson.Set(template, "model", gjson.Get(template, "request.model").String())
template, _ = sjson.Delete(template, "request.model") template, _ = sjson.Delete(template, "request.model")
} }
template, errFixCLIToolResponse := translator.FixCLIToolResponse(template) template, errFixCLIToolResponse := cli.FixCLIToolResponse(template)
if errFixCLIToolResponse != nil { if errFixCLIToolResponse != nil {
c.JSON(http.StatusInternalServerError, ErrorResponse{ c.JSON(http.StatusInternalServerError, handlers.ErrorResponse{
Error: ErrorDetail{ Error: handlers.ErrorDetail{
Message: errFixCLIToolResponse.Error(), Message: errFixCLIToolResponse.Error(),
Type: "server_error", Type: "server_error",
}, },
@@ -380,16 +405,16 @@ func (h *APIHandlers) geminiGenerateContent(c *gin.Context, rawJson []byte) {
template, _ = sjson.SetRaw(template, "request.systemInstruction", systemInstructionResult.Raw) template, _ = sjson.SetRaw(template, "request.systemInstruction", systemInstructionResult.Raw)
template, _ = sjson.Delete(template, "request.system_instruction") template, _ = sjson.Delete(template, "request.system_instruction")
} }
rawJson = []byte(template) rawJSON = []byte(template)
if glAPIKey := cliClient.GetGenerativeLanguageAPIKey(); glAPIKey != "" { if glAPIKey := cliClient.GetGenerativeLanguageAPIKey(); glAPIKey != "" {
log.Debugf("Request use generative language API Key: %s", glAPIKey) log.Debugf("Request use generative language API Key: %s", glAPIKey)
} else { } else {
log.Debugf("Request use account: %s, project id: %s", cliClient.GetEmail(), cliClient.GetProjectID()) log.Debugf("Request use account: %s, project id: %s", cliClient.GetEmail(), cliClient.GetProjectID())
} }
resp, err := cliClient.SendRawMessage(cliCtx, rawJson, alt) resp, err := cliClient.SendRawMessage(cliCtx, rawJSON, alt)
if err != nil { if err != nil {
if err.StatusCode == 429 && h.cfg.QuotaExceeded.SwitchProject { if err.StatusCode == 429 && h.Cfg.QuotaExceeded.SwitchProject {
continue continue
} else { } else {
c.Status(err.StatusCode) c.Status(err.StatusCode)
@@ -410,16 +435,3 @@ func (h *APIHandlers) geminiGenerateContent(c *gin.Context, rawJson []byte) {
} }
} }
} }
func (h *APIHandlers) getAlt(c *gin.Context) string {
var alt string
var hasAlt bool
alt, hasAlt = c.GetQuery("alt")
if !hasAlt {
alt, _ = c.GetQuery("$alt")
}
if alt == "sse" {
return ""
}
return alt
}

View File

@@ -0,0 +1,122 @@
// Package handlers provides core API handler functionality for the CLI Proxy API server.
// It includes common types, client management, load balancing, and error handling
// shared across all API endpoint handlers (OpenAI, Claude, Gemini).
package handlers
import (
"fmt"
"github.com/gin-gonic/gin"
"github.com/luispater/CLIProxyAPI/internal/client"
"github.com/luispater/CLIProxyAPI/internal/config"
log "github.com/sirupsen/logrus"
"sync"
)
// ErrorResponse represents a standard error response format for the API.
// It contains a single ErrorDetail field.
type ErrorResponse struct {
Error ErrorDetail `json:"error"`
}
// ErrorDetail provides specific information about an error that occurred.
// It includes a human-readable message, an error type, and an optional error code.
type ErrorDetail struct {
// A human-readable message providing more details about the error.
Message string `json:"message"`
// The type of error that occurred (e.g., "invalid_request_error").
Type string `json:"type"`
// A short code identifying the error, if applicable.
Code string `json:"code,omitempty"`
}
// APIHandlers contains the handlers for API endpoints.
// It holds a pool of clients to interact with the backend service.
type APIHandlers struct {
CliClients []*client.Client
Cfg *config.Config
Mutex *sync.Mutex
LastUsedClientIndex int
}
// NewAPIHandlers creates a new API handlers instance.
// It takes a slice of clients and a debug flag as input.
func NewAPIHandlers(cliClients []*client.Client, cfg *config.Config) *APIHandlers {
return &APIHandlers{
CliClients: cliClients,
Cfg: cfg,
Mutex: &sync.Mutex{},
LastUsedClientIndex: 0,
}
}
// UpdateClients updates the handlers' client list and configuration
func (h *APIHandlers) UpdateClients(clients []*client.Client, cfg *config.Config) {
h.CliClients = clients
h.Cfg = cfg
}
// GetClient returns an available client from the pool using round-robin load balancing.
// It checks for quota limits and tries to find an unlocked client for immediate use.
// The modelName parameter is used to check quota status for specific models.
func (h *APIHandlers) GetClient(modelName string, isGenerateContent ...bool) (*client.Client, *client.ErrorMessage) {
if len(h.CliClients) == 0 {
return nil, &client.ErrorMessage{StatusCode: 500, Error: fmt.Errorf("no clients available")}
}
var cliClient *client.Client
// Lock the mutex to update the last used client index
h.Mutex.Lock()
startIndex := h.LastUsedClientIndex
if (len(isGenerateContent) > 0 && isGenerateContent[0]) || len(isGenerateContent) == 0 {
currentIndex := (startIndex + 1) % len(h.CliClients)
h.LastUsedClientIndex = currentIndex
}
h.Mutex.Unlock()
// Reorder the client to start from the last used index
reorderedClients := make([]*client.Client, 0)
for i := 0; i < len(h.CliClients); i++ {
cliClient = h.CliClients[(startIndex+1+i)%len(h.CliClients)]
if cliClient.IsModelQuotaExceeded(modelName) {
log.Debugf("Model %s is quota exceeded for account %s, project id: %s", modelName, cliClient.GetEmail(), cliClient.GetProjectID())
cliClient = nil
continue
}
reorderedClients = append(reorderedClients, cliClient)
}
if len(reorderedClients) == 0 {
return nil, &client.ErrorMessage{StatusCode: 429, Error: fmt.Errorf(`{"error":{"code":429,"message":"All the models of '%s' are quota exceeded","status":"RESOURCE_EXHAUSTED"}}`, modelName)}
}
locked := false
for i := 0; i < len(reorderedClients); i++ {
cliClient = reorderedClients[i]
if cliClient.RequestMutex.TryLock() {
locked = true
break
}
}
if !locked {
cliClient = h.CliClients[0]
cliClient.RequestMutex.Lock()
}
return cliClient, nil
}
// GetAlt extracts the 'alt' parameter from the request query string.
// It checks both 'alt' and '$alt' parameters and returns the appropriate value.
func (h *APIHandlers) GetAlt(c *gin.Context) string {
var alt string
var hasAlt bool
alt, hasAlt = c.GetQuery("alt")
if !hasAlt {
alt, _ = c.GetQuery("$alt")
}
if alt == "sse" {
return ""
}
return alt
}

View File

@@ -1,50 +1,42 @@
package api // Package openai provides HTTP handlers for OpenAI API endpoints.
// This package implements the OpenAI-compatible API interface, including model listing
// and chat completion functionality. It supports both streaming and non-streaming responses,
// and manages a pool of clients to interact with backend services.
// The handlers translate OpenAI API requests to the appropriate backend format and
// convert responses back to OpenAI-compatible format.
package openai
import ( import (
"context" "context"
"fmt" "fmt"
"github.com/luispater/CLIProxyAPI/internal/api/translator" "github.com/luispater/CLIProxyAPI/internal/api/handlers"
"github.com/luispater/CLIProxyAPI/internal/api/translator/openai"
"github.com/luispater/CLIProxyAPI/internal/client" "github.com/luispater/CLIProxyAPI/internal/client"
"github.com/luispater/CLIProxyAPI/internal/config"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"github.com/tidwall/gjson" "github.com/tidwall/gjson"
"net/http" "net/http"
"sync"
"time" "time"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
) )
var ( // OpenAIAPIHandlers contains the handlers for OpenAI API endpoints.
mutex = &sync.Mutex{}
lastUsedClientIndex = 0
)
// APIHandlers contains the handlers for API endpoints.
// It holds a pool of clients to interact with the backend service. // It holds a pool of clients to interact with the backend service.
type APIHandlers struct { type OpenAIAPIHandlers struct {
cliClients []*client.Client *handlers.APIHandlers
cfg *config.Config
} }
// NewAPIHandlers creates a new API handlers instance. // NewOpenAIAPIHandlers creates a new OpenAI API handlers instance.
// It takes a slice of clients and a debug flag as input. // It takes an APIHandlers instance as input and returns an OpenAIAPIHandlers.
func NewAPIHandlers(cliClients []*client.Client, cfg *config.Config) *APIHandlers { func NewOpenAIAPIHandlers(apiHandlers *handlers.APIHandlers) *OpenAIAPIHandlers {
return &APIHandlers{ return &OpenAIAPIHandlers{
cliClients: cliClients, APIHandlers: apiHandlers,
cfg: cfg,
} }
} }
// UpdateClients updates the handlers' client list and configuration
func (h *APIHandlers) UpdateClients(clients []*client.Client, cfg *config.Config) {
h.cliClients = clients
h.cfg = cfg
}
// Models handles the /v1/models endpoint. // Models handles the /v1/models endpoint.
// It returns a hardcoded list of available AI models. // It returns a hardcoded list of available AI models.
func (h *APIHandlers) Models(c *gin.Context) { func (h *OpenAIAPIHandlers) Models(c *gin.Context) {
c.JSON(http.StatusOK, gin.H{ c.JSON(http.StatusOK, gin.H{
"data": []map[string]any{ "data": []map[string]any{
{ {
@@ -91,63 +83,15 @@ func (h *APIHandlers) Models(c *gin.Context) {
}) })
} }
func (h *APIHandlers) getClient(modelName string, isGenerateContent ...bool) (*client.Client, *client.ErrorMessage) {
if len(h.cliClients) == 0 {
return nil, &client.ErrorMessage{StatusCode: 500, Error: fmt.Errorf("no clients available")}
}
var cliClient *client.Client
// Lock the mutex to update the last used client index
mutex.Lock()
startIndex := lastUsedClientIndex
if (len(isGenerateContent) > 0 && isGenerateContent[0]) || len(isGenerateContent) == 0 {
currentIndex := (startIndex + 1) % len(h.cliClients)
lastUsedClientIndex = currentIndex
}
mutex.Unlock()
// Reorder the client to start from the last used index
reorderedClients := make([]*client.Client, 0)
for i := 0; i < len(h.cliClients); i++ {
cliClient = h.cliClients[(startIndex+1+i)%len(h.cliClients)]
if cliClient.IsModelQuotaExceeded(modelName) {
log.Debugf("Model %s is quota exceeded for account %s, project id: %s", modelName, cliClient.GetEmail(), cliClient.GetProjectID())
cliClient = nil
continue
}
reorderedClients = append(reorderedClients, cliClient)
}
if len(reorderedClients) == 0 {
return nil, &client.ErrorMessage{StatusCode: 429, Error: fmt.Errorf(`{"error":{"code":429,"message":"All the models of '%s' are quota exceeded","status":"RESOURCE_EXHAUSTED"}}`, modelName)}
}
locked := false
for i := 0; i < len(reorderedClients); i++ {
cliClient = reorderedClients[i]
if cliClient.RequestMutex.TryLock() {
locked = true
break
}
}
if !locked {
cliClient = h.cliClients[0]
cliClient.RequestMutex.Lock()
}
return cliClient, nil
}
// ChatCompletions handles the /v1/chat/completions endpoint. // ChatCompletions handles the /v1/chat/completions endpoint.
// It determines whether the request is for a streaming or non-streaming response // It determines whether the request is for a streaming or non-streaming response
// and calls the appropriate handler. // and calls the appropriate handler.
func (h *APIHandlers) ChatCompletions(c *gin.Context) { func (h *OpenAIAPIHandlers) ChatCompletions(c *gin.Context) {
rawJson, err := c.GetRawData() rawJSON, err := c.GetRawData()
// If data retrieval fails, return a 400 Bad Request error. // If data retrieval fails, return a 400 Bad Request error.
if err != nil { if err != nil {
c.JSON(http.StatusBadRequest, ErrorResponse{ c.JSON(http.StatusBadRequest, handlers.ErrorResponse{
Error: ErrorDetail{ Error: handlers.ErrorDetail{
Message: fmt.Sprintf("Invalid request: %v", err), Message: fmt.Sprintf("Invalid request: %v", err),
Type: "invalid_request_error", Type: "invalid_request_error",
}, },
@@ -156,21 +100,21 @@ func (h *APIHandlers) ChatCompletions(c *gin.Context) {
} }
// Check if the client requested a streaming response. // Check if the client requested a streaming response.
streamResult := gjson.GetBytes(rawJson, "stream") streamResult := gjson.GetBytes(rawJSON, "stream")
if streamResult.Type == gjson.True { if streamResult.Type == gjson.True {
h.handleStreamingResponse(c, rawJson) h.handleStreamingResponse(c, rawJSON)
} else { } else {
h.handleNonStreamingResponse(c, rawJson) h.handleNonStreamingResponse(c, rawJSON)
} }
} }
// handleNonStreamingResponse handles non-streaming chat completion responses. // handleNonStreamingResponse handles non-streaming chat completion responses.
// It selects a client from the pool, sends the request, and aggregates the response // It selects a client from the pool, sends the request, and aggregates the response
// before sending it back to the client. // before sending it back to the client.
func (h *APIHandlers) handleNonStreamingResponse(c *gin.Context, rawJson []byte) { func (h *OpenAIAPIHandlers) handleNonStreamingResponse(c *gin.Context, rawJSON []byte) {
c.Header("Content-Type", "application/json") c.Header("Content-Type", "application/json")
modelName, systemInstruction, contents, tools := translator.PrepareRequest(rawJson) modelName, systemInstruction, contents, tools := openai.PrepareRequest(rawJSON)
cliCtx, cliCancel := context.WithCancel(context.Background()) cliCtx, cliCancel := context.WithCancel(context.Background())
var cliClient *client.Client var cliClient *client.Client
defer func() { defer func() {
@@ -181,7 +125,7 @@ func (h *APIHandlers) handleNonStreamingResponse(c *gin.Context, rawJson []byte)
for { for {
var errorResponse *client.ErrorMessage var errorResponse *client.ErrorMessage
cliClient, errorResponse = h.getClient(modelName) cliClient, errorResponse = h.GetClient(modelName)
if errorResponse != nil { if errorResponse != nil {
c.Status(errorResponse.StatusCode) c.Status(errorResponse.StatusCode)
_, _ = fmt.Fprint(c.Writer, errorResponse.Error) _, _ = fmt.Fprint(c.Writer, errorResponse.Error)
@@ -197,9 +141,9 @@ func (h *APIHandlers) handleNonStreamingResponse(c *gin.Context, rawJson []byte)
log.Debugf("Request use account: %s, project id: %s", cliClient.GetEmail(), cliClient.GetProjectID()) log.Debugf("Request use account: %s, project id: %s", cliClient.GetEmail(), cliClient.GetProjectID())
} }
resp, err := cliClient.SendMessage(cliCtx, rawJson, modelName, systemInstruction, contents, tools) resp, err := cliClient.SendMessage(cliCtx, rawJSON, modelName, systemInstruction, contents, tools)
if err != nil { if err != nil {
if err.StatusCode == 429 && h.cfg.QuotaExceeded.SwitchProject { if err.StatusCode == 429 && h.Cfg.QuotaExceeded.SwitchProject {
continue continue
} else { } else {
c.Status(err.StatusCode) c.Status(err.StatusCode)
@@ -208,7 +152,7 @@ func (h *APIHandlers) handleNonStreamingResponse(c *gin.Context, rawJson []byte)
} }
break break
} else { } else {
openAIFormat := translator.ConvertCliToOpenAINonStream(resp, time.Now().Unix(), isGlAPIKey) openAIFormat := openai.ConvertCliToOpenAINonStream(resp, time.Now().Unix(), isGlAPIKey)
if openAIFormat != "" { if openAIFormat != "" {
_, _ = c.Writer.Write([]byte(openAIFormat)) _, _ = c.Writer.Write([]byte(openAIFormat))
} }
@@ -219,7 +163,7 @@ func (h *APIHandlers) handleNonStreamingResponse(c *gin.Context, rawJson []byte)
} }
// handleStreamingResponse handles streaming responses // handleStreamingResponse handles streaming responses
func (h *APIHandlers) handleStreamingResponse(c *gin.Context, rawJson []byte) { func (h *OpenAIAPIHandlers) handleStreamingResponse(c *gin.Context, rawJSON []byte) {
c.Header("Content-Type", "text/event-stream") c.Header("Content-Type", "text/event-stream")
c.Header("Cache-Control", "no-cache") c.Header("Cache-Control", "no-cache")
c.Header("Connection", "keep-alive") c.Header("Connection", "keep-alive")
@@ -228,8 +172,8 @@ func (h *APIHandlers) handleStreamingResponse(c *gin.Context, rawJson []byte) {
// Get the http.Flusher interface to manually flush the response. // Get the http.Flusher interface to manually flush the response.
flusher, ok := c.Writer.(http.Flusher) flusher, ok := c.Writer.(http.Flusher)
if !ok { if !ok {
c.JSON(http.StatusInternalServerError, ErrorResponse{ c.JSON(http.StatusInternalServerError, handlers.ErrorResponse{
Error: ErrorDetail{ Error: handlers.ErrorDetail{
Message: "Streaming not supported", Message: "Streaming not supported",
Type: "server_error", Type: "server_error",
}, },
@@ -238,7 +182,7 @@ func (h *APIHandlers) handleStreamingResponse(c *gin.Context, rawJson []byte) {
} }
// Prepare the request for the backend client. // Prepare the request for the backend client.
modelName, systemInstruction, contents, tools := translator.PrepareRequest(rawJson) modelName, systemInstruction, contents, tools := openai.PrepareRequest(rawJSON)
cliCtx, cliCancel := context.WithCancel(context.Background()) cliCtx, cliCancel := context.WithCancel(context.Background())
var cliClient *client.Client var cliClient *client.Client
defer func() { defer func() {
@@ -251,7 +195,7 @@ func (h *APIHandlers) handleStreamingResponse(c *gin.Context, rawJson []byte) {
outLoop: outLoop:
for { for {
var errorResponse *client.ErrorMessage var errorResponse *client.ErrorMessage
cliClient, errorResponse = h.getClient(modelName) cliClient, errorResponse = h.GetClient(modelName)
if errorResponse != nil { if errorResponse != nil {
c.Status(errorResponse.StatusCode) c.Status(errorResponse.StatusCode)
_, _ = fmt.Fprint(c.Writer, errorResponse.Error) _, _ = fmt.Fprint(c.Writer, errorResponse.Error)
@@ -268,7 +212,7 @@ outLoop:
log.Debugf("Request use account: %s, project id: %s", cliClient.GetEmail(), cliClient.GetProjectID()) log.Debugf("Request use account: %s, project id: %s", cliClient.GetEmail(), cliClient.GetProjectID())
} }
// Send the message and receive response chunks and errors via channels. // Send the message and receive response chunks and errors via channels.
respChan, errChan := cliClient.SendMessageStream(cliCtx, rawJson, modelName, systemInstruction, contents, tools) respChan, errChan := cliClient.SendMessageStream(cliCtx, rawJSON, modelName, systemInstruction, contents, tools)
hasFirstResponse := false hasFirstResponse := false
for { for {
select { select {
@@ -287,19 +231,18 @@ outLoop:
flusher.Flush() flusher.Flush()
cliCancel() cliCancel()
return return
} else { }
// Convert the chunk to OpenAI format and send it to the client. // Convert the chunk to OpenAI format and send it to the client.
hasFirstResponse = true hasFirstResponse = true
openAIFormat := translator.ConvertCliToOpenAI(chunk, time.Now().Unix(), isGlAPIKey) openAIFormat := openai.ConvertCliToOpenAI(chunk, time.Now().Unix(), isGlAPIKey)
if openAIFormat != "" { if openAIFormat != "" {
_, _ = fmt.Fprintf(c.Writer, "data: %s\n\n", openAIFormat) _, _ = fmt.Fprintf(c.Writer, "data: %s\n\n", openAIFormat)
flusher.Flush() flusher.Flush()
}
} }
// Handle errors from the backend. // Handle errors from the backend.
case err, okError := <-errChan: case err, okError := <-errChan:
if okError { if okError {
if err.StatusCode == 429 && h.cfg.QuotaExceeded.SwitchProject { if err.StatusCode == 429 && h.Cfg.QuotaExceeded.SwitchProject {
continue outLoop continue outLoop
} else { } else {
c.Status(err.StatusCode) c.Status(err.StatusCode)

View File

@@ -1,18 +0,0 @@
package api
// ErrorResponse represents a standard error response format for the API.
// It contains a single ErrorDetail field.
type ErrorResponse struct {
Error ErrorDetail `json:"error"`
}
// ErrorDetail provides specific information about an error that occurred.
// It includes a human-readable message, an error type, and an optional error code.
type ErrorDetail struct {
// A human-readable message providing more details about the error.
Message string `json:"message"`
// The type of error that occurred (e.g., "invalid_request_error").
Type string `json:"type"`
// A short code identifying the error, if applicable.
Code string `json:"code,omitempty"`
}

View File

@@ -1,3 +1,7 @@
// Package api provides the HTTP API server implementation for the CLI Proxy API.
// It includes the main server struct, routing setup, middleware for CORS and authentication,
// and integration with various AI API handlers (OpenAI, Claude, Gemini).
// The server supports hot-reloading of clients and configuration.
package api package api
import ( import (
@@ -5,6 +9,11 @@ import (
"errors" "errors"
"fmt" "fmt"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/luispater/CLIProxyAPI/internal/api/handlers"
"github.com/luispater/CLIProxyAPI/internal/api/handlers/claude"
"github.com/luispater/CLIProxyAPI/internal/api/handlers/gemini"
"github.com/luispater/CLIProxyAPI/internal/api/handlers/gemini/cli"
"github.com/luispater/CLIProxyAPI/internal/api/handlers/openai"
"github.com/luispater/CLIProxyAPI/internal/client" "github.com/luispater/CLIProxyAPI/internal/client"
"github.com/luispater/CLIProxyAPI/internal/config" "github.com/luispater/CLIProxyAPI/internal/config"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
@@ -17,7 +26,7 @@ import (
type Server struct { type Server struct {
engine *gin.Engine engine *gin.Engine
server *http.Server server *http.Server
handlers *APIHandlers handlers *handlers.APIHandlers
cfg *config.Config cfg *config.Config
} }
@@ -29,9 +38,6 @@ func NewServer(cfg *config.Config, cliClients []*client.Client) *Server {
gin.SetMode(gin.ReleaseMode) gin.SetMode(gin.ReleaseMode)
} }
// Create handlers
handlers := NewAPIHandlers(cliClients, cfg)
// Create gin engine // Create gin engine
engine := gin.New() engine := gin.New()
@@ -43,7 +49,7 @@ func NewServer(cfg *config.Config, cliClients []*client.Client) *Server {
// Create server instance // Create server instance
s := &Server{ s := &Server{
engine: engine, engine: engine,
handlers: handlers, handlers: handlers.NewAPIHandlers(cliClients, cfg),
cfg: cfg, cfg: cfg,
} }
@@ -62,22 +68,27 @@ func NewServer(cfg *config.Config, cliClients []*client.Client) *Server {
// setupRoutes configures the API routes for the server. // setupRoutes configures the API routes for the server.
// It defines the endpoints and associates them with their respective handlers. // It defines the endpoints and associates them with their respective handlers.
func (s *Server) setupRoutes() { func (s *Server) setupRoutes() {
openaiHandlers := openai.NewOpenAIAPIHandlers(s.handlers)
geminiHandlers := gemini.NewGeminiAPIHandlers(s.handlers)
geminiCLIHandlers := cli.NewGeminiCLIAPIHandlers(s.handlers)
claudeCodeHandlers := claude.NewClaudeCodeAPIHandlers(s.handlers)
// OpenAI compatible API routes // OpenAI compatible API routes
v1 := s.engine.Group("/v1") v1 := s.engine.Group("/v1")
v1.Use(AuthMiddleware(s.cfg)) v1.Use(AuthMiddleware(s.cfg))
{ {
v1.GET("/models", s.handlers.Models) v1.GET("/models", openaiHandlers.Models)
v1.POST("/chat/completions", s.handlers.ChatCompletions) v1.POST("/chat/completions", openaiHandlers.ChatCompletions)
v1.POST("/messages", s.handlers.ClaudeMessages) v1.POST("/messages", claudeCodeHandlers.ClaudeMessages)
} }
// Gemini compatible API routes // Gemini compatible API routes
v1beta := s.engine.Group("/v1beta") v1beta := s.engine.Group("/v1beta")
v1beta.Use(AuthMiddleware(s.cfg)) v1beta.Use(AuthMiddleware(s.cfg))
{ {
v1beta.GET("/models", s.handlers.GeminiModels) v1beta.GET("/models", geminiHandlers.GeminiModels)
v1beta.POST("/models/:action", s.handlers.GeminiHandler) v1beta.POST("/models/:action", geminiHandlers.GeminiHandler)
v1beta.GET("/models/:action", s.handlers.GeminiGetHandler) v1beta.GET("/models/:action", geminiHandlers.GeminiGetHandler)
} }
// Root endpoint // Root endpoint
@@ -91,7 +102,7 @@ func (s *Server) setupRoutes() {
}, },
}) })
}) })
s.engine.POST("/v1internal:method", s.handlers.CLIHandler) s.engine.POST("/v1internal:method", geminiCLIHandlers.CLIHandler)
} }
@@ -150,7 +161,7 @@ func (s *Server) UpdateClients(clients []*client.Client, cfg *config.Config) {
// using API keys. If no API keys are configured, it allows all requests. // using API keys. If no API keys are configured, it allows all requests.
func AuthMiddleware(cfg *config.Config) gin.HandlerFunc { func AuthMiddleware(cfg *config.Config) gin.HandlerFunc {
return func(c *gin.Context) { return func(c *gin.Context) {
if len(cfg.ApiKeys) == 0 { if len(cfg.APIKeys) == 0 {
c.Next() c.Next()
return return
} }
@@ -181,9 +192,9 @@ func AuthMiddleware(cfg *config.Config) gin.HandlerFunc {
// Find the API key in the in-memory list // Find the API key in the in-memory list
var foundKey string var foundKey string
for i := range cfg.ApiKeys { for i := range cfg.APIKeys {
if cfg.ApiKeys[i] == apiKey || cfg.ApiKeys[i] == authHeaderGoogle || cfg.ApiKeys[i] == authHeaderAnthropic || cfg.ApiKeys[i] == apiKeyQuery { if cfg.APIKeys[i] == apiKey || cfg.APIKeys[i] == authHeaderGoogle || cfg.APIKeys[i] == authHeaderAnthropic || cfg.APIKeys[i] == apiKeyQuery {
foundKey = cfg.ApiKeys[i] foundKey = cfg.APIKeys[i]
break break
} }
} }

View File

@@ -0,0 +1,169 @@
// Package code provides request translation functionality for Claude API.
// It handles parsing and transforming Claude API requests into the internal client format,
// extracting model information, system instructions, message contents, and tool declarations.
// The package also performs JSON data cleaning and transformation to ensure compatibility
// between Claude API format and the internal client's expected format.
package code
import (
"bytes"
"encoding/json"
"github.com/luispater/CLIProxyAPI/internal/client"
"github.com/tidwall/gjson"
"github.com/tidwall/sjson"
"strings"
)
// PrepareClaudeRequest parses and transforms a Claude API request into internal client format.
// It extracts the model name, system instruction, message contents, and tool declarations
// from the raw JSON request and returns them in the format expected by the internal client.
func PrepareClaudeRequest(rawJSON []byte) (string, *client.Content, []client.Content, []client.ToolDeclaration) {
var pathsToDelete []string
root := gjson.ParseBytes(rawJSON)
walk(root, "", "additionalProperties", &pathsToDelete)
walk(root, "", "$schema", &pathsToDelete)
var err error
for _, p := range pathsToDelete {
rawJSON, err = sjson.DeleteBytes(rawJSON, p)
if err != nil {
continue
}
}
rawJSON = bytes.Replace(rawJSON, []byte(`"url":{"type":"string","format":"uri",`), []byte(`"url":{"type":"string",`), -1)
// log.Debug(string(rawJSON))
modelName := "gemini-2.5-pro"
modelResult := gjson.GetBytes(rawJSON, "model")
if modelResult.Type == gjson.String {
modelName = modelResult.String()
}
contents := make([]client.Content, 0)
var systemInstruction *client.Content
systemResult := gjson.GetBytes(rawJSON, "system")
if systemResult.IsArray() {
systemResults := systemResult.Array()
systemInstruction = &client.Content{Role: "user", Parts: []client.Part{}}
for i := 0; i < len(systemResults); i++ {
systemPromptResult := systemResults[i]
systemTypePromptResult := systemPromptResult.Get("type")
if systemTypePromptResult.Type == gjson.String && systemTypePromptResult.String() == "text" {
systemPrompt := systemPromptResult.Get("text").String()
systemPart := client.Part{Text: systemPrompt}
systemInstruction.Parts = append(systemInstruction.Parts, systemPart)
}
}
if len(systemInstruction.Parts) == 0 {
systemInstruction = nil
}
}
messagesResult := gjson.GetBytes(rawJSON, "messages")
if messagesResult.IsArray() {
messageResults := messagesResult.Array()
for i := 0; i < len(messageResults); i++ {
messageResult := messageResults[i]
roleResult := messageResult.Get("role")
if roleResult.Type != gjson.String {
continue
}
role := roleResult.String()
if role == "assistant" {
role = "model"
}
clientContent := client.Content{Role: role, Parts: []client.Part{}}
contentsResult := messageResult.Get("content")
if contentsResult.IsArray() {
contentResults := contentsResult.Array()
for j := 0; j < len(contentResults); j++ {
contentResult := contentResults[j]
contentTypeResult := contentResult.Get("type")
if contentTypeResult.Type == gjson.String && contentTypeResult.String() == "text" {
prompt := contentResult.Get("text").String()
clientContent.Parts = append(clientContent.Parts, client.Part{Text: prompt})
} else if contentTypeResult.Type == gjson.String && contentTypeResult.String() == "tool_use" {
functionName := contentResult.Get("name").String()
functionArgs := contentResult.Get("input").String()
var args map[string]any
if err = json.Unmarshal([]byte(functionArgs), &args); err == nil {
clientContent.Parts = append(clientContent.Parts, client.Part{
FunctionCall: &client.FunctionCall{
Name: functionName,
Args: args,
},
})
}
} else if contentTypeResult.Type == gjson.String && contentTypeResult.String() == "tool_result" {
toolCallID := contentResult.Get("tool_use_id").String()
if toolCallID != "" {
funcName := toolCallID
toolCallIDs := strings.Split(toolCallID, "-")
if len(toolCallIDs) > 1 {
funcName = strings.Join(toolCallIDs[0:len(toolCallIDs)-1], "-")
}
responseData := contentResult.Get("content").String()
functionResponse := client.FunctionResponse{Name: funcName, Response: map[string]interface{}{"result": responseData}}
clientContent.Parts = append(clientContent.Parts, client.Part{FunctionResponse: &functionResponse})
}
}
}
contents = append(contents, clientContent)
} else if contentsResult.Type == gjson.String {
prompt := contentsResult.String()
contents = append(contents, client.Content{Role: role, Parts: []client.Part{{Text: prompt}}})
}
}
}
var tools []client.ToolDeclaration
toolsResult := gjson.GetBytes(rawJSON, "tools")
if toolsResult.IsArray() {
tools = make([]client.ToolDeclaration, 1)
tools[0].FunctionDeclarations = make([]any, 0)
toolsResults := toolsResult.Array()
for i := 0; i < len(toolsResults); i++ {
toolResult := toolsResults[i]
inputSchemaResult := toolResult.Get("input_schema")
if inputSchemaResult.Exists() && inputSchemaResult.IsObject() {
inputSchema := inputSchemaResult.Raw
inputSchema, _ = sjson.Delete(inputSchema, "additionalProperties")
inputSchema, _ = sjson.Delete(inputSchema, "$schema")
tool, _ := sjson.Delete(toolResult.Raw, "input_schema")
tool, _ = sjson.SetRaw(tool, "parameters", inputSchema)
var toolDeclaration any
if err = json.Unmarshal([]byte(tool), &toolDeclaration); err == nil {
tools[0].FunctionDeclarations = append(tools[0].FunctionDeclarations, toolDeclaration)
}
}
}
} else {
tools = make([]client.ToolDeclaration, 0)
}
return modelName, systemInstruction, contents, tools
}
func walk(value gjson.Result, path, field string, pathsToDelete *[]string) {
switch value.Type {
case gjson.JSON:
value.ForEach(func(key, val gjson.Result) bool {
var childPath string
if path == "" {
childPath = key.String()
} else {
childPath = path + "." + key.String()
}
if key.String() == field {
*pathsToDelete = append(*pathsToDelete, childPath)
}
walk(val, childPath, field, pathsToDelete)
return true
})
case gjson.String, gjson.Number, gjson.True, gjson.False, gjson.Null:
}
}

View File

@@ -0,0 +1,206 @@
// Package code provides response translation functionality for Claude API.
// This package handles the conversion of backend client responses into Claude-compatible
// Server-Sent Events (SSE) format, implementing a sophisticated state machine that manages
// different response types including text content, thinking processes, and function calls.
// The translation ensures proper sequencing of SSE events and maintains state across
// multiple response chunks to provide a seamless streaming experience.
package code
import (
"bytes"
"fmt"
"github.com/tidwall/gjson"
"github.com/tidwall/sjson"
"time"
)
// ConvertCliToClaude performs sophisticated streaming response format conversion.
// This function implements a complex state machine that translates backend client responses
// into Claude-compatible Server-Sent Events (SSE) format. It manages different response types
// and handles state transitions between content blocks, thinking processes, and function calls.
//
// Response type states: 0=none, 1=content, 2=thinking, 3=function
// The function maintains state across multiple calls to ensure proper SSE event sequencing.
func ConvertCliToClaude(rawJSON []byte, isGlAPIKey, hasFirstResponse bool, responseType, responseIndex *int) string {
// Normalize the response format for different API key types
// Generative Language API keys have a different response structure
if isGlAPIKey {
rawJSON, _ = sjson.SetRawBytes(rawJSON, "response", rawJSON)
}
// Track whether tools are being used in this response chunk
usedTool := false
output := ""
// Initialize the streaming session with a message_start event
// This is only sent for the very first response chunk
if !hasFirstResponse {
output = "event: message_start\n"
// Create the initial message structure with default values
// This follows the Claude API specification for streaming message initialization
messageStartTemplate := `{"type": "message_start", "message": {"id": "msg_1nZdL29xx5MUA1yADyHTEsnR8uuvGzszyY", "type": "message", "role": "assistant", "content": [], "model": "claude-3-5-sonnet-20241022", "stop_reason": null, "stop_sequence": null, "usage": {"input_tokens": 0, "output_tokens": 0}}}`
// Override default values with actual response metadata if available
if modelVersionResult := gjson.GetBytes(rawJSON, "response.modelVersion"); modelVersionResult.Exists() {
messageStartTemplate, _ = sjson.Set(messageStartTemplate, "message.model", modelVersionResult.String())
}
if responseIDResult := gjson.GetBytes(rawJSON, "response.responseId"); responseIDResult.Exists() {
messageStartTemplate, _ = sjson.Set(messageStartTemplate, "message.id", responseIDResult.String())
}
output = output + fmt.Sprintf("data: %s\n\n\n", messageStartTemplate)
}
// Process the response parts array from the backend client
// Each part can contain text content, thinking content, or function calls
partsResult := gjson.GetBytes(rawJSON, "response.candidates.0.content.parts")
if partsResult.IsArray() {
partResults := partsResult.Array()
for i := 0; i < len(partResults); i++ {
partResult := partResults[i]
// Extract the different types of content from each part
partTextResult := partResult.Get("text")
functionCallResult := partResult.Get("functionCall")
// Handle text content (both regular content and thinking)
if partTextResult.Exists() {
// Process thinking content (internal reasoning)
if partResult.Get("thought").Bool() {
// Continue existing thinking block
if *responseType == 2 {
output = output + "event: content_block_delta\n"
data, _ := sjson.Set(fmt.Sprintf(`{"type":"content_block_delta","index":%d,"delta":{"type":"thinking_delta","thinking":""}}`, *responseIndex), "delta.thinking", partTextResult.String())
output = output + fmt.Sprintf("data: %s\n\n\n", data)
} else {
// Transition from another state to thinking
// First, close any existing content block
if *responseType != 0 {
if *responseType == 2 {
output = output + "event: content_block_delta\n"
output = output + fmt.Sprintf(`data: {"type":"content_block_delta","index":%d,"delta":{"type":"signature_delta","signature":null}}`, *responseIndex)
output = output + "\n\n\n"
}
output = output + "event: content_block_stop\n"
output = output + fmt.Sprintf(`data: {"type":"content_block_stop","index":%d}`, *responseIndex)
output = output + "\n\n\n"
*responseIndex++
}
// Start a new thinking content block
output = output + "event: content_block_start\n"
output = output + fmt.Sprintf(`data: {"type":"content_block_start","index":%d,"content_block":{"type":"thinking","thinking":""}}`, *responseIndex)
output = output + "\n\n\n"
output = output + "event: content_block_delta\n"
data, _ := sjson.Set(fmt.Sprintf(`{"type":"content_block_delta","index":%d,"delta":{"type":"thinking_delta","thinking":""}}`, *responseIndex), "delta.thinking", partTextResult.String())
output = output + fmt.Sprintf("data: %s\n\n\n", data)
*responseType = 2 // Set state to thinking
}
} else {
// Process regular text content (user-visible output)
// Continue existing text block
if *responseType == 1 {
output = output + "event: content_block_delta\n"
data, _ := sjson.Set(fmt.Sprintf(`{"type":"content_block_delta","index":%d,"delta":{"type":"text_delta","text":""}}`, *responseIndex), "delta.text", partTextResult.String())
output = output + fmt.Sprintf("data: %s\n\n\n", data)
} else {
// Transition from another state to text content
// First, close any existing content block
if *responseType != 0 {
if *responseType == 2 {
output = output + "event: content_block_delta\n"
output = output + fmt.Sprintf(`data: {"type":"content_block_delta","index":%d,"delta":{"type":"signature_delta","signature":null}}`, *responseIndex)
output = output + "\n\n\n"
}
output = output + "event: content_block_stop\n"
output = output + fmt.Sprintf(`data: {"type":"content_block_stop","index":%d}`, *responseIndex)
output = output + "\n\n\n"
*responseIndex++
}
// Start a new text content block
output = output + "event: content_block_start\n"
output = output + fmt.Sprintf(`data: {"type":"content_block_start","index":%d,"content_block":{"type":"text","text":""}}`, *responseIndex)
output = output + "\n\n\n"
output = output + "event: content_block_delta\n"
data, _ := sjson.Set(fmt.Sprintf(`{"type":"content_block_delta","index":%d,"delta":{"type":"text_delta","text":""}}`, *responseIndex), "delta.text", partTextResult.String())
output = output + fmt.Sprintf("data: %s\n\n\n", data)
*responseType = 1 // Set state to content
}
}
} else if functionCallResult.Exists() {
// Handle function/tool calls from the AI model
// This processes tool usage requests and formats them for Claude API compatibility
usedTool = true
fcName := functionCallResult.Get("name").String()
// Handle state transitions when switching to function calls
// Close any existing function call block first
if *responseType == 3 {
output = output + "event: content_block_stop\n"
output = output + fmt.Sprintf(`data: {"type":"content_block_stop","index":%d}`, *responseIndex)
output = output + "\n\n\n"
*responseIndex++
*responseType = 0
}
// Special handling for thinking state transition
if *responseType == 2 {
output = output + "event: content_block_delta\n"
output = output + fmt.Sprintf(`data: {"type":"content_block_delta","index":%d,"delta":{"type":"signature_delta","signature":null}}`, *responseIndex)
output = output + "\n\n\n"
}
// Close any other existing content block
if *responseType != 0 {
output = output + "event: content_block_stop\n"
output = output + fmt.Sprintf(`data: {"type":"content_block_stop","index":%d}`, *responseIndex)
output = output + "\n\n\n"
*responseIndex++
}
// Start a new tool use content block
// This creates the structure for a function call in Claude format
output = output + "event: content_block_start\n"
// Create the tool use block with unique ID and function details
data := fmt.Sprintf(`{"type":"content_block_start","index":%d,"content_block":{"type":"tool_use","id":"","name":"","input":{}}}`, *responseIndex)
data, _ = sjson.Set(data, "content_block.id", fmt.Sprintf("%s-%d", fcName, time.Now().UnixNano()))
data, _ = sjson.Set(data, "content_block.name", fcName)
output = output + fmt.Sprintf("data: %s\n\n\n", data)
if fcArgsResult := functionCallResult.Get("args"); fcArgsResult.Exists() {
output = output + "event: content_block_delta\n"
data, _ = sjson.Set(fmt.Sprintf(`{"type":"content_block_delta","index":%d,"delta":{"type":"input_json_delta","partial_json":""}}`, *responseIndex), "delta.partial_json", fcArgsResult.Raw)
output = output + fmt.Sprintf("data: %s\n\n\n", data)
}
*responseType = 3
}
}
}
usageResult := gjson.GetBytes(rawJSON, "response.usageMetadata")
if usageResult.Exists() && bytes.Contains(rawJSON, []byte(`"finishReason"`)) {
if candidatesTokenCountResult := usageResult.Get("candidatesTokenCount"); candidatesTokenCountResult.Exists() {
output = output + "event: content_block_stop\n"
output = output + fmt.Sprintf(`data: {"type":"content_block_stop","index":%d}`, *responseIndex)
output = output + "\n\n\n"
output = output + "event: message_delta\n"
output = output + `data: `
template := `{"type":"message_delta","delta":{"stop_reason":"end_turn","stop_sequence":null},"usage":{"input_tokens":0,"output_tokens":0}}`
if usedTool {
template = `{"type":"message_delta","delta":{"stop_reason":"tool_use","stop_sequence":null},"usage":{"input_tokens":0,"output_tokens":0}}`
}
thoughtsTokenCount := usageResult.Get("thoughtsTokenCount").Int()
template, _ = sjson.Set(template, "usage.output_tokens", candidatesTokenCountResult.Int()+thoughtsTokenCount)
template, _ = sjson.Set(template, "usage.input_tokens", usageResult.Get("promptTokenCount").Int())
output = output + template + "\n\n\n"
}
}
return output
}

View File

@@ -0,0 +1,185 @@
// Package cli provides request translation functionality for Gemini CLI API.
// It handles the conversion and formatting of CLI tool responses, specifically
// transforming between different JSON formats to ensure proper conversation flow
// and API compatibility. The package focuses on intelligently grouping function
// calls with their corresponding responses, converting from linear format to
// grouped format where function calls and responses are properly associated.
package cli
import (
"encoding/json"
"fmt"
log "github.com/sirupsen/logrus"
"github.com/tidwall/gjson"
"github.com/tidwall/sjson"
)
// FunctionCallGroup represents a group of function calls and their responses
type FunctionCallGroup struct {
ModelContent map[string]interface{}
FunctionCalls []gjson.Result
ResponsesNeeded int
}
// FixCLIToolResponse performs sophisticated tool response format conversion and grouping.
// This function transforms the CLI tool response format by intelligently grouping function calls
// with their corresponding responses, ensuring proper conversation flow and API compatibility.
// It converts from a linear format (1.json) to a grouped format (2.json) where function calls
// and their responses are properly associated and structured.
func FixCLIToolResponse(input string) (string, error) {
// Parse the input JSON to extract the conversation structure
parsed := gjson.Parse(input)
// Extract the contents array which contains the conversation messages
contents := parsed.Get("request.contents")
if !contents.Exists() {
// log.Debugf(input)
return input, fmt.Errorf("contents not found in input")
}
// Initialize data structures for processing and grouping
var newContents []interface{} // Final processed contents array
var pendingGroups []*FunctionCallGroup // Groups awaiting completion with responses
var collectedResponses []gjson.Result // Standalone responses to be matched
// Process each content object in the conversation
// This iterates through messages and groups function calls with their responses
contents.ForEach(func(key, value gjson.Result) bool {
role := value.Get("role").String()
parts := value.Get("parts")
// Check if this content has function responses
var responsePartsInThisContent []gjson.Result
parts.ForEach(func(_, part gjson.Result) bool {
if part.Get("functionResponse").Exists() {
responsePartsInThisContent = append(responsePartsInThisContent, part)
}
return true
})
// If this content has function responses, collect them
if len(responsePartsInThisContent) > 0 {
collectedResponses = append(collectedResponses, responsePartsInThisContent...)
// Check if any pending groups can be satisfied
for i := len(pendingGroups) - 1; i >= 0; i-- {
group := pendingGroups[i]
if len(collectedResponses) >= group.ResponsesNeeded {
// Take the needed responses for this group
groupResponses := collectedResponses[:group.ResponsesNeeded]
collectedResponses = collectedResponses[group.ResponsesNeeded:]
// Create merged function response content
var responseParts []interface{}
for _, response := range groupResponses {
var responseMap map[string]interface{}
errUnmarshal := json.Unmarshal([]byte(response.Raw), &responseMap)
if errUnmarshal != nil {
log.Warnf("failed to unmarshal function response: %v\n", errUnmarshal)
continue
}
responseParts = append(responseParts, responseMap)
}
if len(responseParts) > 0 {
functionResponseContent := map[string]interface{}{
"parts": responseParts,
"role": "function",
}
newContents = append(newContents, functionResponseContent)
}
// Remove this group as it's been satisfied
pendingGroups = append(pendingGroups[:i], pendingGroups[i+1:]...)
break
}
}
return true // Skip adding this content, responses are merged
}
// If this is a model with function calls, create a new group
if role == "model" {
var functionCallsInThisModel []gjson.Result
parts.ForEach(func(_, part gjson.Result) bool {
if part.Get("functionCall").Exists() {
functionCallsInThisModel = append(functionCallsInThisModel, part)
}
return true
})
if len(functionCallsInThisModel) > 0 {
// Add the model content
var contentMap map[string]interface{}
errUnmarshal := json.Unmarshal([]byte(value.Raw), &contentMap)
if errUnmarshal != nil {
log.Warnf("failed to unmarshal model content: %v\n", errUnmarshal)
return true
}
newContents = append(newContents, contentMap)
// Create a new group for tracking responses
group := &FunctionCallGroup{
ModelContent: contentMap,
FunctionCalls: functionCallsInThisModel,
ResponsesNeeded: len(functionCallsInThisModel),
}
pendingGroups = append(pendingGroups, group)
} else {
// Regular model content without function calls
var contentMap map[string]interface{}
errUnmarshal := json.Unmarshal([]byte(value.Raw), &contentMap)
if errUnmarshal != nil {
log.Warnf("failed to unmarshal content: %v\n", errUnmarshal)
return true
}
newContents = append(newContents, contentMap)
}
} else {
// Non-model content (user, etc.)
var contentMap map[string]interface{}
errUnmarshal := json.Unmarshal([]byte(value.Raw), &contentMap)
if errUnmarshal != nil {
log.Warnf("failed to unmarshal content: %v\n", errUnmarshal)
return true
}
newContents = append(newContents, contentMap)
}
return true
})
// Handle any remaining pending groups with remaining responses
for _, group := range pendingGroups {
if len(collectedResponses) >= group.ResponsesNeeded {
groupResponses := collectedResponses[:group.ResponsesNeeded]
collectedResponses = collectedResponses[group.ResponsesNeeded:]
var responseParts []interface{}
for _, response := range groupResponses {
var responseMap map[string]interface{}
errUnmarshal := json.Unmarshal([]byte(response.Raw), &responseMap)
if errUnmarshal != nil {
log.Warnf("failed to unmarshal function response: %v\n", errUnmarshal)
continue
}
responseParts = append(responseParts, responseMap)
}
if len(responseParts) > 0 {
functionResponseContent := map[string]interface{}{
"parts": responseParts,
"role": "function",
}
newContents = append(newContents, functionResponseContent)
}
}
}
// Update the original JSON with the new contents
result := input
newContentsJSON, _ := json.Marshal(newContents)
result, _ = sjson.Set(result, "request.contents", json.RawMessage(newContentsJSON))
return result, nil
}

View File

@@ -1,3 +1,6 @@
// Package translator provides data translation and format conversion utilities
// for the CLI Proxy API. It includes MIME type mappings and other translation
// functions used across different API endpoints.
package translator package translator
// MimeTypes is a comprehensive map of file extensions to their corresponding MIME types. // MimeTypes is a comprehensive map of file extensions to their corresponding MIME types.

View File

@@ -0,0 +1,226 @@
// Package openai provides request translation functionality for OpenAI API.
// It handles the conversion of OpenAI-compatible request formats to the internal
// format expected by the backend client, including parsing messages, roles,
// content types (text, image, file), and tool calls.
package openai
import (
"encoding/json"
"github.com/luispater/CLIProxyAPI/internal/api/translator"
"strings"
"github.com/luispater/CLIProxyAPI/internal/client"
log "github.com/sirupsen/logrus"
"github.com/tidwall/gjson"
)
// PrepareRequest translates a raw JSON request from an OpenAI-compatible format
// to the internal format expected by the backend client. It parses messages,
// roles, content types (text, image, file), and tool calls.
func PrepareRequest(rawJSON []byte) (string, *client.Content, []client.Content, []client.ToolDeclaration) {
// Extract the model name from the request, defaulting to "gemini-2.5-pro".
modelName := "gemini-2.5-pro"
modelResult := gjson.GetBytes(rawJSON, "model")
if modelResult.Type == gjson.String {
modelName = modelResult.String()
}
// Initialize data structures for processing conversation messages
// contents: stores the processed conversation history
// systemInstruction: stores system-level instructions separate from conversation
contents := make([]client.Content, 0)
var systemInstruction *client.Content
messagesResult := gjson.GetBytes(rawJSON, "messages")
// Pre-process tool responses to create a lookup map
// This first pass collects all tool responses so they can be matched with their corresponding calls
toolItems := make(map[string]*client.FunctionResponse)
if messagesResult.IsArray() {
messagesResults := messagesResult.Array()
for i := 0; i < len(messagesResults); i++ {
messageResult := messagesResults[i]
roleResult := messageResult.Get("role")
if roleResult.Type != gjson.String {
continue
}
contentResult := messageResult.Get("content")
// Extract tool responses for later matching with function calls
if roleResult.String() == "tool" {
toolCallID := messageResult.Get("tool_call_id").String()
if toolCallID != "" {
var responseData string
// Handle both string and object-based tool response formats
if contentResult.Type == gjson.String {
responseData = contentResult.String()
} else if contentResult.IsObject() && contentResult.Get("type").String() == "text" {
responseData = contentResult.Get("text").String()
}
// Clean up tool call ID by removing timestamp suffix
// This normalizes IDs for consistent matching between calls and responses
toolCallIDs := strings.Split(toolCallID, "-")
strings.Join(toolCallIDs, "-")
newToolCallID := strings.Join(toolCallIDs[:len(toolCallIDs)-1], "-")
// Create function response object with normalized ID and response data
functionResponse := client.FunctionResponse{Name: newToolCallID, Response: map[string]interface{}{"result": responseData}}
toolItems[toolCallID] = &functionResponse
}
}
}
}
if messagesResult.IsArray() {
messagesResults := messagesResult.Array()
for i := 0; i < len(messagesResults); i++ {
messageResult := messagesResults[i]
roleResult := messageResult.Get("role")
contentResult := messageResult.Get("content")
if roleResult.Type != gjson.String {
continue
}
switch roleResult.String() {
// System messages are converted to a user message followed by a model's acknowledgment.
case "system":
if contentResult.Type == gjson.String {
systemInstruction = &client.Content{Role: "user", Parts: []client.Part{{Text: contentResult.String()}}}
} else if contentResult.IsObject() {
// Handle object-based system messages.
if contentResult.Get("type").String() == "text" {
systemInstruction = &client.Content{Role: "user", Parts: []client.Part{{Text: contentResult.Get("text").String()}}}
}
}
// User messages can contain simple text or a multi-part body.
case "user":
if contentResult.Type == gjson.String {
contents = append(contents, client.Content{Role: "user", Parts: []client.Part{{Text: contentResult.String()}}})
} else if contentResult.IsArray() {
// Handle multi-part user messages (text, images, files).
contentItemResults := contentResult.Array()
parts := make([]client.Part, 0)
for j := 0; j < len(contentItemResults); j++ {
contentItemResult := contentItemResults[j]
contentTypeResult := contentItemResult.Get("type")
switch contentTypeResult.String() {
case "text":
parts = append(parts, client.Part{Text: contentItemResult.Get("text").String()})
case "image_url":
// Parse data URI for images.
imageURL := contentItemResult.Get("image_url.url").String()
if len(imageURL) > 5 {
imageURLs := strings.SplitN(imageURL[5:], ";", 2)
if len(imageURLs) == 2 && len(imageURLs[1]) > 7 {
parts = append(parts, client.Part{InlineData: &client.InlineData{
MimeType: imageURLs[0],
Data: imageURLs[1][7:],
}})
}
}
case "file":
// Handle file attachments by determining MIME type from extension.
filename := contentItemResult.Get("file.filename").String()
fileData := contentItemResult.Get("file.file_data").String()
ext := ""
if split := strings.Split(filename, "."); len(split) > 1 {
ext = split[len(split)-1]
}
if mimeType, ok := translator.MimeTypes[ext]; ok {
parts = append(parts, client.Part{InlineData: &client.InlineData{
MimeType: mimeType,
Data: fileData,
}})
} else {
log.Warnf("Unknown file name extension '%s' at index %d, skipping file", ext, j)
}
}
}
contents = append(contents, client.Content{Role: "user", Parts: parts})
}
// Assistant messages can contain text responses or tool calls
// In the internal format, assistant messages are converted to "model" role
case "assistant":
if contentResult.Type == gjson.String {
// Simple text response from the assistant
contents = append(contents, client.Content{Role: "model", Parts: []client.Part{{Text: contentResult.String()}}})
} else if !contentResult.Exists() || contentResult.Type == gjson.Null {
// Handle complex tool calls made by the assistant
// This processes function calls and matches them with their responses
functionIDs := make([]string, 0)
toolCallsResult := messageResult.Get("tool_calls")
if toolCallsResult.IsArray() {
parts := make([]client.Part, 0)
tcsResult := toolCallsResult.Array()
// Process each tool call in the assistant's message
for j := 0; j < len(tcsResult); j++ {
tcResult := tcsResult[j]
// Extract function call details
functionID := tcResult.Get("id").String()
functionIDs = append(functionIDs, functionID)
functionName := tcResult.Get("function.name").String()
functionArgs := tcResult.Get("function.arguments").String()
// Parse function arguments from JSON string to map
var args map[string]any
if err := json.Unmarshal([]byte(functionArgs), &args); err == nil {
parts = append(parts, client.Part{
FunctionCall: &client.FunctionCall{
Name: functionName,
Args: args,
},
})
}
}
// Add the model's function calls to the conversation
if len(parts) > 0 {
contents = append(contents, client.Content{
Role: "model", Parts: parts,
})
// Create a separate tool response message with the collected responses
// This matches function calls with their corresponding responses
toolParts := make([]client.Part, 0)
for _, functionID := range functionIDs {
if functionResponse, ok := toolItems[functionID]; ok {
toolParts = append(toolParts, client.Part{FunctionResponse: functionResponse})
}
}
// Add the tool responses as a separate message in the conversation
contents = append(contents, client.Content{Role: "tool", Parts: toolParts})
}
}
}
}
}
}
// Translate the tool declarations from the request.
var tools []client.ToolDeclaration
toolsResult := gjson.GetBytes(rawJSON, "tools")
if toolsResult.IsArray() {
tools = make([]client.ToolDeclaration, 1)
tools[0].FunctionDeclarations = make([]any, 0)
toolsResults := toolsResult.Array()
for i := 0; i < len(toolsResults); i++ {
toolResult := toolsResults[i]
if toolResult.Get("type").String() == "function" {
functionTypeResult := toolResult.Get("function")
if functionTypeResult.Exists() && functionTypeResult.IsObject() {
var functionDeclaration any
if err := json.Unmarshal([]byte(functionTypeResult.Raw), &functionDeclaration); err == nil {
tools[0].FunctionDeclarations = append(tools[0].FunctionDeclarations, functionDeclaration)
}
}
}
}
} else {
tools = make([]client.ToolDeclaration, 0)
}
return modelName, systemInstruction, contents, tools
}

View File

@@ -0,0 +1,197 @@
// Package openai provides response translation functionality for converting between
// different API response formats and OpenAI-compatible formats. It handles both
// streaming and non-streaming responses, transforming backend client responses
// into OpenAI Server-Sent Events (SSE) format and standard JSON response formats.
// The package supports content translation, function calls, usage metadata,
// and various response attributes while maintaining compatibility with OpenAI API
// specifications.
package openai
import (
"fmt"
"time"
"github.com/tidwall/gjson"
"github.com/tidwall/sjson"
)
// ConvertCliToOpenAI translates a single chunk of a streaming response from the
// backend client format to the OpenAI Server-Sent Events (SSE) format.
// It returns an empty string if the chunk contains no useful data.
func ConvertCliToOpenAI(rawJSON []byte, unixTimestamp int64, isGlAPIKey bool) string {
if isGlAPIKey {
rawJSON, _ = sjson.SetRawBytes(rawJSON, "response", rawJSON)
}
// Initialize the OpenAI SSE template.
template := `{"id":"","object":"chat.completion.chunk","created":12345,"model":"model","choices":[{"index":0,"delta":{"role":null,"content":null,"reasoning_content":null,"tool_calls":null},"finish_reason":null,"native_finish_reason":null}]}`
// Extract and set the model version.
if modelVersionResult := gjson.GetBytes(rawJSON, "response.modelVersion"); modelVersionResult.Exists() {
template, _ = sjson.Set(template, "model", modelVersionResult.String())
}
// Extract and set the creation timestamp.
if createTimeResult := gjson.GetBytes(rawJSON, "response.createTime"); createTimeResult.Exists() {
t, err := time.Parse(time.RFC3339Nano, createTimeResult.String())
if err == nil {
unixTimestamp = t.Unix()
}
template, _ = sjson.Set(template, "created", unixTimestamp)
} else {
template, _ = sjson.Set(template, "created", unixTimestamp)
}
// Extract and set the response ID.
if responseIDResult := gjson.GetBytes(rawJSON, "response.responseId"); responseIDResult.Exists() {
template, _ = sjson.Set(template, "id", responseIDResult.String())
}
// Extract and set the finish reason.
if finishReasonResult := gjson.GetBytes(rawJSON, "response.candidates.0.finishReason"); finishReasonResult.Exists() {
template, _ = sjson.Set(template, "choices.0.finish_reason", finishReasonResult.String())
template, _ = sjson.Set(template, "choices.0.native_finish_reason", finishReasonResult.String())
}
// Extract and set usage metadata (token counts).
if usageResult := gjson.GetBytes(rawJSON, "response.usageMetadata"); usageResult.Exists() {
if candidatesTokenCountResult := usageResult.Get("candidatesTokenCount"); candidatesTokenCountResult.Exists() {
template, _ = sjson.Set(template, "usage.completion_tokens", candidatesTokenCountResult.Int())
}
if totalTokenCountResult := usageResult.Get("totalTokenCount"); totalTokenCountResult.Exists() {
template, _ = sjson.Set(template, "usage.total_tokens", totalTokenCountResult.Int())
}
promptTokenCount := usageResult.Get("promptTokenCount").Int()
thoughtsTokenCount := usageResult.Get("thoughtsTokenCount").Int()
template, _ = sjson.Set(template, "usage.prompt_tokens", promptTokenCount+thoughtsTokenCount)
if thoughtsTokenCount > 0 {
template, _ = sjson.Set(template, "usage.completion_tokens_details.reasoning_tokens", thoughtsTokenCount)
}
}
// Process the main content part of the response.
partsResult := gjson.GetBytes(rawJSON, "response.candidates.0.content.parts")
if partsResult.IsArray() {
partResults := partsResult.Array()
for i := 0; i < len(partResults); i++ {
partResult := partResults[i]
partTextResult := partResult.Get("text")
functionCallResult := partResult.Get("functionCall")
if partTextResult.Exists() {
// Handle text content, distinguishing between regular content and reasoning/thoughts.
if partResult.Get("thought").Bool() {
template, _ = sjson.Set(template, "choices.0.delta.reasoning_content", partTextResult.String())
} else {
template, _ = sjson.Set(template, "choices.0.delta.content", partTextResult.String())
}
template, _ = sjson.Set(template, "choices.0.delta.role", "assistant")
} else if functionCallResult.Exists() {
// Handle function call content.
toolCallsResult := gjson.Get(template, "choices.0.delta.tool_calls")
if !toolCallsResult.Exists() || !toolCallsResult.IsArray() {
template, _ = sjson.SetRaw(template, "choices.0.delta.tool_calls", `[]`)
}
functionCallTemplate := `{"id": "","type": "function","function": {"name": "","arguments": ""}}`
fcName := functionCallResult.Get("name").String()
functionCallTemplate, _ = sjson.Set(functionCallTemplate, "id", fmt.Sprintf("%s-%d", fcName, time.Now().UnixNano()))
functionCallTemplate, _ = sjson.Set(functionCallTemplate, "function.name", fcName)
if fcArgsResult := functionCallResult.Get("args"); fcArgsResult.Exists() {
functionCallTemplate, _ = sjson.Set(functionCallTemplate, "function.arguments", fcArgsResult.Raw)
}
template, _ = sjson.Set(template, "choices.0.delta.role", "assistant")
template, _ = sjson.SetRaw(template, "choices.0.message.tool_calls.-1", functionCallTemplate)
}
}
}
return template
}
// ConvertCliToOpenAINonStream aggregates response from the backend client
// convert a single, non-streaming OpenAI-compatible JSON response.
func ConvertCliToOpenAINonStream(rawJSON []byte, unixTimestamp int64, isGlAPIKey bool) string {
if isGlAPIKey {
rawJSON, _ = sjson.SetRawBytes(rawJSON, "response", rawJSON)
}
template := `{"id":"","object":"chat.completion","created":123456,"model":"model","choices":[{"index":0,"message":{"role":"assistant","content":null,"reasoning_content":null,"tool_calls":null},"finish_reason":null,"native_finish_reason":null}]}`
if modelVersionResult := gjson.GetBytes(rawJSON, "response.modelVersion"); modelVersionResult.Exists() {
template, _ = sjson.Set(template, "model", modelVersionResult.String())
}
if createTimeResult := gjson.GetBytes(rawJSON, "response.createTime"); createTimeResult.Exists() {
t, err := time.Parse(time.RFC3339Nano, createTimeResult.String())
if err == nil {
unixTimestamp = t.Unix()
}
template, _ = sjson.Set(template, "created", unixTimestamp)
} else {
template, _ = sjson.Set(template, "created", unixTimestamp)
}
if responseIDResult := gjson.GetBytes(rawJSON, "response.responseId"); responseIDResult.Exists() {
template, _ = sjson.Set(template, "id", responseIDResult.String())
}
if finishReasonResult := gjson.GetBytes(rawJSON, "response.candidates.0.finishReason"); finishReasonResult.Exists() {
template, _ = sjson.Set(template, "choices.0.finish_reason", finishReasonResult.String())
template, _ = sjson.Set(template, "choices.0.native_finish_reason", finishReasonResult.String())
}
if usageResult := gjson.GetBytes(rawJSON, "response.usageMetadata"); usageResult.Exists() {
if candidatesTokenCountResult := usageResult.Get("candidatesTokenCount"); candidatesTokenCountResult.Exists() {
template, _ = sjson.Set(template, "usage.completion_tokens", candidatesTokenCountResult.Int())
}
if totalTokenCountResult := usageResult.Get("totalTokenCount"); totalTokenCountResult.Exists() {
template, _ = sjson.Set(template, "usage.total_tokens", totalTokenCountResult.Int())
}
promptTokenCount := usageResult.Get("promptTokenCount").Int()
thoughtsTokenCount := usageResult.Get("thoughtsTokenCount").Int()
template, _ = sjson.Set(template, "usage.prompt_tokens", promptTokenCount+thoughtsTokenCount)
if thoughtsTokenCount > 0 {
template, _ = sjson.Set(template, "usage.completion_tokens_details.reasoning_tokens", thoughtsTokenCount)
}
}
// Process the main content part of the response.
partsResult := gjson.GetBytes(rawJSON, "response.candidates.0.content.parts")
if partsResult.IsArray() {
partsResults := partsResult.Array()
for i := 0; i < len(partsResults); i++ {
partResult := partsResults[i]
partTextResult := partResult.Get("text")
functionCallResult := partResult.Get("functionCall")
if partTextResult.Exists() {
// Append text content, distinguishing between regular content and reasoning.
if partResult.Get("thought").Bool() {
template, _ = sjson.Set(template, "choices.0.message.reasoning_content", partTextResult.String())
} else {
template, _ = sjson.Set(template, "choices.0.message.content", partTextResult.String())
}
template, _ = sjson.Set(template, "choices.0.message.role", "assistant")
} else if functionCallResult.Exists() {
// Append function call content to the tool_calls array.
toolCallsResult := gjson.Get(template, "choices.0.message.tool_calls")
if !toolCallsResult.Exists() || !toolCallsResult.IsArray() {
template, _ = sjson.SetRaw(template, "choices.0.message.tool_calls", `[]`)
}
functionCallItemTemplate := `{"id": "","type": "function","function": {"name": "","arguments": ""}}`
fcName := functionCallResult.Get("name").String()
functionCallItemTemplate, _ = sjson.Set(functionCallItemTemplate, "id", fmt.Sprintf("%s-%d", fcName, time.Now().UnixNano()))
functionCallItemTemplate, _ = sjson.Set(functionCallItemTemplate, "function.name", fcName)
if fcArgsResult := functionCallResult.Get("args"); fcArgsResult.Exists() {
functionCallItemTemplate, _ = sjson.Set(functionCallItemTemplate, "function.arguments", fcArgsResult.Raw)
}
template, _ = sjson.Set(template, "choices.0.message.role", "assistant")
template, _ = sjson.SetRaw(template, "choices.0.message.tool_calls.-1", functionCallItemTemplate)
} else {
// If no usable content is found, return an empty string.
return ""
}
}
}
return template
}

View File

@@ -1,545 +0,0 @@
package translator
import (
"bytes"
"encoding/json"
"fmt"
"github.com/tidwall/sjson"
"strings"
"github.com/luispater/CLIProxyAPI/internal/client"
log "github.com/sirupsen/logrus"
"github.com/tidwall/gjson"
)
// PrepareRequest translates a raw JSON request from an OpenAI-compatible format
// to the internal format expected by the backend client. It parses messages,
// roles, content types (text, image, file), and tool calls.
func PrepareRequest(rawJson []byte) (string, *client.Content, []client.Content, []client.ToolDeclaration) {
// Extract the model name from the request, defaulting to "gemini-2.5-pro".
modelName := "gemini-2.5-pro"
modelResult := gjson.GetBytes(rawJson, "model")
if modelResult.Type == gjson.String {
modelName = modelResult.String()
}
// Initialize data structures for processing conversation messages
// contents: stores the processed conversation history
// systemInstruction: stores system-level instructions separate from conversation
contents := make([]client.Content, 0)
var systemInstruction *client.Content
messagesResult := gjson.GetBytes(rawJson, "messages")
// Pre-process tool responses to create a lookup map
// This first pass collects all tool responses so they can be matched with their corresponding calls
toolItems := make(map[string]*client.FunctionResponse)
if messagesResult.IsArray() {
messagesResults := messagesResult.Array()
for i := 0; i < len(messagesResults); i++ {
messageResult := messagesResults[i]
roleResult := messageResult.Get("role")
if roleResult.Type != gjson.String {
continue
}
contentResult := messageResult.Get("content")
// Extract tool responses for later matching with function calls
if roleResult.String() == "tool" {
toolCallID := messageResult.Get("tool_call_id").String()
if toolCallID != "" {
var responseData string
// Handle both string and object-based tool response formats
if contentResult.Type == gjson.String {
responseData = contentResult.String()
} else if contentResult.IsObject() && contentResult.Get("type").String() == "text" {
responseData = contentResult.Get("text").String()
}
// Clean up tool call ID by removing timestamp suffix
// This normalizes IDs for consistent matching between calls and responses
toolCallIDs := strings.Split(toolCallID, "-")
strings.Join(toolCallIDs, "-")
newToolCallID := strings.Join(toolCallIDs[:len(toolCallIDs)-1], "-")
// Create function response object with normalized ID and response data
functionResponse := client.FunctionResponse{Name: newToolCallID, Response: map[string]interface{}{"result": responseData}}
toolItems[toolCallID] = &functionResponse
}
}
}
}
if messagesResult.IsArray() {
messagesResults := messagesResult.Array()
for i := 0; i < len(messagesResults); i++ {
messageResult := messagesResults[i]
roleResult := messageResult.Get("role")
contentResult := messageResult.Get("content")
if roleResult.Type != gjson.String {
continue
}
switch roleResult.String() {
// System messages are converted to a user message followed by a model's acknowledgment.
case "system":
if contentResult.Type == gjson.String {
systemInstruction = &client.Content{Role: "user", Parts: []client.Part{{Text: contentResult.String()}}}
} else if contentResult.IsObject() {
// Handle object-based system messages.
if contentResult.Get("type").String() == "text" {
systemInstruction = &client.Content{Role: "user", Parts: []client.Part{{Text: contentResult.Get("text").String()}}}
}
}
// User messages can contain simple text or a multi-part body.
case "user":
if contentResult.Type == gjson.String {
contents = append(contents, client.Content{Role: "user", Parts: []client.Part{{Text: contentResult.String()}}})
} else if contentResult.IsArray() {
// Handle multi-part user messages (text, images, files).
contentItemResults := contentResult.Array()
parts := make([]client.Part, 0)
for j := 0; j < len(contentItemResults); j++ {
contentItemResult := contentItemResults[j]
contentTypeResult := contentItemResult.Get("type")
switch contentTypeResult.String() {
case "text":
parts = append(parts, client.Part{Text: contentItemResult.Get("text").String()})
case "image_url":
// Parse data URI for images.
imageURL := contentItemResult.Get("image_url.url").String()
if len(imageURL) > 5 {
imageURLs := strings.SplitN(imageURL[5:], ";", 2)
if len(imageURLs) == 2 && len(imageURLs[1]) > 7 {
parts = append(parts, client.Part{InlineData: &client.InlineData{
MimeType: imageURLs[0],
Data: imageURLs[1][7:],
}})
}
}
case "file":
// Handle file attachments by determining MIME type from extension.
filename := contentItemResult.Get("file.filename").String()
fileData := contentItemResult.Get("file.file_data").String()
ext := ""
if split := strings.Split(filename, "."); len(split) > 1 {
ext = split[len(split)-1]
}
if mimeType, ok := MimeTypes[ext]; ok {
parts = append(parts, client.Part{InlineData: &client.InlineData{
MimeType: mimeType,
Data: fileData,
}})
} else {
log.Warnf("Unknown file name extension '%s' at index %d, skipping file", ext, j)
}
}
}
contents = append(contents, client.Content{Role: "user", Parts: parts})
}
// Assistant messages can contain text responses or tool calls
// In the internal format, assistant messages are converted to "model" role
case "assistant":
if contentResult.Type == gjson.String {
// Simple text response from the assistant
contents = append(contents, client.Content{Role: "model", Parts: []client.Part{{Text: contentResult.String()}}})
} else if !contentResult.Exists() || contentResult.Type == gjson.Null {
// Handle complex tool calls made by the assistant
// This processes function calls and matches them with their responses
functionIDs := make([]string, 0)
toolCallsResult := messageResult.Get("tool_calls")
if toolCallsResult.IsArray() {
parts := make([]client.Part, 0)
tcsResult := toolCallsResult.Array()
// Process each tool call in the assistant's message
for j := 0; j < len(tcsResult); j++ {
tcResult := tcsResult[j]
// Extract function call details
functionID := tcResult.Get("id").String()
functionIDs = append(functionIDs, functionID)
functionName := tcResult.Get("function.name").String()
functionArgs := tcResult.Get("function.arguments").String()
// Parse function arguments from JSON string to map
var args map[string]any
if err := json.Unmarshal([]byte(functionArgs), &args); err == nil {
parts = append(parts, client.Part{
FunctionCall: &client.FunctionCall{
Name: functionName,
Args: args,
},
})
}
}
// Add the model's function calls to the conversation
if len(parts) > 0 {
contents = append(contents, client.Content{
Role: "model", Parts: parts,
})
// Create a separate tool response message with the collected responses
// This matches function calls with their corresponding responses
toolParts := make([]client.Part, 0)
for _, functionID := range functionIDs {
if functionResponse, ok := toolItems[functionID]; ok {
toolParts = append(toolParts, client.Part{FunctionResponse: functionResponse})
}
}
// Add the tool responses as a separate message in the conversation
contents = append(contents, client.Content{Role: "tool", Parts: toolParts})
}
}
}
}
}
}
// Translate the tool declarations from the request.
var tools []client.ToolDeclaration
toolsResult := gjson.GetBytes(rawJson, "tools")
if toolsResult.IsArray() {
tools = make([]client.ToolDeclaration, 1)
tools[0].FunctionDeclarations = make([]any, 0)
toolsResults := toolsResult.Array()
for i := 0; i < len(toolsResults); i++ {
toolResult := toolsResults[i]
if toolResult.Get("type").String() == "function" {
functionTypeResult := toolResult.Get("function")
if functionTypeResult.Exists() && functionTypeResult.IsObject() {
var functionDeclaration any
if err := json.Unmarshal([]byte(functionTypeResult.Raw), &functionDeclaration); err == nil {
tools[0].FunctionDeclarations = append(tools[0].FunctionDeclarations, functionDeclaration)
}
}
}
}
} else {
tools = make([]client.ToolDeclaration, 0)
}
return modelName, systemInstruction, contents, tools
}
// FunctionCallGroup represents a group of function calls and their responses
type FunctionCallGroup struct {
ModelContent map[string]interface{}
FunctionCalls []gjson.Result
ResponsesNeeded int
}
// FixCLIToolResponse performs sophisticated tool response format conversion and grouping.
// This function transforms the CLI tool response format by intelligently grouping function calls
// with their corresponding responses, ensuring proper conversation flow and API compatibility.
// It converts from a linear format (1.json) to a grouped format (2.json) where function calls
// and their responses are properly associated and structured.
func FixCLIToolResponse(input string) (string, error) {
// Parse the input JSON to extract the conversation structure
parsed := gjson.Parse(input)
// Extract the contents array which contains the conversation messages
contents := parsed.Get("request.contents")
if !contents.Exists() {
// log.Debugf(input)
return input, fmt.Errorf("contents not found in input")
}
// Initialize data structures for processing and grouping
var newContents []interface{} // Final processed contents array
var pendingGroups []*FunctionCallGroup // Groups awaiting completion with responses
var collectedResponses []gjson.Result // Standalone responses to be matched
// Process each content object in the conversation
// This iterates through messages and groups function calls with their responses
contents.ForEach(func(key, value gjson.Result) bool {
role := value.Get("role").String()
parts := value.Get("parts")
// Check if this content has function responses
var responsePartsInThisContent []gjson.Result
parts.ForEach(func(_, part gjson.Result) bool {
if part.Get("functionResponse").Exists() {
responsePartsInThisContent = append(responsePartsInThisContent, part)
}
return true
})
// If this content has function responses, collect them
if len(responsePartsInThisContent) > 0 {
collectedResponses = append(collectedResponses, responsePartsInThisContent...)
// Check if any pending groups can be satisfied
for i := len(pendingGroups) - 1; i >= 0; i-- {
group := pendingGroups[i]
if len(collectedResponses) >= group.ResponsesNeeded {
// Take the needed responses for this group
groupResponses := collectedResponses[:group.ResponsesNeeded]
collectedResponses = collectedResponses[group.ResponsesNeeded:]
// Create merged function response content
var responseParts []interface{}
for _, response := range groupResponses {
var responseMap map[string]interface{}
errUnmarshal := json.Unmarshal([]byte(response.Raw), &responseMap)
if errUnmarshal != nil {
log.Warnf("failed to unmarshal function response: %v\n", errUnmarshal)
continue
}
responseParts = append(responseParts, responseMap)
}
if len(responseParts) > 0 {
functionResponseContent := map[string]interface{}{
"parts": responseParts,
"role": "function",
}
newContents = append(newContents, functionResponseContent)
}
// Remove this group as it's been satisfied
pendingGroups = append(pendingGroups[:i], pendingGroups[i+1:]...)
break
}
}
return true // Skip adding this content, responses are merged
}
// If this is a model with function calls, create a new group
if role == "model" {
var functionCallsInThisModel []gjson.Result
parts.ForEach(func(_, part gjson.Result) bool {
if part.Get("functionCall").Exists() {
functionCallsInThisModel = append(functionCallsInThisModel, part)
}
return true
})
if len(functionCallsInThisModel) > 0 {
// Add the model content
var contentMap map[string]interface{}
errUnmarshal := json.Unmarshal([]byte(value.Raw), &contentMap)
if errUnmarshal != nil {
log.Warnf("failed to unmarshal model content: %v\n", errUnmarshal)
return true
}
newContents = append(newContents, contentMap)
// Create a new group for tracking responses
group := &FunctionCallGroup{
ModelContent: contentMap,
FunctionCalls: functionCallsInThisModel,
ResponsesNeeded: len(functionCallsInThisModel),
}
pendingGroups = append(pendingGroups, group)
} else {
// Regular model content without function calls
var contentMap map[string]interface{}
errUnmarshal := json.Unmarshal([]byte(value.Raw), &contentMap)
if errUnmarshal != nil {
log.Warnf("failed to unmarshal content: %v\n", errUnmarshal)
return true
}
newContents = append(newContents, contentMap)
}
} else {
// Non-model content (user, etc.)
var contentMap map[string]interface{}
errUnmarshal := json.Unmarshal([]byte(value.Raw), &contentMap)
if errUnmarshal != nil {
log.Warnf("failed to unmarshal content: %v\n", errUnmarshal)
return true
}
newContents = append(newContents, contentMap)
}
return true
})
// Handle any remaining pending groups with remaining responses
for _, group := range pendingGroups {
if len(collectedResponses) >= group.ResponsesNeeded {
groupResponses := collectedResponses[:group.ResponsesNeeded]
collectedResponses = collectedResponses[group.ResponsesNeeded:]
var responseParts []interface{}
for _, response := range groupResponses {
var responseMap map[string]interface{}
errUnmarshal := json.Unmarshal([]byte(response.Raw), &responseMap)
if errUnmarshal != nil {
log.Warnf("failed to unmarshal function response: %v\n", errUnmarshal)
continue
}
responseParts = append(responseParts, responseMap)
}
if len(responseParts) > 0 {
functionResponseContent := map[string]interface{}{
"parts": responseParts,
"role": "function",
}
newContents = append(newContents, functionResponseContent)
}
}
}
// Update the original JSON with the new contents
result := input
newContentsJSON, _ := json.Marshal(newContents)
result, _ = sjson.Set(result, "request.contents", json.RawMessage(newContentsJSON))
return result, nil
}
func PrepareClaudeRequest(rawJson []byte) (string, *client.Content, []client.Content, []client.ToolDeclaration) {
var pathsToDelete []string
root := gjson.ParseBytes(rawJson)
walk(root, "", "additionalProperties", &pathsToDelete)
walk(root, "", "$schema", &pathsToDelete)
var err error
for _, p := range pathsToDelete {
rawJson, err = sjson.DeleteBytes(rawJson, p)
if err != nil {
continue
}
}
rawJson = bytes.Replace(rawJson, []byte(`"url":{"type":"string","format":"uri",`), []byte(`"url":{"type":"string",`), -1)
// log.Debug(string(rawJson))
modelName := "gemini-2.5-pro"
modelResult := gjson.GetBytes(rawJson, "model")
if modelResult.Type == gjson.String {
modelName = modelResult.String()
}
contents := make([]client.Content, 0)
var systemInstruction *client.Content
systemResult := gjson.GetBytes(rawJson, "system")
if systemResult.IsArray() {
systemResults := systemResult.Array()
systemInstruction = &client.Content{Role: "user", Parts: []client.Part{}}
for i := 0; i < len(systemResults); i++ {
systemPromptResult := systemResults[i]
systemTypePromptResult := systemPromptResult.Get("type")
if systemTypePromptResult.Type == gjson.String && systemTypePromptResult.String() == "text" {
systemPrompt := systemPromptResult.Get("text").String()
systemPart := client.Part{Text: systemPrompt}
systemInstruction.Parts = append(systemInstruction.Parts, systemPart)
}
}
if len(systemInstruction.Parts) == 0 {
systemInstruction = nil
}
}
messagesResult := gjson.GetBytes(rawJson, "messages")
if messagesResult.IsArray() {
messageResults := messagesResult.Array()
for i := 0; i < len(messageResults); i++ {
messageResult := messageResults[i]
roleResult := messageResult.Get("role")
if roleResult.Type != gjson.String {
continue
}
role := roleResult.String()
if role == "assistant" {
role = "model"
}
clientContent := client.Content{Role: role, Parts: []client.Part{}}
contentsResult := messageResult.Get("content")
if contentsResult.IsArray() {
contentResults := contentsResult.Array()
for j := 0; j < len(contentResults); j++ {
contentResult := contentResults[j]
contentTypeResult := contentResult.Get("type")
if contentTypeResult.Type == gjson.String && contentTypeResult.String() == "text" {
prompt := contentResult.Get("text").String()
clientContent.Parts = append(clientContent.Parts, client.Part{Text: prompt})
} else if contentTypeResult.Type == gjson.String && contentTypeResult.String() == "tool_use" {
functionName := contentResult.Get("name").String()
functionArgs := contentResult.Get("input").String()
var args map[string]any
if err = json.Unmarshal([]byte(functionArgs), &args); err == nil {
clientContent.Parts = append(clientContent.Parts, client.Part{
FunctionCall: &client.FunctionCall{
Name: functionName,
Args: args,
},
})
}
} else if contentTypeResult.Type == gjson.String && contentTypeResult.String() == "tool_result" {
toolCallID := contentResult.Get("tool_use_id").String()
if toolCallID != "" {
funcName := toolCallID
toolCallIDs := strings.Split(toolCallID, "-")
if len(toolCallIDs) > 1 {
funcName = strings.Join(toolCallIDs[0:len(toolCallIDs)-1], "-")
}
responseData := contentResult.Get("content").String()
functionResponse := client.FunctionResponse{Name: funcName, Response: map[string]interface{}{"result": responseData}}
clientContent.Parts = append(clientContent.Parts, client.Part{FunctionResponse: &functionResponse})
}
}
}
contents = append(contents, clientContent)
} else if contentsResult.Type == gjson.String {
prompt := contentsResult.String()
contents = append(contents, client.Content{Role: role, Parts: []client.Part{{Text: prompt}}})
}
}
}
var tools []client.ToolDeclaration
toolsResult := gjson.GetBytes(rawJson, "tools")
if toolsResult.IsArray() {
tools = make([]client.ToolDeclaration, 1)
tools[0].FunctionDeclarations = make([]any, 0)
toolsResults := toolsResult.Array()
for i := 0; i < len(toolsResults); i++ {
toolResult := toolsResults[i]
inputSchemaResult := toolResult.Get("input_schema")
if inputSchemaResult.Exists() && inputSchemaResult.IsObject() {
inputSchema := inputSchemaResult.Raw
inputSchema, _ = sjson.Delete(inputSchema, "additionalProperties")
inputSchema, _ = sjson.Delete(inputSchema, "$schema")
tool, _ := sjson.Delete(toolResult.Raw, "input_schema")
tool, _ = sjson.SetRaw(tool, "parameters", inputSchema)
var toolDeclaration any
if err = json.Unmarshal([]byte(tool), &toolDeclaration); err == nil {
tools[0].FunctionDeclarations = append(tools[0].FunctionDeclarations, toolDeclaration)
}
}
}
} else {
tools = make([]client.ToolDeclaration, 0)
}
return modelName, systemInstruction, contents, tools
}
func walk(value gjson.Result, path, field string, pathsToDelete *[]string) {
switch value.Type {
case gjson.JSON:
value.ForEach(func(key, val gjson.Result) bool {
var childPath string
if path == "" {
childPath = key.String()
} else {
childPath = path + "." + key.String()
}
if key.String() == field {
*pathsToDelete = append(*pathsToDelete, childPath)
}
walk(val, childPath, field, pathsToDelete)
return true
})
case gjson.String, gjson.Number, gjson.True, gjson.False, gjson.Null:
}
}

View File

@@ -1,382 +0,0 @@
package translator
import (
"bytes"
"fmt"
"time"
"github.com/tidwall/gjson"
"github.com/tidwall/sjson"
)
// ConvertCliToOpenAI translates a single chunk of a streaming response from the
// backend client format to the OpenAI Server-Sent Events (SSE) format.
// It returns an empty string if the chunk contains no useful data.
func ConvertCliToOpenAI(rawJson []byte, unixTimestamp int64, isGlAPIKey bool) string {
if isGlAPIKey {
rawJson, _ = sjson.SetRawBytes(rawJson, "response", rawJson)
}
// Initialize the OpenAI SSE template.
template := `{"id":"","object":"chat.completion.chunk","created":12345,"model":"model","choices":[{"index":0,"delta":{"role":null,"content":null,"reasoning_content":null,"tool_calls":null},"finish_reason":null,"native_finish_reason":null}]}`
// Extract and set the model version.
if modelVersionResult := gjson.GetBytes(rawJson, "response.modelVersion"); modelVersionResult.Exists() {
template, _ = sjson.Set(template, "model", modelVersionResult.String())
}
// Extract and set the creation timestamp.
if createTimeResult := gjson.GetBytes(rawJson, "response.createTime"); createTimeResult.Exists() {
t, err := time.Parse(time.RFC3339Nano, createTimeResult.String())
if err == nil {
unixTimestamp = t.Unix()
}
template, _ = sjson.Set(template, "created", unixTimestamp)
} else {
template, _ = sjson.Set(template, "created", unixTimestamp)
}
// Extract and set the response ID.
if responseIdResult := gjson.GetBytes(rawJson, "response.responseId"); responseIdResult.Exists() {
template, _ = sjson.Set(template, "id", responseIdResult.String())
}
// Extract and set the finish reason.
if finishReasonResult := gjson.GetBytes(rawJson, "response.candidates.0.finishReason"); finishReasonResult.Exists() {
template, _ = sjson.Set(template, "choices.0.finish_reason", finishReasonResult.String())
template, _ = sjson.Set(template, "choices.0.native_finish_reason", finishReasonResult.String())
}
// Extract and set usage metadata (token counts).
if usageResult := gjson.GetBytes(rawJson, "response.usageMetadata"); usageResult.Exists() {
if candidatesTokenCountResult := usageResult.Get("candidatesTokenCount"); candidatesTokenCountResult.Exists() {
template, _ = sjson.Set(template, "usage.completion_tokens", candidatesTokenCountResult.Int())
}
if totalTokenCountResult := usageResult.Get("totalTokenCount"); totalTokenCountResult.Exists() {
template, _ = sjson.Set(template, "usage.total_tokens", totalTokenCountResult.Int())
}
promptTokenCount := usageResult.Get("promptTokenCount").Int()
thoughtsTokenCount := usageResult.Get("thoughtsTokenCount").Int()
template, _ = sjson.Set(template, "usage.prompt_tokens", promptTokenCount+thoughtsTokenCount)
if thoughtsTokenCount > 0 {
template, _ = sjson.Set(template, "usage.completion_tokens_details.reasoning_tokens", thoughtsTokenCount)
}
}
// Process the main content part of the response.
partsResult := gjson.GetBytes(rawJson, "response.candidates.0.content.parts")
if partsResult.IsArray() {
partResults := partsResult.Array()
for i := 0; i < len(partResults); i++ {
partResult := partResults[i]
partTextResult := partResult.Get("text")
functionCallResult := partResult.Get("functionCall")
if partTextResult.Exists() {
// Handle text content, distinguishing between regular content and reasoning/thoughts.
if partResult.Get("thought").Bool() {
template, _ = sjson.Set(template, "choices.0.delta.reasoning_content", partTextResult.String())
} else {
template, _ = sjson.Set(template, "choices.0.delta.content", partTextResult.String())
}
template, _ = sjson.Set(template, "choices.0.delta.role", "assistant")
} else if functionCallResult.Exists() {
// Handle function call content.
toolCallsResult := gjson.Get(template, "choices.0.delta.tool_calls")
if !toolCallsResult.Exists() || !toolCallsResult.IsArray() {
template, _ = sjson.SetRaw(template, "choices.0.delta.tool_calls", `[]`)
}
functionCallTemplate := `{"id": "","type": "function","function": {"name": "","arguments": ""}}`
fcName := functionCallResult.Get("name").String()
functionCallTemplate, _ = sjson.Set(functionCallTemplate, "id", fmt.Sprintf("%s-%d", fcName, time.Now().UnixNano()))
functionCallTemplate, _ = sjson.Set(functionCallTemplate, "function.name", fcName)
if fcArgsResult := functionCallResult.Get("args"); fcArgsResult.Exists() {
functionCallTemplate, _ = sjson.Set(functionCallTemplate, "function.arguments", fcArgsResult.Raw)
}
template, _ = sjson.Set(template, "choices.0.delta.role", "assistant")
template, _ = sjson.SetRaw(template, "choices.0.message.tool_calls.-1", functionCallTemplate)
}
}
}
return template
}
// ConvertCliToOpenAINonStream aggregates response from the backend client
// convert a single, non-streaming OpenAI-compatible JSON response.
func ConvertCliToOpenAINonStream(rawJson []byte, unixTimestamp int64, isGlAPIKey bool) string {
if isGlAPIKey {
rawJson, _ = sjson.SetRawBytes(rawJson, "response", rawJson)
}
template := `{"id":"","object":"chat.completion","created":123456,"model":"model","choices":[{"index":0,"message":{"role":"assistant","content":null,"reasoning_content":null,"tool_calls":null},"finish_reason":null,"native_finish_reason":null}]}`
if modelVersionResult := gjson.GetBytes(rawJson, "response.modelVersion"); modelVersionResult.Exists() {
template, _ = sjson.Set(template, "model", modelVersionResult.String())
}
if createTimeResult := gjson.GetBytes(rawJson, "response.createTime"); createTimeResult.Exists() {
t, err := time.Parse(time.RFC3339Nano, createTimeResult.String())
if err == nil {
unixTimestamp = t.Unix()
}
template, _ = sjson.Set(template, "created", unixTimestamp)
} else {
template, _ = sjson.Set(template, "created", unixTimestamp)
}
if responseIdResult := gjson.GetBytes(rawJson, "response.responseId"); responseIdResult.Exists() {
template, _ = sjson.Set(template, "id", responseIdResult.String())
}
if finishReasonResult := gjson.GetBytes(rawJson, "response.candidates.0.finishReason"); finishReasonResult.Exists() {
template, _ = sjson.Set(template, "choices.0.finish_reason", finishReasonResult.String())
template, _ = sjson.Set(template, "choices.0.native_finish_reason", finishReasonResult.String())
}
if usageResult := gjson.GetBytes(rawJson, "response.usageMetadata"); usageResult.Exists() {
if candidatesTokenCountResult := usageResult.Get("candidatesTokenCount"); candidatesTokenCountResult.Exists() {
template, _ = sjson.Set(template, "usage.completion_tokens", candidatesTokenCountResult.Int())
}
if totalTokenCountResult := usageResult.Get("totalTokenCount"); totalTokenCountResult.Exists() {
template, _ = sjson.Set(template, "usage.total_tokens", totalTokenCountResult.Int())
}
promptTokenCount := usageResult.Get("promptTokenCount").Int()
thoughtsTokenCount := usageResult.Get("thoughtsTokenCount").Int()
template, _ = sjson.Set(template, "usage.prompt_tokens", promptTokenCount+thoughtsTokenCount)
if thoughtsTokenCount > 0 {
template, _ = sjson.Set(template, "usage.completion_tokens_details.reasoning_tokens", thoughtsTokenCount)
}
}
// Process the main content part of the response.
partsResult := gjson.GetBytes(rawJson, "response.candidates.0.content.parts")
if partsResult.IsArray() {
partsResults := partsResult.Array()
for i := 0; i < len(partsResults); i++ {
partResult := partsResults[i]
partTextResult := partResult.Get("text")
functionCallResult := partResult.Get("functionCall")
if partTextResult.Exists() {
// Append text content, distinguishing between regular content and reasoning.
if partResult.Get("thought").Bool() {
template, _ = sjson.Set(template, "choices.0.message.reasoning_content", partTextResult.String())
} else {
template, _ = sjson.Set(template, "choices.0.message.content", partTextResult.String())
}
template, _ = sjson.Set(template, "choices.0.message.role", "assistant")
} else if functionCallResult.Exists() {
// Append function call content to the tool_calls array.
toolCallsResult := gjson.Get(template, "choices.0.message.tool_calls")
if !toolCallsResult.Exists() || !toolCallsResult.IsArray() {
template, _ = sjson.SetRaw(template, "choices.0.message.tool_calls", `[]`)
}
functionCallItemTemplate := `{"id": "","type": "function","function": {"name": "","arguments": ""}}`
fcName := functionCallResult.Get("name").String()
functionCallItemTemplate, _ = sjson.Set(functionCallItemTemplate, "id", fmt.Sprintf("%s-%d", fcName, time.Now().UnixNano()))
functionCallItemTemplate, _ = sjson.Set(functionCallItemTemplate, "function.name", fcName)
if fcArgsResult := functionCallResult.Get("args"); fcArgsResult.Exists() {
functionCallItemTemplate, _ = sjson.Set(functionCallItemTemplate, "function.arguments", fcArgsResult.Raw)
}
template, _ = sjson.Set(template, "choices.0.message.role", "assistant")
template, _ = sjson.SetRaw(template, "choices.0.message.tool_calls.-1", functionCallItemTemplate)
} else {
// If no usable content is found, return an empty string.
return ""
}
}
}
return template
}
// ConvertCliToClaude performs sophisticated streaming response format conversion.
// This function implements a complex state machine that translates backend client responses
// into Claude-compatible Server-Sent Events (SSE) format. It manages different response types
// and handles state transitions between content blocks, thinking processes, and function calls.
//
// Response type states: 0=none, 1=content, 2=thinking, 3=function
// The function maintains state across multiple calls to ensure proper SSE event sequencing.
func ConvertCliToClaude(rawJson []byte, isGlAPIKey, hasFirstResponse bool, responseType, responseIndex *int) string {
// Normalize the response format for different API key types
// Generative Language API keys have a different response structure
if isGlAPIKey {
rawJson, _ = sjson.SetRawBytes(rawJson, "response", rawJson)
}
// Track whether tools are being used in this response chunk
usedTool := false
output := ""
// Initialize the streaming session with a message_start event
// This is only sent for the very first response chunk
if !hasFirstResponse {
output = "event: message_start\n"
// Create the initial message structure with default values
// This follows the Claude API specification for streaming message initialization
messageStartTemplate := `{"type": "message_start", "message": {"id": "msg_1nZdL29xx5MUA1yADyHTEsnR8uuvGzszyY", "type": "message", "role": "assistant", "content": [], "model": "claude-3-5-sonnet-20241022", "stop_reason": null, "stop_sequence": null, "usage": {"input_tokens": 0, "output_tokens": 0}}}`
// Override default values with actual response metadata if available
if modelVersionResult := gjson.GetBytes(rawJson, "response.modelVersion"); modelVersionResult.Exists() {
messageStartTemplate, _ = sjson.Set(messageStartTemplate, "message.model", modelVersionResult.String())
}
if responseIdResult := gjson.GetBytes(rawJson, "response.responseId"); responseIdResult.Exists() {
messageStartTemplate, _ = sjson.Set(messageStartTemplate, "message.id", responseIdResult.String())
}
output = output + fmt.Sprintf("data: %s\n\n\n", messageStartTemplate)
}
// Process the response parts array from the backend client
// Each part can contain text content, thinking content, or function calls
partsResult := gjson.GetBytes(rawJson, "response.candidates.0.content.parts")
if partsResult.IsArray() {
partResults := partsResult.Array()
for i := 0; i < len(partResults); i++ {
partResult := partResults[i]
// Extract the different types of content from each part
partTextResult := partResult.Get("text")
functionCallResult := partResult.Get("functionCall")
// Handle text content (both regular content and thinking)
if partTextResult.Exists() {
// Process thinking content (internal reasoning)
if partResult.Get("thought").Bool() {
// Continue existing thinking block
if *responseType == 2 {
output = output + "event: content_block_delta\n"
data, _ := sjson.Set(fmt.Sprintf(`{"type":"content_block_delta","index":%d,"delta":{"type":"thinking_delta","thinking":""}}`, *responseIndex), "delta.thinking", partTextResult.String())
output = output + fmt.Sprintf("data: %s\n\n\n", data)
} else {
// Transition from another state to thinking
// First, close any existing content block
if *responseType != 0 {
if *responseType == 2 {
output = output + "event: content_block_delta\n"
output = output + fmt.Sprintf(`data: {"type":"content_block_delta","index":%d,"delta":{"type":"signature_delta","signature":null}}`, *responseIndex)
output = output + "\n\n\n"
}
output = output + "event: content_block_stop\n"
output = output + fmt.Sprintf(`data: {"type":"content_block_stop","index":%d}`, *responseIndex)
output = output + "\n\n\n"
*responseIndex++
}
// Start a new thinking content block
output = output + "event: content_block_start\n"
output = output + fmt.Sprintf(`data: {"type":"content_block_start","index":%d,"content_block":{"type":"thinking","thinking":""}}`, *responseIndex)
output = output + "\n\n\n"
output = output + "event: content_block_delta\n"
data, _ := sjson.Set(fmt.Sprintf(`{"type":"content_block_delta","index":%d,"delta":{"type":"thinking_delta","thinking":""}}`, *responseIndex), "delta.thinking", partTextResult.String())
output = output + fmt.Sprintf("data: %s\n\n\n", data)
*responseType = 2 // Set state to thinking
}
} else {
// Process regular text content (user-visible output)
// Continue existing text block
if *responseType == 1 {
output = output + "event: content_block_delta\n"
data, _ := sjson.Set(fmt.Sprintf(`{"type":"content_block_delta","index":%d,"delta":{"type":"text_delta","text":""}}`, *responseIndex), "delta.text", partTextResult.String())
output = output + fmt.Sprintf("data: %s\n\n\n", data)
} else {
// Transition from another state to text content
// First, close any existing content block
if *responseType != 0 {
if *responseType == 2 {
output = output + "event: content_block_delta\n"
output = output + fmt.Sprintf(`data: {"type":"content_block_delta","index":%d,"delta":{"type":"signature_delta","signature":null}}`, *responseIndex)
output = output + "\n\n\n"
}
output = output + "event: content_block_stop\n"
output = output + fmt.Sprintf(`data: {"type":"content_block_stop","index":%d}`, *responseIndex)
output = output + "\n\n\n"
*responseIndex++
}
// Start a new text content block
output = output + "event: content_block_start\n"
output = output + fmt.Sprintf(`data: {"type":"content_block_start","index":%d,"content_block":{"type":"text","text":""}}`, *responseIndex)
output = output + "\n\n\n"
output = output + "event: content_block_delta\n"
data, _ := sjson.Set(fmt.Sprintf(`{"type":"content_block_delta","index":%d,"delta":{"type":"text_delta","text":""}}`, *responseIndex), "delta.text", partTextResult.String())
output = output + fmt.Sprintf("data: %s\n\n\n", data)
*responseType = 1 // Set state to content
}
}
} else if functionCallResult.Exists() {
// Handle function/tool calls from the AI model
// This processes tool usage requests and formats them for Claude API compatibility
usedTool = true
fcName := functionCallResult.Get("name").String()
// Handle state transitions when switching to function calls
// Close any existing function call block first
if *responseType == 3 {
output = output + "event: content_block_stop\n"
output = output + fmt.Sprintf(`data: {"type":"content_block_stop","index":%d}`, *responseIndex)
output = output + "\n\n\n"
*responseIndex++
*responseType = 0
}
// Special handling for thinking state transition
if *responseType == 2 {
output = output + "event: content_block_delta\n"
output = output + fmt.Sprintf(`data: {"type":"content_block_delta","index":%d,"delta":{"type":"signature_delta","signature":null}}`, *responseIndex)
output = output + "\n\n\n"
}
// Close any other existing content block
if *responseType != 0 {
output = output + "event: content_block_stop\n"
output = output + fmt.Sprintf(`data: {"type":"content_block_stop","index":%d}`, *responseIndex)
output = output + "\n\n\n"
*responseIndex++
}
// Start a new tool use content block
// This creates the structure for a function call in Claude format
output = output + "event: content_block_start\n"
// Create the tool use block with unique ID and function details
data := fmt.Sprintf(`{"type":"content_block_start","index":%d,"content_block":{"type":"tool_use","id":"","name":"","input":{}}}`, *responseIndex)
data, _ = sjson.Set(data, "content_block.id", fmt.Sprintf("%s-%d", fcName, time.Now().UnixNano()))
data, _ = sjson.Set(data, "content_block.name", fcName)
output = output + fmt.Sprintf("data: %s\n\n\n", data)
if fcArgsResult := functionCallResult.Get("args"); fcArgsResult.Exists() {
output = output + "event: content_block_delta\n"
data, _ = sjson.Set(fmt.Sprintf(`{"type":"content_block_delta","index":%d,"delta":{"type":"input_json_delta","partial_json":""}}`, *responseIndex), "delta.partial_json", fcArgsResult.Raw)
output = output + fmt.Sprintf("data: %s\n\n\n", data)
}
*responseType = 3
}
}
}
usageResult := gjson.GetBytes(rawJson, "response.usageMetadata")
if usageResult.Exists() && bytes.Contains(rawJson, []byte(`"finishReason"`)) {
if candidatesTokenCountResult := usageResult.Get("candidatesTokenCount"); candidatesTokenCountResult.Exists() {
output = output + "event: content_block_stop\n"
output = output + fmt.Sprintf(`data: {"type":"content_block_stop","index":%d}`, *responseIndex)
output = output + "\n\n\n"
output = output + "event: message_delta\n"
output = output + `data: `
template := `{"type":"message_delta","delta":{"stop_reason":"end_turn","stop_sequence":null},"usage":{"input_tokens":0,"output_tokens":0}}`
if usedTool {
template = `{"type":"message_delta","delta":{"stop_reason":"tool_use","stop_sequence":null},"usage":{"input_tokens":0,"output_tokens":0}}`
}
thoughtsTokenCount := usageResult.Get("thoughtsTokenCount").Int()
template, _ = sjson.Set(template, "usage.output_tokens", candidatesTokenCountResult.Int()+thoughtsTokenCount)
template, _ = sjson.Set(template, "usage.input_tokens", usageResult.Get("promptTokenCount").Int())
output = output + template + "\n\n\n"
}
}
return output
}

View File

@@ -1,3 +1,6 @@
// Package auth provides OAuth2 authentication functionality for Google Cloud APIs.
// It handles the complete OAuth2 flow including token storage, web-based authentication,
// proxy support, and automatic token refresh. The package supports both SOCKS5 and HTTP/HTTPS proxies.
package auth package auth
import ( import (
@@ -39,7 +42,7 @@ var (
// initiating a new web-based OAuth flow if necessary, and refreshing tokens. // initiating a new web-based OAuth flow if necessary, and refreshing tokens.
func GetAuthenticatedClient(ctx context.Context, ts *TokenStorage, cfg *config.Config) (*http.Client, error) { func GetAuthenticatedClient(ctx context.Context, ts *TokenStorage, cfg *config.Config) (*http.Client, error) {
// Configure proxy settings for the HTTP client if a proxy URL is provided. // Configure proxy settings for the HTTP client if a proxy URL is provided.
proxyURL, err := url.Parse(cfg.ProxyUrl) proxyURL, err := url.Parse(cfg.ProxyURL)
if err == nil { if err == nil {
var transport *http.Transport var transport *http.Transport
if proxyURL.Scheme == "socks5" { if proxyURL.Scheme == "socks5" {

View File

@@ -1,3 +1,7 @@
// Package client provides HTTP client functionality for interacting with Google Cloud AI APIs.
// It handles OAuth2 authentication, token management, request/response processing,
// streaming communication, quota management, and automatic model fallback.
// The package supports both direct API key authentication and OAuth2 flows.
package client package client
import ( import (
@@ -29,7 +33,7 @@ const (
pluginVersion = "0.1.9" pluginVersion = "0.1.9"
glEndPoint = "https://generativelanguage.googleapis.com" glEndPoint = "https://generativelanguage.googleapis.com"
glApiVersion = "v1beta" glAPIVersion = "v1beta"
) )
var ( var (
@@ -64,30 +68,37 @@ func NewClient(httpClient *http.Client, ts *auth.TokenStorage, cfg *config.Confi
} }
} }
// SetProjectID updates the project ID for the client's token storage.
func (c *Client) SetProjectID(projectID string) { func (c *Client) SetProjectID(projectID string) {
c.tokenStorage.ProjectID = projectID c.tokenStorage.ProjectID = projectID
} }
// SetIsAuto configures whether the client should operate in automatic mode.
func (c *Client) SetIsAuto(auto bool) { func (c *Client) SetIsAuto(auto bool) {
c.tokenStorage.Auto = auto c.tokenStorage.Auto = auto
} }
// SetIsChecked sets the checked status for the client's token storage.
func (c *Client) SetIsChecked(checked bool) { func (c *Client) SetIsChecked(checked bool) {
c.tokenStorage.Checked = checked c.tokenStorage.Checked = checked
} }
// IsChecked returns whether the client's token storage has been checked.
func (c *Client) IsChecked() bool { func (c *Client) IsChecked() bool {
return c.tokenStorage.Checked return c.tokenStorage.Checked
} }
// IsAuto returns whether the client is operating in automatic mode.
func (c *Client) IsAuto() bool { func (c *Client) IsAuto() bool {
return c.tokenStorage.Auto return c.tokenStorage.Auto
} }
// GetEmail returns the email address associated with the client's token storage.
func (c *Client) GetEmail() string { func (c *Client) GetEmail() string {
return c.tokenStorage.Email return c.tokenStorage.Email
} }
// GetProjectID returns the Google Cloud project ID from the client's token storage.
func (c *Client) GetProjectID() string { func (c *Client) GetProjectID() string {
if c.tokenStorage != nil { if c.tokenStorage != nil {
return c.tokenStorage.ProjectID return c.tokenStorage.ProjectID
@@ -95,6 +106,7 @@ func (c *Client) GetProjectID() string {
return "" return ""
} }
// GetGenerativeLanguageAPIKey returns the generative language API key if configured.
func (c *Client) GetGenerativeLanguageAPIKey() string { func (c *Client) GetGenerativeLanguageAPIKey() string {
return c.glAPIKey return c.glAPIKey
} }
@@ -267,10 +279,10 @@ func (c *Client) APIRequest(ctx context.Context, endpoint string, body interface
} else { } else {
if endpoint == "countTokens" { if endpoint == "countTokens" {
modelResult := gjson.GetBytes(jsonBody, "model") modelResult := gjson.GetBytes(jsonBody, "model")
url = fmt.Sprintf("%s/%s/models/%s:%s", glEndPoint, glApiVersion, modelResult.String(), endpoint) url = fmt.Sprintf("%s/%s/models/%s:%s", glEndPoint, glAPIVersion, modelResult.String(), endpoint)
} else { } else {
modelResult := gjson.GetBytes(jsonBody, "model") modelResult := gjson.GetBytes(jsonBody, "model")
url = fmt.Sprintf("%s/%s/models/%s:%s", glEndPoint, glApiVersion, modelResult.String(), endpoint) url = fmt.Sprintf("%s/%s/models/%s:%s", glEndPoint, glAPIVersion, modelResult.String(), endpoint)
if alt == "" && stream { if alt == "" && stream {
url = url + "?alt=sse" url = url + "?alt=sse"
} else { } else {
@@ -333,7 +345,7 @@ func (c *Client) APIRequest(ctx context.Context, endpoint string, body interface
} }
// SendMessage handles a single conversational turn, including tool calls. // SendMessage handles a single conversational turn, including tool calls.
func (c *Client) SendMessage(ctx context.Context, rawJson []byte, model string, systemInstruction *Content, contents []Content, tools []ToolDeclaration) ([]byte, *ErrorMessage) { func (c *Client) SendMessage(ctx context.Context, rawJSON []byte, model string, systemInstruction *Content, contents []Content, tools []ToolDeclaration) ([]byte, *ErrorMessage) {
request := GenerateContentRequest{ request := GenerateContentRequest{
Contents: contents, Contents: contents,
GenerationConfig: GenerationConfig{ GenerationConfig: GenerationConfig{
@@ -357,7 +369,7 @@ func (c *Client) SendMessage(ctx context.Context, rawJson []byte, model string,
// log.Debug(string(byteRequestBody)) // log.Debug(string(byteRequestBody))
reasoningEffortResult := gjson.GetBytes(rawJson, "reasoning_effort") reasoningEffortResult := gjson.GetBytes(rawJSON, "reasoning_effort")
if reasoningEffortResult.String() == "none" { if reasoningEffortResult.String() == "none" {
byteRequestBody, _ = sjson.DeleteBytes(byteRequestBody, "request.generationConfig.thinkingConfig.include_thoughts") byteRequestBody, _ = sjson.DeleteBytes(byteRequestBody, "request.generationConfig.thinkingConfig.include_thoughts")
byteRequestBody, _ = sjson.SetBytes(byteRequestBody, "request.generationConfig.thinkingConfig.thinkingBudget", 0) byteRequestBody, _ = sjson.SetBytes(byteRequestBody, "request.generationConfig.thinkingConfig.thinkingBudget", 0)
@@ -373,17 +385,17 @@ func (c *Client) SendMessage(ctx context.Context, rawJson []byte, model string,
byteRequestBody, _ = sjson.SetBytes(byteRequestBody, "request.generationConfig.thinkingConfig.thinkingBudget", -1) byteRequestBody, _ = sjson.SetBytes(byteRequestBody, "request.generationConfig.thinkingConfig.thinkingBudget", -1)
} }
temperatureResult := gjson.GetBytes(rawJson, "temperature") temperatureResult := gjson.GetBytes(rawJSON, "temperature")
if temperatureResult.Exists() && temperatureResult.Type == gjson.Number { if temperatureResult.Exists() && temperatureResult.Type == gjson.Number {
byteRequestBody, _ = sjson.SetBytes(byteRequestBody, "request.generationConfig.temperature", temperatureResult.Num) byteRequestBody, _ = sjson.SetBytes(byteRequestBody, "request.generationConfig.temperature", temperatureResult.Num)
} }
topPResult := gjson.GetBytes(rawJson, "top_p") topPResult := gjson.GetBytes(rawJSON, "top_p")
if topPResult.Exists() && topPResult.Type == gjson.Number { if topPResult.Exists() && topPResult.Type == gjson.Number {
byteRequestBody, _ = sjson.SetBytes(byteRequestBody, "request.generationConfig.topP", topPResult.Num) byteRequestBody, _ = sjson.SetBytes(byteRequestBody, "request.generationConfig.topP", topPResult.Num)
} }
topKResult := gjson.GetBytes(rawJson, "top_k") topKResult := gjson.GetBytes(rawJSON, "top_k")
if topKResult.Exists() && topKResult.Type == gjson.Number { if topKResult.Exists() && topKResult.Type == gjson.Number {
byteRequestBody, _ = sjson.SetBytes(byteRequestBody, "request.generationConfig.topK", topKResult.Num) byteRequestBody, _ = sjson.SetBytes(byteRequestBody, "request.generationConfig.topK", topKResult.Num)
} }
@@ -430,7 +442,7 @@ func (c *Client) SendMessage(ctx context.Context, rawJson []byte, model string,
// This function implements a sophisticated streaming system that supports tool calls, reasoning modes, // This function implements a sophisticated streaming system that supports tool calls, reasoning modes,
// quota management, and automatic model fallback. It returns two channels for asynchronous communication: // quota management, and automatic model fallback. It returns two channels for asynchronous communication:
// one for streaming response data and another for error handling. // one for streaming response data and another for error handling.
func (c *Client) SendMessageStream(ctx context.Context, rawJson []byte, model string, systemInstruction *Content, contents []Content, tools []ToolDeclaration, includeThoughts ...bool) (<-chan []byte, <-chan *ErrorMessage) { func (c *Client) SendMessageStream(ctx context.Context, rawJSON []byte, model string, systemInstruction *Content, contents []Content, tools []ToolDeclaration, includeThoughts ...bool) (<-chan []byte, <-chan *ErrorMessage) {
// Define the data prefix used in Server-Sent Events streaming format // Define the data prefix used in Server-Sent Events streaming format
dataTag := []byte("data: ") dataTag := []byte("data: ")
@@ -486,7 +498,7 @@ func (c *Client) SendMessageStream(ctx context.Context, rawJson []byte, model st
// Parse and configure reasoning effort levels from the original request // Parse and configure reasoning effort levels from the original request
// This maps Claude-style reasoning effort parameters to Gemini's thinking budget system // This maps Claude-style reasoning effort parameters to Gemini's thinking budget system
reasoningEffortResult := gjson.GetBytes(rawJson, "reasoning_effort") reasoningEffortResult := gjson.GetBytes(rawJSON, "reasoning_effort")
if reasoningEffortResult.String() == "none" { if reasoningEffortResult.String() == "none" {
// Disable thinking entirely for fastest responses // Disable thinking entirely for fastest responses
byteRequestBody, _ = sjson.DeleteBytes(byteRequestBody, "request.generationConfig.thinkingConfig.include_thoughts") byteRequestBody, _ = sjson.DeleteBytes(byteRequestBody, "request.generationConfig.thinkingConfig.include_thoughts")
@@ -510,21 +522,21 @@ func (c *Client) SendMessageStream(ctx context.Context, rawJson []byte, model st
// Configure temperature parameter for response randomness control // Configure temperature parameter for response randomness control
// Temperature affects the creativity vs consistency trade-off in responses // Temperature affects the creativity vs consistency trade-off in responses
temperatureResult := gjson.GetBytes(rawJson, "temperature") temperatureResult := gjson.GetBytes(rawJSON, "temperature")
if temperatureResult.Exists() && temperatureResult.Type == gjson.Number { if temperatureResult.Exists() && temperatureResult.Type == gjson.Number {
byteRequestBody, _ = sjson.SetBytes(byteRequestBody, "request.generationConfig.temperature", temperatureResult.Num) byteRequestBody, _ = sjson.SetBytes(byteRequestBody, "request.generationConfig.temperature", temperatureResult.Num)
} }
// Configure top-p parameter for nucleus sampling // Configure top-p parameter for nucleus sampling
// Controls the cumulative probability threshold for token selection // Controls the cumulative probability threshold for token selection
topPResult := gjson.GetBytes(rawJson, "top_p") topPResult := gjson.GetBytes(rawJSON, "top_p")
if topPResult.Exists() && topPResult.Type == gjson.Number { if topPResult.Exists() && topPResult.Type == gjson.Number {
byteRequestBody, _ = sjson.SetBytes(byteRequestBody, "request.generationConfig.topP", topPResult.Num) byteRequestBody, _ = sjson.SetBytes(byteRequestBody, "request.generationConfig.topP", topPResult.Num)
} }
// Configure top-k parameter for limiting token candidates // Configure top-k parameter for limiting token candidates
// Restricts the model to consider only the top K most likely tokens // Restricts the model to consider only the top K most likely tokens
topKResult := gjson.GetBytes(rawJson, "top_k") topKResult := gjson.GetBytes(rawJSON, "top_k")
if topKResult.Exists() && topKResult.Type == gjson.Number { if topKResult.Exists() && topKResult.Type == gjson.Number {
byteRequestBody, _ = sjson.SetBytes(byteRequestBody, "request.generationConfig.topK", topKResult.Num) byteRequestBody, _ = sjson.SetBytes(byteRequestBody, "request.generationConfig.topK", topKResult.Num)
} }
@@ -608,8 +620,8 @@ func (c *Client) SendMessageStream(ctx context.Context, rawJson []byte, model st
} }
// SendRawTokenCount handles a token count. // SendRawTokenCount handles a token count.
func (c *Client) SendRawTokenCount(ctx context.Context, rawJson []byte, alt string) ([]byte, *ErrorMessage) { func (c *Client) SendRawTokenCount(ctx context.Context, rawJSON []byte, alt string) ([]byte, *ErrorMessage) {
modelResult := gjson.GetBytes(rawJson, "model") modelResult := gjson.GetBytes(rawJSON, "model")
model := modelResult.String() model := modelResult.String()
modelName := model modelName := model
for { for {
@@ -618,7 +630,7 @@ func (c *Client) SendRawTokenCount(ctx context.Context, rawJson []byte, alt stri
modelName = c.getPreviewModel(model) modelName = c.getPreviewModel(model)
if modelName != "" { if modelName != "" {
log.Debugf("Model %s is quota exceeded. Switch to preview model %s", model, modelName) log.Debugf("Model %s is quota exceeded. Switch to preview model %s", model, modelName)
rawJson, _ = sjson.SetBytes(rawJson, "model", modelName) rawJSON, _ = sjson.SetBytes(rawJSON, "model", modelName)
continue continue
} }
} }
@@ -628,7 +640,7 @@ func (c *Client) SendRawTokenCount(ctx context.Context, rawJson []byte, alt stri
} }
} }
respBody, err := c.APIRequest(ctx, "countTokens", rawJson, alt, false) respBody, err := c.APIRequest(ctx, "countTokens", rawJSON, alt, false)
if err != nil { if err != nil {
if err.StatusCode == 429 { if err.StatusCode == 429 {
now := time.Now() now := time.Now()
@@ -649,12 +661,12 @@ func (c *Client) SendRawTokenCount(ctx context.Context, rawJson []byte, alt stri
} }
// SendRawMessage handles a single conversational turn, including tool calls. // SendRawMessage handles a single conversational turn, including tool calls.
func (c *Client) SendRawMessage(ctx context.Context, rawJson []byte, alt string) ([]byte, *ErrorMessage) { func (c *Client) SendRawMessage(ctx context.Context, rawJSON []byte, alt string) ([]byte, *ErrorMessage) {
if c.glAPIKey == "" { if c.glAPIKey == "" {
rawJson, _ = sjson.SetBytes(rawJson, "project", c.GetProjectID()) rawJSON, _ = sjson.SetBytes(rawJSON, "project", c.GetProjectID())
} }
modelResult := gjson.GetBytes(rawJson, "model") modelResult := gjson.GetBytes(rawJSON, "model")
model := modelResult.String() model := modelResult.String()
modelName := model modelName := model
for { for {
@@ -663,7 +675,7 @@ func (c *Client) SendRawMessage(ctx context.Context, rawJson []byte, alt string)
modelName = c.getPreviewModel(model) modelName = c.getPreviewModel(model)
if modelName != "" { if modelName != "" {
log.Debugf("Model %s is quota exceeded. Switch to preview model %s", model, modelName) log.Debugf("Model %s is quota exceeded. Switch to preview model %s", model, modelName)
rawJson, _ = sjson.SetBytes(rawJson, "model", modelName) rawJSON, _ = sjson.SetBytes(rawJSON, "model", modelName)
continue continue
} }
} }
@@ -673,7 +685,7 @@ func (c *Client) SendRawMessage(ctx context.Context, rawJson []byte, alt string)
} }
} }
respBody, err := c.APIRequest(ctx, "generateContent", rawJson, alt, false) respBody, err := c.APIRequest(ctx, "generateContent", rawJSON, alt, false)
if err != nil { if err != nil {
if err.StatusCode == 429 { if err.StatusCode == 429 {
now := time.Now() now := time.Now()
@@ -694,7 +706,7 @@ func (c *Client) SendRawMessage(ctx context.Context, rawJson []byte, alt string)
} }
// SendRawMessageStream handles a single conversational turn, including tool calls. // SendRawMessageStream handles a single conversational turn, including tool calls.
func (c *Client) SendRawMessageStream(ctx context.Context, rawJson []byte, alt string) (<-chan []byte, <-chan *ErrorMessage) { func (c *Client) SendRawMessageStream(ctx context.Context, rawJSON []byte, alt string) (<-chan []byte, <-chan *ErrorMessage) {
dataTag := []byte("data: ") dataTag := []byte("data: ")
errChan := make(chan *ErrorMessage) errChan := make(chan *ErrorMessage)
dataChan := make(chan []byte) dataChan := make(chan []byte)
@@ -703,10 +715,10 @@ func (c *Client) SendRawMessageStream(ctx context.Context, rawJson []byte, alt s
defer close(dataChan) defer close(dataChan)
if c.glAPIKey == "" { if c.glAPIKey == "" {
rawJson, _ = sjson.SetBytes(rawJson, "project", c.GetProjectID()) rawJSON, _ = sjson.SetBytes(rawJSON, "project", c.GetProjectID())
} }
modelResult := gjson.GetBytes(rawJson, "model") modelResult := gjson.GetBytes(rawJSON, "model")
model := modelResult.String() model := modelResult.String()
modelName := model modelName := model
var stream io.ReadCloser var stream io.ReadCloser
@@ -716,7 +728,7 @@ func (c *Client) SendRawMessageStream(ctx context.Context, rawJson []byte, alt s
modelName = c.getPreviewModel(model) modelName = c.getPreviewModel(model)
if modelName != "" { if modelName != "" {
log.Debugf("Model %s is quota exceeded. Switch to preview model %s", model, modelName) log.Debugf("Model %s is quota exceeded. Switch to preview model %s", model, modelName)
rawJson, _ = sjson.SetBytes(rawJson, "model", modelName) rawJSON, _ = sjson.SetBytes(rawJSON, "model", modelName)
continue continue
} }
} }
@@ -727,7 +739,7 @@ func (c *Client) SendRawMessageStream(ctx context.Context, rawJson []byte, alt s
return return
} }
var err *ErrorMessage var err *ErrorMessage
stream, err = c.APIRequest(ctx, "streamGenerateContent", rawJson, alt, true) stream, err = c.APIRequest(ctx, "streamGenerateContent", rawJSON, alt, true)
if err != nil { if err != nil {
if err.StatusCode == 429 { if err.StatusCode == 429 {
now := time.Now() now := time.Now()
@@ -774,6 +786,8 @@ func (c *Client) SendRawMessageStream(ctx context.Context, rawJson []byte, alt s
return dataChan, errChan return dataChan, errChan
} }
// isModelQuotaExceeded checks if the specified model has exceeded its quota
// within the last 30 minutes.
func (c *Client) isModelQuotaExceeded(model string) bool { func (c *Client) isModelQuotaExceeded(model string) bool {
if lastExceededTime, hasKey := c.modelQuotaExceeded[model]; hasKey { if lastExceededTime, hasKey := c.modelQuotaExceeded[model]; hasKey {
duration := time.Now().Sub(*lastExceededTime) duration := time.Now().Sub(*lastExceededTime)
@@ -785,6 +799,8 @@ func (c *Client) isModelQuotaExceeded(model string) bool {
return false return false
} }
// getPreviewModel returns an available preview model for the given base model,
// or an empty string if no preview models are available or all are quota exceeded.
func (c *Client) getPreviewModel(model string) string { func (c *Client) getPreviewModel(model string) string {
if models, hasKey := previewModels[model]; hasKey { if models, hasKey := previewModels[model]; hasKey {
for i := 0; i < len(models); i++ { for i := 0; i < len(models); i++ {
@@ -796,6 +812,8 @@ func (c *Client) getPreviewModel(model string) string {
return "" return ""
} }
// IsModelQuotaExceeded returns true if the specified model has exceeded its quota
// and no fallback options are available.
func (c *Client) IsModelQuotaExceeded(model string) bool { func (c *Client) IsModelQuotaExceeded(model string) bool {
if c.isModelQuotaExceeded(model) { if c.isModelQuotaExceeded(model) {
if c.cfg.QuotaExceeded.SwitchPreviewModel { if c.cfg.QuotaExceeded.SwitchPreviewModel {
@@ -824,20 +842,20 @@ func (c *Client) CheckCloudAPIIsEnabled() (bool, error) {
if err != nil { if err != nil {
// If a 403 Forbidden error occurs, it likely means the API is not enabled. // If a 403 Forbidden error occurs, it likely means the API is not enabled.
if err.StatusCode == 403 { if err.StatusCode == 403 {
errJson := err.Error.Error() errJSON := err.Error.Error()
// Check for a specific error code and extract the activation URL. // Check for a specific error code and extract the activation URL.
if gjson.Get(errJson, "error.code").Int() == 403 { if gjson.Get(errJSON, "error.code").Int() == 403 {
activationUrl := gjson.Get(errJson, "error.details.0.metadata.activationUrl").String() activationURL := gjson.Get(errJSON, "error.details.0.metadata.activationUrl").String()
if activationUrl != "" { if activationURL != "" {
log.Warnf( log.Warnf(
"\n\nPlease activate your account with this url:\n\n%s\n And execute this command again:\n%s --login --project_id %s", "\n\nPlease activate your account with this url:\n\n%s\n And execute this command again:\n%s --login --project_id %s",
activationUrl, activationURL,
os.Args[0], os.Args[0],
c.tokenStorage.ProjectID, c.tokenStorage.ProjectID,
) )
} }
} }
log.Warnf("\n\nPlease copy this message and create an issue.\n\n%s\n\n", errJson) log.Warnf("\n\nPlease copy this message and create an issue.\n\n%s\n\n", errJSON)
return false, nil return false, nil
} }
return false, err.Error return false, err.Error

View File

@@ -1,3 +1,6 @@
// Package cmd provides command-line interface functionality for the CLI Proxy API.
// It implements the main application commands including login/authentication
// and server startup, handling the complete user onboarding and service lifecycle.
package cmd package cmd
import ( import (

View File

@@ -1,3 +1,8 @@
// Package cmd provides the main service execution functionality for the CLIProxyAPI.
// It contains the core logic for starting and managing the API proxy service,
// including authentication client management, server initialization, and graceful shutdown handling.
// The package handles loading authentication tokens, creating client pools, starting the API server,
// and monitoring configuration changes through file watchers.
package cmd package cmd
import ( import (

View File

@@ -1,3 +1,7 @@
// Package config provides configuration management for the CLI Proxy API server.
// It handles loading and parsing YAML configuration files, and provides structured
// access to application settings including server port, authentication directory,
// debug settings, proxy configuration, and API keys.
package config package config
import ( import (
@@ -14,17 +18,19 @@ type Config struct {
AuthDir string `yaml:"auth-dir"` AuthDir string `yaml:"auth-dir"`
// Debug enables or disables debug-level logging and other debug features. // Debug enables or disables debug-level logging and other debug features.
Debug bool `yaml:"debug"` Debug bool `yaml:"debug"`
// ProxyUrl is the URL of an optional proxy server to use for outbound requests. // ProxyURL is the URL of an optional proxy server to use for outbound requests.
ProxyUrl string `yaml:"proxy-url"` ProxyURL string `yaml:"proxy-url"`
// ApiKeys is a list of keys for authenticating clients to this proxy server. // APIKeys is a list of keys for authenticating clients to this proxy server.
ApiKeys []string `yaml:"api-keys"` APIKeys []string `yaml:"api-keys"`
// QuotaExceeded defines the behavior when a quota is exceeded. // QuotaExceeded defines the behavior when a quota is exceeded.
QuotaExceeded ConfigQuotaExceeded `yaml:"quota-exceeded"` QuotaExceeded QuotaExceeded `yaml:"quota-exceeded"`
// GlAPIKey is the API key for the generative language API. // GlAPIKey is the API key for the generative language API.
GlAPIKey []string `yaml:"generative-language-api-key"` GlAPIKey []string `yaml:"generative-language-api-key"`
} }
type ConfigQuotaExceeded struct { // QuotaExceeded defines the behavior when API quota limits are exceeded.
// It provides configuration options for automatic failover mechanisms.
type QuotaExceeded struct {
// SwitchProject indicates whether to automatically switch to another project when a quota is exceeded. // SwitchProject indicates whether to automatically switch to another project when a quota is exceeded.
SwitchProject bool `yaml:"switch-project"` SwitchProject bool `yaml:"switch-project"`
// SwitchPreviewModel indicates whether to automatically switch to a preview model when a quota is exceeded. // SwitchPreviewModel indicates whether to automatically switch to a preview model when a quota is exceeded.

View File

@@ -1,3 +1,6 @@
// Package util provides utility functions for the CLI Proxy API server.
// It includes helper functions for proxy configuration, HTTP client setup,
// and other common operations used across the application.
package util package util
import ( import (
@@ -9,9 +12,12 @@ import (
"net/url" "net/url"
) )
// SetProxy configures the provided HTTP client with proxy settings from the configuration.
// It supports SOCKS5, HTTP, and HTTPS proxies. The function modifies the client's transport
// to route requests through the configured proxy server.
func SetProxy(cfg *config.Config, httpClient *http.Client) (*http.Client, error) { func SetProxy(cfg *config.Config, httpClient *http.Client) (*http.Client, error) {
var transport *http.Transport var transport *http.Transport
proxyURL, errParse := url.Parse(cfg.ProxyUrl) proxyURL, errParse := url.Parse(cfg.ProxyURL)
if errParse == nil { if errParse == nil {
if proxyURL.Scheme == "socks5" { if proxyURL.Scheme == "socks5" {
username := proxyURL.User.Username() username := proxyURL.User.Username()

View File

@@ -1,3 +1,7 @@
// Package watcher provides file system monitoring functionality for the CLI Proxy API.
// It watches configuration files and authentication directories for changes,
// automatically reloading clients and configuration when files are modified.
// The package handles cross-platform file system events and supports hot-reloading.
package watcher package watcher
import ( import (
@@ -156,11 +160,11 @@ func (w *Watcher) reloadConfig() {
if oldConfig.Debug != newConfig.Debug { if oldConfig.Debug != newConfig.Debug {
log.Debugf(" debug: %t -> %t", oldConfig.Debug, newConfig.Debug) log.Debugf(" debug: %t -> %t", oldConfig.Debug, newConfig.Debug)
} }
if oldConfig.ProxyUrl != newConfig.ProxyUrl { if oldConfig.ProxyURL != newConfig.ProxyURL {
log.Debugf(" proxy-url: %s -> %s", oldConfig.ProxyUrl, newConfig.ProxyUrl) log.Debugf(" proxy-url: %s -> %s", oldConfig.ProxyURL, newConfig.ProxyURL)
} }
if len(oldConfig.ApiKeys) != len(newConfig.ApiKeys) { if len(oldConfig.APIKeys) != len(newConfig.APIKeys) {
log.Debugf(" api-keys count: %d -> %d", len(oldConfig.ApiKeys), len(newConfig.ApiKeys)) log.Debugf(" api-keys count: %d -> %d", len(oldConfig.APIKeys), len(newConfig.APIKeys))
} }
if len(oldConfig.GlAPIKey) != len(newConfig.GlAPIKey) { if len(oldConfig.GlAPIKey) != len(newConfig.GlAPIKey) {
log.Debugf(" generative-language-api-key count: %d -> %d", len(oldConfig.GlAPIKey), len(newConfig.GlAPIKey)) log.Debugf(" generative-language-api-key count: %d -> %d", len(oldConfig.GlAPIKey), len(newConfig.GlAPIKey))
@@ -248,7 +252,7 @@ func (w *Watcher) reloadClients() {
log.Debugf("auth directory scan complete - found %d .json files, %d successful authentications", authFileCount, successfulAuthCount) log.Debugf("auth directory scan complete - found %d .json files, %d successful authentications", authFileCount, successfulAuthCount)
// Add clients for Generative Language API keys if configured // Add clients for Generative Language API keys if configured
glApiKeyCount := 0 glAPIKeyCount := 0
if len(cfg.GlAPIKey) > 0 { if len(cfg.GlAPIKey) > 0 {
log.Debugf("processing %d Generative Language API keys", len(cfg.GlAPIKey)) log.Debugf("processing %d Generative Language API keys", len(cfg.GlAPIKey))
for i := 0; i < len(cfg.GlAPIKey); i++ { for i := 0; i < len(cfg.GlAPIKey); i++ {
@@ -261,9 +265,9 @@ func (w *Watcher) reloadClients() {
log.Debugf(" initializing with Generative Language API key %d...", i+1) log.Debugf(" initializing with Generative Language API key %d...", i+1)
cliClient := client.NewClient(httpClient, nil, cfg, cfg.GlAPIKey[i]) cliClient := client.NewClient(httpClient, nil, cfg, cfg.GlAPIKey[i])
newClients = append(newClients, cliClient) newClients = append(newClients, cliClient)
glApiKeyCount++ glAPIKeyCount++
} }
log.Debugf("successfully initialized %d Generative Language API key clients", glApiKeyCount) log.Debugf("successfully initialized %d Generative Language API key clients", glAPIKeyCount)
} }
// Update the client list // Update the client list
@@ -272,7 +276,7 @@ func (w *Watcher) reloadClients() {
w.clientsMutex.Unlock() w.clientsMutex.Unlock()
log.Infof("client reload complete - old: %d clients, new: %d clients (%d auth files + %d GL API keys)", log.Infof("client reload complete - old: %d clients, new: %d clients (%d auth files + %d GL API keys)",
oldClientCount, len(newClients), successfulAuthCount, glApiKeyCount) oldClientCount, len(newClients), successfulAuthCount, glAPIKeyCount)
// Trigger the callback to update the server // Trigger the callback to update the server
if w.reloadCallback != nil { if w.reloadCallback != nil {