mirror of
https://github.com/router-for-me/CLIProxyAPI.git
synced 2026-02-03 04:50:52 +08:00
Refactor API handlers organization and simplify error response handling
- Modularized handlers into dedicated packages (`gemini`, `claude`, `cli`) for better structure. - Centralized `ErrorResponse` and `ErrorDetail` types under `handlers` package for reuse. - Updated all handlers to utilize the shared `ErrorResponse` model. - Introduced specialization of handler structs (`GeminiAPIHandlers`, `ClaudeCodeAPIHandlers`, `GeminiCLIAPIHandlers`) for improved clarity and separation of concerns. - Refactored `getClient` logic with additional properties and better state management. Refactor `translator` package by modularizing code for `claude` and `gemini` - Moved Claude-specific logic (`PrepareClaudeRequest`, `ConvertCliToClaude`) to `translator/claude/code`. - Moved Gemini-specific logic (`FixCLIToolResponse`) to `translator/gemini/cli` for better package structure. - Updated affected handler imports and method references. Add comprehensive package-level documentation across key modules - Introduced detailed package-level documentation for core modules: `auth`, `client`, `cmd`, `handlers`, `util`, `watcher`, `config`, `translator`, and `api`. - Enhanced code readability and maintainability by clarifying the purpose and functionality of each package. - Aligned documentation style and tone with existing codebase conventions. Refactor API handlers and translator modules for improved clarity and consistency - Standardized handler struct names (`GeminiAPIHandlers`, `ClaudeCodeAPIHandlers`, `GeminiCLIAPIHandlers`, `OpenAIAPIHandlers`) and updated related comments. - Fixed unnecessary `else` blocks in streaming logic for cleaner error handling. - Renamed variables for better readability (`responseIdResult` to `responseIDResult`, `activationUrl` to `activationURL`, etc.). - Addressed minor inconsistencies in API handler comments and SSE header initialization. - Improved modularization of `claude` and `gemini` translator components. Standardize configuration field naming for consistency across modules - Renamed `ProxyUrl` to `ProxyURL`, `ApiKeys` to `APIKeys`, and `ConfigQuotaExceeded` to `QuotaExceeded`. - Updated all relevant references and comments in `config`, `auth`, `api`, `util`, and `watcher`. - Ensured consistent casing for `GlAPIKey` debug logs.
This commit is contained in:
@@ -1,3 +1,6 @@
|
|||||||
|
// Package main provides the entry point for the CLI Proxy API server.
|
||||||
|
// This server acts as a proxy that provides OpenAI/Gemini/Claude compatible API interfaces
|
||||||
|
// for CLI models, allowing CLI models to be used with tools and libraries designed for standard AI APIs.
|
||||||
package main
|
package main
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
|||||||
@@ -1,10 +1,17 @@
|
|||||||
package api
|
// Package claude provides HTTP handlers for Claude API code-related functionality.
|
||||||
|
// This package implements Claude-compatible streaming chat completions with sophisticated
|
||||||
|
// client rotation and quota management systems to ensure high availability and optimal
|
||||||
|
// resource utilization across multiple backend clients. It handles request translation
|
||||||
|
// between Claude API format and the underlying Gemini backend, providing seamless
|
||||||
|
// API compatibility while maintaining robust error handling and connection management.
|
||||||
|
package claude
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
"github.com/luispater/CLIProxyAPI/internal/api/translator"
|
"github.com/luispater/CLIProxyAPI/internal/api/handlers"
|
||||||
|
"github.com/luispater/CLIProxyAPI/internal/api/translator/claude/code"
|
||||||
"github.com/luispater/CLIProxyAPI/internal/client"
|
"github.com/luispater/CLIProxyAPI/internal/client"
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
"net/http"
|
"net/http"
|
||||||
@@ -12,16 +19,30 @@ import (
|
|||||||
"time"
|
"time"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// ClaudeCodeAPIHandlers contains the handlers for Claude API endpoints.
|
||||||
|
// It holds a pool of clients to interact with the backend service.
|
||||||
|
type ClaudeCodeAPIHandlers struct {
|
||||||
|
*handlers.APIHandlers
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewClaudeCodeAPIHandlers creates a new Claude API handlers instance.
|
||||||
|
// It takes an APIHandlers instance as input and returns a ClaudeCodeAPIHandlers.
|
||||||
|
func NewClaudeCodeAPIHandlers(apiHandlers *handlers.APIHandlers) *ClaudeCodeAPIHandlers {
|
||||||
|
return &ClaudeCodeAPIHandlers{
|
||||||
|
APIHandlers: apiHandlers,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// ClaudeMessages handles Claude-compatible streaming chat completions.
|
// ClaudeMessages handles Claude-compatible streaming chat completions.
|
||||||
// This function implements a sophisticated client rotation and quota management system
|
// This function implements a sophisticated client rotation and quota management system
|
||||||
// to ensure high availability and optimal resource utilization across multiple backend clients.
|
// to ensure high availability and optimal resource utilization across multiple backend clients.
|
||||||
func (h *APIHandlers) ClaudeMessages(c *gin.Context) {
|
func (h *ClaudeCodeAPIHandlers) ClaudeMessages(c *gin.Context) {
|
||||||
// Extract raw JSON data from the incoming request
|
// Extract raw JSON data from the incoming request
|
||||||
rawJson, err := c.GetRawData()
|
rawJSON, err := c.GetRawData()
|
||||||
// If data retrieval fails, return a 400 Bad Request error.
|
// If data retrieval fails, return a 400 Bad Request error.
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.JSON(http.StatusBadRequest, ErrorResponse{
|
c.JSON(http.StatusBadRequest, handlers.ErrorResponse{
|
||||||
Error: ErrorDetail{
|
Error: handlers.ErrorDetail{
|
||||||
Message: fmt.Sprintf("Invalid request: %v", err),
|
Message: fmt.Sprintf("Invalid request: %v", err),
|
||||||
Type: "invalid_request_error",
|
Type: "invalid_request_error",
|
||||||
},
|
},
|
||||||
@@ -41,8 +62,8 @@ func (h *APIHandlers) ClaudeMessages(c *gin.Context) {
|
|||||||
// This is crucial for streaming as it allows immediate sending of data chunks
|
// This is crucial for streaming as it allows immediate sending of data chunks
|
||||||
flusher, ok := c.Writer.(http.Flusher)
|
flusher, ok := c.Writer.(http.Flusher)
|
||||||
if !ok {
|
if !ok {
|
||||||
c.JSON(http.StatusInternalServerError, ErrorResponse{
|
c.JSON(http.StatusInternalServerError, handlers.ErrorResponse{
|
||||||
Error: ErrorDetail{
|
Error: handlers.ErrorDetail{
|
||||||
Message: "Streaming not supported",
|
Message: "Streaming not supported",
|
||||||
Type: "server_error",
|
Type: "server_error",
|
||||||
},
|
},
|
||||||
@@ -52,7 +73,7 @@ func (h *APIHandlers) ClaudeMessages(c *gin.Context) {
|
|||||||
|
|
||||||
// Parse and prepare the Claude request, extracting model name, system instructions,
|
// Parse and prepare the Claude request, extracting model name, system instructions,
|
||||||
// conversation contents, and available tools from the raw JSON
|
// conversation contents, and available tools from the raw JSON
|
||||||
modelName, systemInstruction, contents, tools := translator.PrepareClaudeRequest(rawJson)
|
modelName, systemInstruction, contents, tools := code.PrepareClaudeRequest(rawJSON)
|
||||||
|
|
||||||
// Map Claude model names to corresponding Gemini models
|
// Map Claude model names to corresponding Gemini models
|
||||||
// This allows the proxy to handle Claude API calls using Gemini backend
|
// This allows the proxy to handle Claude API calls using Gemini backend
|
||||||
@@ -79,7 +100,7 @@ func (h *APIHandlers) ClaudeMessages(c *gin.Context) {
|
|||||||
outLoop:
|
outLoop:
|
||||||
for {
|
for {
|
||||||
var errorResponse *client.ErrorMessage
|
var errorResponse *client.ErrorMessage
|
||||||
cliClient, errorResponse = h.getClient(modelName)
|
cliClient, errorResponse = h.GetClient(modelName)
|
||||||
if errorResponse != nil {
|
if errorResponse != nil {
|
||||||
c.Status(errorResponse.StatusCode)
|
c.Status(errorResponse.StatusCode)
|
||||||
_, _ = fmt.Fprint(c.Writer, errorResponse.Error)
|
_, _ = fmt.Fprint(c.Writer, errorResponse.Error)
|
||||||
@@ -105,7 +126,7 @@ outLoop:
|
|||||||
includeThoughts = !strings.Contains(userAgent[0], "claude-cli")
|
includeThoughts = !strings.Contains(userAgent[0], "claude-cli")
|
||||||
}
|
}
|
||||||
|
|
||||||
respChan, errChan := cliClient.SendMessageStream(cliCtx, rawJson, modelName, systemInstruction, contents, tools, includeThoughts)
|
respChan, errChan := cliClient.SendMessageStream(cliCtx, rawJSON, modelName, systemInstruction, contents, tools, includeThoughts)
|
||||||
|
|
||||||
// Track response state for proper Claude format conversion
|
// Track response state for proper Claude format conversion
|
||||||
hasFirstResponse := false
|
hasFirstResponse := false
|
||||||
@@ -139,16 +160,15 @@ outLoop:
|
|||||||
flusher.Flush()
|
flusher.Flush()
|
||||||
cliCancel()
|
cliCancel()
|
||||||
return
|
return
|
||||||
} else {
|
|
||||||
// Convert the backend response to Claude-compatible format
|
|
||||||
// This translation layer ensures API compatibility
|
|
||||||
claudeFormat := translator.ConvertCliToClaude(chunk, isGlAPIKey, hasFirstResponse, &responseType, &responseIndex)
|
|
||||||
if claudeFormat != "" {
|
|
||||||
_, _ = c.Writer.Write([]byte(claudeFormat))
|
|
||||||
flusher.Flush() // Immediately send the chunk to the client
|
|
||||||
}
|
|
||||||
hasFirstResponse = true
|
|
||||||
}
|
}
|
||||||
|
// Convert the backend response to Claude-compatible format
|
||||||
|
// This translation layer ensures API compatibility
|
||||||
|
claudeFormat := code.ConvertCliToClaude(chunk, isGlAPIKey, hasFirstResponse, &responseType, &responseIndex)
|
||||||
|
if claudeFormat != "" {
|
||||||
|
_, _ = c.Writer.Write([]byte(claudeFormat))
|
||||||
|
flusher.Flush() // Immediately send the chunk to the client
|
||||||
|
}
|
||||||
|
hasFirstResponse = true
|
||||||
|
|
||||||
// Case 3: Handle errors from the backend
|
// Case 3: Handle errors from the backend
|
||||||
// This manages various error conditions and implements retry logic
|
// This manages various error conditions and implements retry logic
|
||||||
@@ -156,7 +176,7 @@ outLoop:
|
|||||||
if okError {
|
if okError {
|
||||||
// Special handling for quota exceeded errors
|
// Special handling for quota exceeded errors
|
||||||
// If configured, attempt to switch to a different project/client
|
// If configured, attempt to switch to a different project/client
|
||||||
if errInfo.StatusCode == 429 && h.cfg.QuotaExceeded.SwitchProject {
|
if errInfo.StatusCode == 429 && h.Cfg.QuotaExceeded.SwitchProject {
|
||||||
continue outLoop // Restart the client selection process
|
continue outLoop // Restart the client selection process
|
||||||
} else {
|
} else {
|
||||||
// Forward other errors directly to the client
|
// Forward other errors directly to the client
|
||||||
@@ -1,10 +1,15 @@
|
|||||||
package api
|
// Package cli provides HTTP handlers for Gemini CLI API functionality.
|
||||||
|
// This package implements handlers that process CLI-specific requests for Gemini API operations,
|
||||||
|
// including content generation and streaming content generation endpoints.
|
||||||
|
// The handlers restrict access to localhost only and manage communication with the backend service.
|
||||||
|
package cli
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
"context"
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
|
"github.com/luispater/CLIProxyAPI/internal/api/handlers"
|
||||||
"github.com/luispater/CLIProxyAPI/internal/client"
|
"github.com/luispater/CLIProxyAPI/internal/client"
|
||||||
"github.com/luispater/CLIProxyAPI/internal/util"
|
"github.com/luispater/CLIProxyAPI/internal/util"
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
@@ -16,10 +21,26 @@ import (
|
|||||||
"time"
|
"time"
|
||||||
)
|
)
|
||||||
|
|
||||||
func (h *APIHandlers) CLIHandler(c *gin.Context) {
|
// GeminiCLIAPIHandlers contains the handlers for Gemini CLI API endpoints.
|
||||||
|
// It holds a pool of clients to interact with the backend service.
|
||||||
|
type GeminiCLIAPIHandlers struct {
|
||||||
|
*handlers.APIHandlers
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewGeminiCLIAPIHandlers creates a new Gemini CLI API handlers instance.
|
||||||
|
// It takes an APIHandlers instance as input and returns a GeminiCLIAPIHandlers.
|
||||||
|
func NewGeminiCLIAPIHandlers(apiHandlers *handlers.APIHandlers) *GeminiCLIAPIHandlers {
|
||||||
|
return &GeminiCLIAPIHandlers{
|
||||||
|
APIHandlers: apiHandlers,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// CLIHandler handles CLI-specific requests for Gemini API operations.
|
||||||
|
// It restricts access to localhost only and routes requests to appropriate internal handlers.
|
||||||
|
func (h *GeminiCLIAPIHandlers) CLIHandler(c *gin.Context) {
|
||||||
if !strings.HasPrefix(c.Request.RemoteAddr, "127.0.0.1:") {
|
if !strings.HasPrefix(c.Request.RemoteAddr, "127.0.0.1:") {
|
||||||
c.JSON(http.StatusForbidden, ErrorResponse{
|
c.JSON(http.StatusForbidden, handlers.ErrorResponse{
|
||||||
Error: ErrorDetail{
|
Error: handlers.ErrorDetail{
|
||||||
Message: "CLI reply only allow local access",
|
Message: "CLI reply only allow local access",
|
||||||
Type: "forbidden",
|
Type: "forbidden",
|
||||||
},
|
},
|
||||||
@@ -27,18 +48,18 @@ func (h *APIHandlers) CLIHandler(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
rawJson, _ := c.GetRawData()
|
rawJSON, _ := c.GetRawData()
|
||||||
requestRawURI := c.Request.URL.Path
|
requestRawURI := c.Request.URL.Path
|
||||||
if requestRawURI == "/v1internal:generateContent" {
|
if requestRawURI == "/v1internal:generateContent" {
|
||||||
h.internalGenerateContent(c, rawJson)
|
h.internalGenerateContent(c, rawJSON)
|
||||||
} else if requestRawURI == "/v1internal:streamGenerateContent" {
|
} else if requestRawURI == "/v1internal:streamGenerateContent" {
|
||||||
h.internalStreamGenerateContent(c, rawJson)
|
h.internalStreamGenerateContent(c, rawJSON)
|
||||||
} else {
|
} else {
|
||||||
reqBody := bytes.NewBuffer(rawJson)
|
reqBody := bytes.NewBuffer(rawJSON)
|
||||||
req, err := http.NewRequest("POST", fmt.Sprintf("https://cloudcode-pa.googleapis.com%s", c.Request.URL.RequestURI()), reqBody)
|
req, err := http.NewRequest("POST", fmt.Sprintf("https://cloudcode-pa.googleapis.com%s", c.Request.URL.RequestURI()), reqBody)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.JSON(http.StatusBadRequest, ErrorResponse{
|
c.JSON(http.StatusBadRequest, handlers.ErrorResponse{
|
||||||
Error: ErrorDetail{
|
Error: handlers.ErrorDetail{
|
||||||
Message: fmt.Sprintf("Invalid request: %v", err),
|
Message: fmt.Sprintf("Invalid request: %v", err),
|
||||||
Type: "invalid_request_error",
|
Type: "invalid_request_error",
|
||||||
},
|
},
|
||||||
@@ -49,15 +70,15 @@ func (h *APIHandlers) CLIHandler(c *gin.Context) {
|
|||||||
req.Header[key] = value
|
req.Header[key] = value
|
||||||
}
|
}
|
||||||
|
|
||||||
httpClient, err := util.SetProxy(h.cfg, &http.Client{})
|
httpClient, err := util.SetProxy(h.Cfg, &http.Client{})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Fatalf("set proxy failed: %v", err)
|
log.Fatalf("set proxy failed: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
resp, err := httpClient.Do(req)
|
resp, err := httpClient.Do(req)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.JSON(http.StatusBadRequest, ErrorResponse{
|
c.JSON(http.StatusBadRequest, handlers.ErrorResponse{
|
||||||
Error: ErrorDetail{
|
Error: handlers.ErrorDetail{
|
||||||
Message: fmt.Sprintf("Invalid request: %v", err),
|
Message: fmt.Sprintf("Invalid request: %v", err),
|
||||||
Type: "invalid_request_error",
|
Type: "invalid_request_error",
|
||||||
},
|
},
|
||||||
@@ -73,8 +94,8 @@ func (h *APIHandlers) CLIHandler(c *gin.Context) {
|
|||||||
}()
|
}()
|
||||||
bodyBytes, _ := io.ReadAll(resp.Body)
|
bodyBytes, _ := io.ReadAll(resp.Body)
|
||||||
|
|
||||||
c.JSON(http.StatusBadRequest, ErrorResponse{
|
c.JSON(http.StatusBadRequest, handlers.ErrorResponse{
|
||||||
Error: ErrorDetail{
|
Error: handlers.ErrorDetail{
|
||||||
Message: string(bodyBytes),
|
Message: string(bodyBytes),
|
||||||
Type: "invalid_request_error",
|
Type: "invalid_request_error",
|
||||||
},
|
},
|
||||||
@@ -98,8 +119,8 @@ func (h *APIHandlers) CLIHandler(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *APIHandlers) internalStreamGenerateContent(c *gin.Context, rawJson []byte) {
|
func (h *GeminiCLIAPIHandlers) internalStreamGenerateContent(c *gin.Context, rawJSON []byte) {
|
||||||
alt := h.getAlt(c)
|
alt := h.GetAlt(c)
|
||||||
|
|
||||||
if alt == "" {
|
if alt == "" {
|
||||||
c.Header("Content-Type", "text/event-stream")
|
c.Header("Content-Type", "text/event-stream")
|
||||||
@@ -111,8 +132,8 @@ func (h *APIHandlers) internalStreamGenerateContent(c *gin.Context, rawJson []by
|
|||||||
// Get the http.Flusher interface to manually flush the response.
|
// Get the http.Flusher interface to manually flush the response.
|
||||||
flusher, ok := c.Writer.(http.Flusher)
|
flusher, ok := c.Writer.(http.Flusher)
|
||||||
if !ok {
|
if !ok {
|
||||||
c.JSON(http.StatusInternalServerError, ErrorResponse{
|
c.JSON(http.StatusInternalServerError, handlers.ErrorResponse{
|
||||||
Error: ErrorDetail{
|
Error: handlers.ErrorDetail{
|
||||||
Message: "Streaming not supported",
|
Message: "Streaming not supported",
|
||||||
Type: "server_error",
|
Type: "server_error",
|
||||||
},
|
},
|
||||||
@@ -120,7 +141,7 @@ func (h *APIHandlers) internalStreamGenerateContent(c *gin.Context, rawJson []by
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
modelResult := gjson.GetBytes(rawJson, "model")
|
modelResult := gjson.GetBytes(rawJSON, "model")
|
||||||
modelName := modelResult.String()
|
modelName := modelResult.String()
|
||||||
|
|
||||||
cliCtx, cliCancel := context.WithCancel(context.Background())
|
cliCtx, cliCancel := context.WithCancel(context.Background())
|
||||||
@@ -135,7 +156,7 @@ func (h *APIHandlers) internalStreamGenerateContent(c *gin.Context, rawJson []by
|
|||||||
outLoop:
|
outLoop:
|
||||||
for {
|
for {
|
||||||
var errorResponse *client.ErrorMessage
|
var errorResponse *client.ErrorMessage
|
||||||
cliClient, errorResponse = h.getClient(modelName)
|
cliClient, errorResponse = h.GetClient(modelName)
|
||||||
if errorResponse != nil {
|
if errorResponse != nil {
|
||||||
c.Status(errorResponse.StatusCode)
|
c.Status(errorResponse.StatusCode)
|
||||||
_, _ = fmt.Fprint(c.Writer, errorResponse.Error)
|
_, _ = fmt.Fprint(c.Writer, errorResponse.Error)
|
||||||
@@ -150,7 +171,7 @@ outLoop:
|
|||||||
log.Debugf("Request use account: %s, project id: %s", cliClient.GetEmail(), cliClient.GetProjectID())
|
log.Debugf("Request use account: %s, project id: %s", cliClient.GetEmail(), cliClient.GetProjectID())
|
||||||
}
|
}
|
||||||
// Send the message and receive response chunks and errors via channels.
|
// Send the message and receive response chunks and errors via channels.
|
||||||
respChan, errChan := cliClient.SendRawMessageStream(cliCtx, rawJson, "")
|
respChan, errChan := cliClient.SendRawMessageStream(cliCtx, rawJSON, "")
|
||||||
hasFirstResponse := false
|
hasFirstResponse := false
|
||||||
for {
|
for {
|
||||||
select {
|
select {
|
||||||
@@ -166,20 +187,19 @@ outLoop:
|
|||||||
if !okStream {
|
if !okStream {
|
||||||
cliCancel()
|
cliCancel()
|
||||||
return
|
return
|
||||||
} else {
|
|
||||||
hasFirstResponse = true
|
|
||||||
if cliClient.GetGenerativeLanguageAPIKey() != "" {
|
|
||||||
chunk, _ = sjson.SetRawBytes(chunk, "response", chunk)
|
|
||||||
}
|
|
||||||
_, _ = c.Writer.Write([]byte("data: "))
|
|
||||||
_, _ = c.Writer.Write(chunk)
|
|
||||||
_, _ = c.Writer.Write([]byte("\n\n"))
|
|
||||||
flusher.Flush()
|
|
||||||
}
|
}
|
||||||
|
hasFirstResponse = true
|
||||||
|
if cliClient.GetGenerativeLanguageAPIKey() != "" {
|
||||||
|
chunk, _ = sjson.SetRawBytes(chunk, "response", chunk)
|
||||||
|
}
|
||||||
|
_, _ = c.Writer.Write([]byte("data: "))
|
||||||
|
_, _ = c.Writer.Write(chunk)
|
||||||
|
_, _ = c.Writer.Write([]byte("\n\n"))
|
||||||
|
flusher.Flush()
|
||||||
// Handle errors from the backend.
|
// Handle errors from the backend.
|
||||||
case err, okError := <-errChan:
|
case err, okError := <-errChan:
|
||||||
if okError {
|
if okError {
|
||||||
if err.StatusCode == 429 && h.cfg.QuotaExceeded.SwitchProject {
|
if err.StatusCode == 429 && h.Cfg.QuotaExceeded.SwitchProject {
|
||||||
continue outLoop
|
continue outLoop
|
||||||
} else {
|
} else {
|
||||||
c.Status(err.StatusCode)
|
c.Status(err.StatusCode)
|
||||||
@@ -200,10 +220,10 @@ outLoop:
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *APIHandlers) internalGenerateContent(c *gin.Context, rawJson []byte) {
|
func (h *GeminiCLIAPIHandlers) internalGenerateContent(c *gin.Context, rawJSON []byte) {
|
||||||
c.Header("Content-Type", "application/json")
|
c.Header("Content-Type", "application/json")
|
||||||
|
|
||||||
modelResult := gjson.GetBytes(rawJson, "model")
|
modelResult := gjson.GetBytes(rawJSON, "model")
|
||||||
modelName := modelResult.String()
|
modelName := modelResult.String()
|
||||||
cliCtx, cliCancel := context.WithCancel(context.Background())
|
cliCtx, cliCancel := context.WithCancel(context.Background())
|
||||||
var cliClient *client.Client
|
var cliClient *client.Client
|
||||||
@@ -215,7 +235,7 @@ func (h *APIHandlers) internalGenerateContent(c *gin.Context, rawJson []byte) {
|
|||||||
|
|
||||||
for {
|
for {
|
||||||
var errorResponse *client.ErrorMessage
|
var errorResponse *client.ErrorMessage
|
||||||
cliClient, errorResponse = h.getClient(modelName)
|
cliClient, errorResponse = h.GetClient(modelName)
|
||||||
if errorResponse != nil {
|
if errorResponse != nil {
|
||||||
c.Status(errorResponse.StatusCode)
|
c.Status(errorResponse.StatusCode)
|
||||||
_, _ = fmt.Fprint(c.Writer, errorResponse.Error)
|
_, _ = fmt.Fprint(c.Writer, errorResponse.Error)
|
||||||
@@ -229,9 +249,9 @@ func (h *APIHandlers) internalGenerateContent(c *gin.Context, rawJson []byte) {
|
|||||||
log.Debugf("Request use account: %s, project id: %s", cliClient.GetEmail(), cliClient.GetProjectID())
|
log.Debugf("Request use account: %s, project id: %s", cliClient.GetEmail(), cliClient.GetProjectID())
|
||||||
}
|
}
|
||||||
|
|
||||||
resp, err := cliClient.SendRawMessage(cliCtx, rawJson, "")
|
resp, err := cliClient.SendRawMessage(cliCtx, rawJSON, "")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if err.StatusCode == 429 && h.cfg.QuotaExceeded.SwitchProject {
|
if err.StatusCode == 429 && h.Cfg.QuotaExceeded.SwitchProject {
|
||||||
continue
|
continue
|
||||||
} else {
|
} else {
|
||||||
c.Status(err.StatusCode)
|
c.Status(err.StatusCode)
|
||||||
@@ -1,10 +1,16 @@
|
|||||||
package api
|
// Package gemini provides HTTP handlers for Gemini API endpoints.
|
||||||
|
// This package implements handlers for managing Gemini model operations including
|
||||||
|
// model listing, content generation, streaming content generation, and token counting.
|
||||||
|
// It serves as a proxy layer between clients and the Gemini backend service,
|
||||||
|
// handling request translation, client management, and response processing.
|
||||||
|
package gemini
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
"github.com/luispater/CLIProxyAPI/internal/api/translator"
|
"github.com/luispater/CLIProxyAPI/internal/api/handlers"
|
||||||
|
"github.com/luispater/CLIProxyAPI/internal/api/translator/gemini/cli"
|
||||||
"github.com/luispater/CLIProxyAPI/internal/client"
|
"github.com/luispater/CLIProxyAPI/internal/client"
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
"github.com/tidwall/gjson"
|
"github.com/tidwall/gjson"
|
||||||
@@ -14,7 +20,23 @@ import (
|
|||||||
"time"
|
"time"
|
||||||
)
|
)
|
||||||
|
|
||||||
func (h *APIHandlers) GeminiModels(c *gin.Context) {
|
// GeminiAPIHandlers contains the handlers for Gemini API endpoints.
|
||||||
|
// It holds a pool of clients to interact with the backend service.
|
||||||
|
type GeminiAPIHandlers struct {
|
||||||
|
*handlers.APIHandlers
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewGeminiAPIHandlers creates a new Gemini API handlers instance.
|
||||||
|
// It takes an APIHandlers instance as input and returns a GeminiAPIHandlers.
|
||||||
|
func NewGeminiAPIHandlers(apiHandlers *handlers.APIHandlers) *GeminiAPIHandlers {
|
||||||
|
return &GeminiAPIHandlers{
|
||||||
|
APIHandlers: apiHandlers,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// GeminiModels handles the Gemini models listing endpoint.
|
||||||
|
// It returns a JSON response containing available Gemini models and their specifications.
|
||||||
|
func (h *GeminiAPIHandlers) GeminiModels(c *gin.Context) {
|
||||||
c.Status(http.StatusOK)
|
c.Status(http.StatusOK)
|
||||||
c.Header("Content-Type", "application/json; charset=UTF-8")
|
c.Header("Content-Type", "application/json; charset=UTF-8")
|
||||||
_, _ = c.Writer.Write([]byte(`{"models":[{"name":"models/gemini-2.5-flash","version":"001","displayName":"Gemini `))
|
_, _ = c.Writer.Write([]byte(`{"models":[{"name":"models/gemini-2.5-flash","version":"001","displayName":"Gemini `))
|
||||||
@@ -30,13 +52,15 @@ func (h *APIHandlers) GeminiModels(c *gin.Context) {
|
|||||||
_, _ = c.Writer.Write([]byte(`e":2,"thinking":true}],"nextPageToken":""}`))
|
_, _ = c.Writer.Write([]byte(`e":2,"thinking":true}],"nextPageToken":""}`))
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *APIHandlers) GeminiGetHandler(c *gin.Context) {
|
// GeminiGetHandler handles GET requests for specific Gemini model information.
|
||||||
|
// It returns detailed information about a specific Gemini model based on the action parameter.
|
||||||
|
func (h *GeminiAPIHandlers) GeminiGetHandler(c *gin.Context) {
|
||||||
var request struct {
|
var request struct {
|
||||||
Action string `uri:"action" binding:"required"`
|
Action string `uri:"action" binding:"required"`
|
||||||
}
|
}
|
||||||
if err := c.ShouldBindUri(&request); err != nil {
|
if err := c.ShouldBindUri(&request); err != nil {
|
||||||
c.JSON(http.StatusBadRequest, ErrorResponse{
|
c.JSON(http.StatusBadRequest, handlers.ErrorResponse{
|
||||||
Error: ErrorDetail{
|
Error: handlers.ErrorDetail{
|
||||||
Message: fmt.Sprintf("Invalid request: %v", err),
|
Message: fmt.Sprintf("Invalid request: %v", err),
|
||||||
Type: "invalid_request_error",
|
Type: "invalid_request_error",
|
||||||
},
|
},
|
||||||
@@ -68,13 +92,15 @@ func (h *APIHandlers) GeminiGetHandler(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *APIHandlers) GeminiHandler(c *gin.Context) {
|
// GeminiHandler handles POST requests for Gemini API operations.
|
||||||
|
// It routes requests to appropriate handlers based on the action parameter (model:method format).
|
||||||
|
func (h *GeminiAPIHandlers) GeminiHandler(c *gin.Context) {
|
||||||
var request struct {
|
var request struct {
|
||||||
Action string `uri:"action" binding:"required"`
|
Action string `uri:"action" binding:"required"`
|
||||||
}
|
}
|
||||||
if err := c.ShouldBindUri(&request); err != nil {
|
if err := c.ShouldBindUri(&request); err != nil {
|
||||||
c.JSON(http.StatusBadRequest, ErrorResponse{
|
c.JSON(http.StatusBadRequest, handlers.ErrorResponse{
|
||||||
Error: ErrorDetail{
|
Error: handlers.ErrorDetail{
|
||||||
Message: fmt.Sprintf("Invalid request: %v", err),
|
Message: fmt.Sprintf("Invalid request: %v", err),
|
||||||
Type: "invalid_request_error",
|
Type: "invalid_request_error",
|
||||||
},
|
},
|
||||||
@@ -83,8 +109,8 @@ func (h *APIHandlers) GeminiHandler(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
action := strings.Split(request.Action, ":")
|
action := strings.Split(request.Action, ":")
|
||||||
if len(action) != 2 {
|
if len(action) != 2 {
|
||||||
c.JSON(http.StatusNotFound, ErrorResponse{
|
c.JSON(http.StatusNotFound, handlers.ErrorResponse{
|
||||||
Error: ErrorDetail{
|
Error: handlers.ErrorDetail{
|
||||||
Message: fmt.Sprintf("%s not found.", c.Request.URL.Path),
|
Message: fmt.Sprintf("%s not found.", c.Request.URL.Path),
|
||||||
Type: "invalid_request_error",
|
Type: "invalid_request_error",
|
||||||
},
|
},
|
||||||
@@ -94,20 +120,20 @@ func (h *APIHandlers) GeminiHandler(c *gin.Context) {
|
|||||||
|
|
||||||
modelName := action[0]
|
modelName := action[0]
|
||||||
method := action[1]
|
method := action[1]
|
||||||
rawJson, _ := c.GetRawData()
|
rawJSON, _ := c.GetRawData()
|
||||||
rawJson, _ = sjson.SetBytes(rawJson, "model", []byte(modelName))
|
rawJSON, _ = sjson.SetBytes(rawJSON, "model", []byte(modelName))
|
||||||
|
|
||||||
if method == "generateContent" {
|
if method == "generateContent" {
|
||||||
h.geminiGenerateContent(c, rawJson)
|
h.geminiGenerateContent(c, rawJSON)
|
||||||
} else if method == "streamGenerateContent" {
|
} else if method == "streamGenerateContent" {
|
||||||
h.geminiStreamGenerateContent(c, rawJson)
|
h.geminiStreamGenerateContent(c, rawJSON)
|
||||||
} else if method == "countTokens" {
|
} else if method == "countTokens" {
|
||||||
h.geminiCountTokens(c, rawJson)
|
h.geminiCountTokens(c, rawJSON)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *APIHandlers) geminiStreamGenerateContent(c *gin.Context, rawJson []byte) {
|
func (h *GeminiAPIHandlers) geminiStreamGenerateContent(c *gin.Context, rawJSON []byte) {
|
||||||
alt := h.getAlt(c)
|
alt := h.GetAlt(c)
|
||||||
|
|
||||||
if alt == "" {
|
if alt == "" {
|
||||||
c.Header("Content-Type", "text/event-stream")
|
c.Header("Content-Type", "text/event-stream")
|
||||||
@@ -119,8 +145,8 @@ func (h *APIHandlers) geminiStreamGenerateContent(c *gin.Context, rawJson []byte
|
|||||||
// Get the http.Flusher interface to manually flush the response.
|
// Get the http.Flusher interface to manually flush the response.
|
||||||
flusher, ok := c.Writer.(http.Flusher)
|
flusher, ok := c.Writer.(http.Flusher)
|
||||||
if !ok {
|
if !ok {
|
||||||
c.JSON(http.StatusInternalServerError, ErrorResponse{
|
c.JSON(http.StatusInternalServerError, handlers.ErrorResponse{
|
||||||
Error: ErrorDetail{
|
Error: handlers.ErrorDetail{
|
||||||
Message: "Streaming not supported",
|
Message: "Streaming not supported",
|
||||||
Type: "server_error",
|
Type: "server_error",
|
||||||
},
|
},
|
||||||
@@ -128,7 +154,7 @@ func (h *APIHandlers) geminiStreamGenerateContent(c *gin.Context, rawJson []byte
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
modelResult := gjson.GetBytes(rawJson, "model")
|
modelResult := gjson.GetBytes(rawJSON, "model")
|
||||||
modelName := modelResult.String()
|
modelName := modelResult.String()
|
||||||
|
|
||||||
cliCtx, cliCancel := context.WithCancel(context.Background())
|
cliCtx, cliCancel := context.WithCancel(context.Background())
|
||||||
@@ -143,7 +169,7 @@ func (h *APIHandlers) geminiStreamGenerateContent(c *gin.Context, rawJson []byte
|
|||||||
outLoop:
|
outLoop:
|
||||||
for {
|
for {
|
||||||
var errorResponse *client.ErrorMessage
|
var errorResponse *client.ErrorMessage
|
||||||
cliClient, errorResponse = h.getClient(modelName)
|
cliClient, errorResponse = h.GetClient(modelName)
|
||||||
if errorResponse != nil {
|
if errorResponse != nil {
|
||||||
c.Status(errorResponse.StatusCode)
|
c.Status(errorResponse.StatusCode)
|
||||||
_, _ = fmt.Fprint(c.Writer, errorResponse.Error)
|
_, _ = fmt.Fprint(c.Writer, errorResponse.Error)
|
||||||
@@ -153,21 +179,21 @@ outLoop:
|
|||||||
}
|
}
|
||||||
|
|
||||||
template := ""
|
template := ""
|
||||||
parsed := gjson.Parse(string(rawJson))
|
parsed := gjson.Parse(string(rawJSON))
|
||||||
contents := parsed.Get("request.contents")
|
contents := parsed.Get("request.contents")
|
||||||
if contents.Exists() {
|
if contents.Exists() {
|
||||||
template = string(rawJson)
|
template = string(rawJSON)
|
||||||
} else {
|
} else {
|
||||||
template = `{"project":"","request":{},"model":""}`
|
template = `{"project":"","request":{},"model":""}`
|
||||||
template, _ = sjson.SetRaw(template, "request", string(rawJson))
|
template, _ = sjson.SetRaw(template, "request", string(rawJSON))
|
||||||
template, _ = sjson.Set(template, "model", gjson.Get(template, "request.model").String())
|
template, _ = sjson.Set(template, "model", gjson.Get(template, "request.model").String())
|
||||||
template, _ = sjson.Delete(template, "request.model")
|
template, _ = sjson.Delete(template, "request.model")
|
||||||
}
|
}
|
||||||
|
|
||||||
template, errFixCLIToolResponse := translator.FixCLIToolResponse(template)
|
template, errFixCLIToolResponse := cli.FixCLIToolResponse(template)
|
||||||
if errFixCLIToolResponse != nil {
|
if errFixCLIToolResponse != nil {
|
||||||
c.JSON(http.StatusInternalServerError, ErrorResponse{
|
c.JSON(http.StatusInternalServerError, handlers.ErrorResponse{
|
||||||
Error: ErrorDetail{
|
Error: handlers.ErrorDetail{
|
||||||
Message: errFixCLIToolResponse.Error(),
|
Message: errFixCLIToolResponse.Error(),
|
||||||
Type: "server_error",
|
Type: "server_error",
|
||||||
},
|
},
|
||||||
@@ -181,7 +207,7 @@ outLoop:
|
|||||||
template, _ = sjson.SetRaw(template, "request.systemInstruction", systemInstructionResult.Raw)
|
template, _ = sjson.SetRaw(template, "request.systemInstruction", systemInstructionResult.Raw)
|
||||||
template, _ = sjson.Delete(template, "request.system_instruction")
|
template, _ = sjson.Delete(template, "request.system_instruction")
|
||||||
}
|
}
|
||||||
rawJson = []byte(template)
|
rawJSON = []byte(template)
|
||||||
|
|
||||||
if glAPIKey := cliClient.GetGenerativeLanguageAPIKey(); glAPIKey != "" {
|
if glAPIKey := cliClient.GetGenerativeLanguageAPIKey(); glAPIKey != "" {
|
||||||
log.Debugf("Request use generative language API Key: %s", glAPIKey)
|
log.Debugf("Request use generative language API Key: %s", glAPIKey)
|
||||||
@@ -190,7 +216,7 @@ outLoop:
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Send the message and receive response chunks and errors via channels.
|
// Send the message and receive response chunks and errors via channels.
|
||||||
respChan, errChan := cliClient.SendRawMessageStream(cliCtx, rawJson, alt)
|
respChan, errChan := cliClient.SendRawMessageStream(cliCtx, rawJSON, alt)
|
||||||
for {
|
for {
|
||||||
select {
|
select {
|
||||||
// Handle client disconnection.
|
// Handle client disconnection.
|
||||||
@@ -205,41 +231,40 @@ outLoop:
|
|||||||
if !okStream {
|
if !okStream {
|
||||||
cliCancel()
|
cliCancel()
|
||||||
return
|
return
|
||||||
} else {
|
}
|
||||||
if cliClient.GetGenerativeLanguageAPIKey() == "" {
|
if cliClient.GetGenerativeLanguageAPIKey() == "" {
|
||||||
if alt == "" {
|
if alt == "" {
|
||||||
responseResult := gjson.GetBytes(chunk, "response")
|
responseResult := gjson.GetBytes(chunk, "response")
|
||||||
if responseResult.Exists() {
|
if responseResult.Exists() {
|
||||||
chunk = []byte(responseResult.Raw)
|
chunk = []byte(responseResult.Raw)
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
chunkTemplate := "[]"
|
chunkTemplate := "[]"
|
||||||
responseResult := gjson.ParseBytes(chunk)
|
responseResult := gjson.ParseBytes(chunk)
|
||||||
if responseResult.IsArray() {
|
if responseResult.IsArray() {
|
||||||
responseResultItems := responseResult.Array()
|
responseResultItems := responseResult.Array()
|
||||||
for i := 0; i < len(responseResultItems); i++ {
|
for i := 0; i < len(responseResultItems); i++ {
|
||||||
responseResultItem := responseResultItems[i]
|
responseResultItem := responseResultItems[i]
|
||||||
if responseResultItem.Get("response").Exists() {
|
if responseResultItem.Get("response").Exists() {
|
||||||
chunkTemplate, _ = sjson.SetRaw(chunkTemplate, "-1", responseResultItem.Get("response").Raw)
|
chunkTemplate, _ = sjson.SetRaw(chunkTemplate, "-1", responseResultItem.Get("response").Raw)
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
chunk = []byte(chunkTemplate)
|
|
||||||
}
|
}
|
||||||
|
chunk = []byte(chunkTemplate)
|
||||||
}
|
}
|
||||||
if alt == "" {
|
|
||||||
_, _ = c.Writer.Write([]byte("data: "))
|
|
||||||
_, _ = c.Writer.Write(chunk)
|
|
||||||
_, _ = c.Writer.Write([]byte("\n\n"))
|
|
||||||
} else {
|
|
||||||
_, _ = c.Writer.Write(chunk)
|
|
||||||
}
|
|
||||||
flusher.Flush()
|
|
||||||
}
|
}
|
||||||
|
if alt == "" {
|
||||||
|
_, _ = c.Writer.Write([]byte("data: "))
|
||||||
|
_, _ = c.Writer.Write(chunk)
|
||||||
|
_, _ = c.Writer.Write([]byte("\n\n"))
|
||||||
|
} else {
|
||||||
|
_, _ = c.Writer.Write(chunk)
|
||||||
|
}
|
||||||
|
flusher.Flush()
|
||||||
// Handle errors from the backend.
|
// Handle errors from the backend.
|
||||||
case err, okError := <-errChan:
|
case err, okError := <-errChan:
|
||||||
if okError {
|
if okError {
|
||||||
if err.StatusCode == 429 && h.cfg.QuotaExceeded.SwitchProject {
|
if err.StatusCode == 429 && h.Cfg.QuotaExceeded.SwitchProject {
|
||||||
log.Debugf("quota exceeded, switch client")
|
log.Debugf("quota exceeded, switch client")
|
||||||
continue outLoop
|
continue outLoop
|
||||||
} else {
|
} else {
|
||||||
@@ -258,12 +283,12 @@ outLoop:
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *APIHandlers) geminiCountTokens(c *gin.Context, rawJson []byte) {
|
func (h *GeminiAPIHandlers) geminiCountTokens(c *gin.Context, rawJSON []byte) {
|
||||||
c.Header("Content-Type", "application/json")
|
c.Header("Content-Type", "application/json")
|
||||||
|
|
||||||
alt := h.getAlt(c)
|
alt := h.GetAlt(c)
|
||||||
// orgRawJson := rawJson
|
// orgrawJSON := rawJSON
|
||||||
modelResult := gjson.GetBytes(rawJson, "model")
|
modelResult := gjson.GetBytes(rawJSON, "model")
|
||||||
modelName := modelResult.String()
|
modelName := modelResult.String()
|
||||||
cliCtx, cliCancel := context.WithCancel(context.Background())
|
cliCtx, cliCancel := context.WithCancel(context.Background())
|
||||||
var cliClient *client.Client
|
var cliClient *client.Client
|
||||||
@@ -275,7 +300,7 @@ func (h *APIHandlers) geminiCountTokens(c *gin.Context, rawJson []byte) {
|
|||||||
|
|
||||||
for {
|
for {
|
||||||
var errorResponse *client.ErrorMessage
|
var errorResponse *client.ErrorMessage
|
||||||
cliClient, errorResponse = h.getClient(modelName, false)
|
cliClient, errorResponse = h.GetClient(modelName, false)
|
||||||
if errorResponse != nil {
|
if errorResponse != nil {
|
||||||
c.Status(errorResponse.StatusCode)
|
c.Status(errorResponse.StatusCode)
|
||||||
_, _ = fmt.Fprint(c.Writer, errorResponse.Error)
|
_, _ = fmt.Fprint(c.Writer, errorResponse.Error)
|
||||||
@@ -289,27 +314,27 @@ func (h *APIHandlers) geminiCountTokens(c *gin.Context, rawJson []byte) {
|
|||||||
log.Debugf("Request use account: %s, project id: %s", cliClient.GetEmail(), cliClient.GetProjectID())
|
log.Debugf("Request use account: %s, project id: %s", cliClient.GetEmail(), cliClient.GetProjectID())
|
||||||
|
|
||||||
template := `{"request":{}}`
|
template := `{"request":{}}`
|
||||||
if gjson.GetBytes(rawJson, "generateContentRequest").Exists() {
|
if gjson.GetBytes(rawJSON, "generateContentRequest").Exists() {
|
||||||
template, _ = sjson.SetRaw(template, "request", gjson.GetBytes(rawJson, "generateContentRequest").Raw)
|
template, _ = sjson.SetRaw(template, "request", gjson.GetBytes(rawJSON, "generateContentRequest").Raw)
|
||||||
template, _ = sjson.Delete(template, "generateContentRequest")
|
template, _ = sjson.Delete(template, "generateContentRequest")
|
||||||
} else if gjson.GetBytes(rawJson, "contents").Exists() {
|
} else if gjson.GetBytes(rawJSON, "contents").Exists() {
|
||||||
template, _ = sjson.SetRaw(template, "request.contents", gjson.GetBytes(rawJson, "contents").Raw)
|
template, _ = sjson.SetRaw(template, "request.contents", gjson.GetBytes(rawJSON, "contents").Raw)
|
||||||
template, _ = sjson.Delete(template, "contents")
|
template, _ = sjson.Delete(template, "contents")
|
||||||
}
|
}
|
||||||
rawJson = []byte(template)
|
rawJSON = []byte(template)
|
||||||
}
|
}
|
||||||
|
|
||||||
resp, err := cliClient.SendRawTokenCount(cliCtx, rawJson, alt)
|
resp, err := cliClient.SendRawTokenCount(cliCtx, rawJSON, alt)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if err.StatusCode == 429 && h.cfg.QuotaExceeded.SwitchProject {
|
if err.StatusCode == 429 && h.Cfg.QuotaExceeded.SwitchProject {
|
||||||
continue
|
continue
|
||||||
} else {
|
} else {
|
||||||
c.Status(err.StatusCode)
|
c.Status(err.StatusCode)
|
||||||
_, _ = c.Writer.Write([]byte(err.Error.Error()))
|
_, _ = c.Writer.Write([]byte(err.Error.Error()))
|
||||||
cliCancel()
|
cliCancel()
|
||||||
// log.Debugf(err.Error.Error())
|
// log.Debugf(err.Error.Error())
|
||||||
// log.Debugf(string(rawJson))
|
// log.Debugf(string(rawJSON))
|
||||||
// log.Debugf(string(orgRawJson))
|
// log.Debugf(string(orgrawJSON))
|
||||||
}
|
}
|
||||||
break
|
break
|
||||||
} else {
|
} else {
|
||||||
@@ -326,12 +351,12 @@ func (h *APIHandlers) geminiCountTokens(c *gin.Context, rawJson []byte) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *APIHandlers) geminiGenerateContent(c *gin.Context, rawJson []byte) {
|
func (h *GeminiAPIHandlers) geminiGenerateContent(c *gin.Context, rawJSON []byte) {
|
||||||
c.Header("Content-Type", "application/json")
|
c.Header("Content-Type", "application/json")
|
||||||
|
|
||||||
alt := h.getAlt(c)
|
alt := h.GetAlt(c)
|
||||||
|
|
||||||
modelResult := gjson.GetBytes(rawJson, "model")
|
modelResult := gjson.GetBytes(rawJSON, "model")
|
||||||
modelName := modelResult.String()
|
modelName := modelResult.String()
|
||||||
cliCtx, cliCancel := context.WithCancel(context.Background())
|
cliCtx, cliCancel := context.WithCancel(context.Background())
|
||||||
var cliClient *client.Client
|
var cliClient *client.Client
|
||||||
@@ -343,7 +368,7 @@ func (h *APIHandlers) geminiGenerateContent(c *gin.Context, rawJson []byte) {
|
|||||||
|
|
||||||
for {
|
for {
|
||||||
var errorResponse *client.ErrorMessage
|
var errorResponse *client.ErrorMessage
|
||||||
cliClient, errorResponse = h.getClient(modelName)
|
cliClient, errorResponse = h.GetClient(modelName)
|
||||||
if errorResponse != nil {
|
if errorResponse != nil {
|
||||||
c.Status(errorResponse.StatusCode)
|
c.Status(errorResponse.StatusCode)
|
||||||
_, _ = fmt.Fprint(c.Writer, errorResponse.Error)
|
_, _ = fmt.Fprint(c.Writer, errorResponse.Error)
|
||||||
@@ -352,21 +377,21 @@ func (h *APIHandlers) geminiGenerateContent(c *gin.Context, rawJson []byte) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
template := ""
|
template := ""
|
||||||
parsed := gjson.Parse(string(rawJson))
|
parsed := gjson.Parse(string(rawJSON))
|
||||||
contents := parsed.Get("request.contents")
|
contents := parsed.Get("request.contents")
|
||||||
if contents.Exists() {
|
if contents.Exists() {
|
||||||
template = string(rawJson)
|
template = string(rawJSON)
|
||||||
} else {
|
} else {
|
||||||
template = `{"project":"","request":{},"model":""}`
|
template = `{"project":"","request":{},"model":""}`
|
||||||
template, _ = sjson.SetRaw(template, "request", string(rawJson))
|
template, _ = sjson.SetRaw(template, "request", string(rawJSON))
|
||||||
template, _ = sjson.Set(template, "model", gjson.Get(template, "request.model").String())
|
template, _ = sjson.Set(template, "model", gjson.Get(template, "request.model").String())
|
||||||
template, _ = sjson.Delete(template, "request.model")
|
template, _ = sjson.Delete(template, "request.model")
|
||||||
}
|
}
|
||||||
|
|
||||||
template, errFixCLIToolResponse := translator.FixCLIToolResponse(template)
|
template, errFixCLIToolResponse := cli.FixCLIToolResponse(template)
|
||||||
if errFixCLIToolResponse != nil {
|
if errFixCLIToolResponse != nil {
|
||||||
c.JSON(http.StatusInternalServerError, ErrorResponse{
|
c.JSON(http.StatusInternalServerError, handlers.ErrorResponse{
|
||||||
Error: ErrorDetail{
|
Error: handlers.ErrorDetail{
|
||||||
Message: errFixCLIToolResponse.Error(),
|
Message: errFixCLIToolResponse.Error(),
|
||||||
Type: "server_error",
|
Type: "server_error",
|
||||||
},
|
},
|
||||||
@@ -380,16 +405,16 @@ func (h *APIHandlers) geminiGenerateContent(c *gin.Context, rawJson []byte) {
|
|||||||
template, _ = sjson.SetRaw(template, "request.systemInstruction", systemInstructionResult.Raw)
|
template, _ = sjson.SetRaw(template, "request.systemInstruction", systemInstructionResult.Raw)
|
||||||
template, _ = sjson.Delete(template, "request.system_instruction")
|
template, _ = sjson.Delete(template, "request.system_instruction")
|
||||||
}
|
}
|
||||||
rawJson = []byte(template)
|
rawJSON = []byte(template)
|
||||||
|
|
||||||
if glAPIKey := cliClient.GetGenerativeLanguageAPIKey(); glAPIKey != "" {
|
if glAPIKey := cliClient.GetGenerativeLanguageAPIKey(); glAPIKey != "" {
|
||||||
log.Debugf("Request use generative language API Key: %s", glAPIKey)
|
log.Debugf("Request use generative language API Key: %s", glAPIKey)
|
||||||
} else {
|
} else {
|
||||||
log.Debugf("Request use account: %s, project id: %s", cliClient.GetEmail(), cliClient.GetProjectID())
|
log.Debugf("Request use account: %s, project id: %s", cliClient.GetEmail(), cliClient.GetProjectID())
|
||||||
}
|
}
|
||||||
resp, err := cliClient.SendRawMessage(cliCtx, rawJson, alt)
|
resp, err := cliClient.SendRawMessage(cliCtx, rawJSON, alt)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if err.StatusCode == 429 && h.cfg.QuotaExceeded.SwitchProject {
|
if err.StatusCode == 429 && h.Cfg.QuotaExceeded.SwitchProject {
|
||||||
continue
|
continue
|
||||||
} else {
|
} else {
|
||||||
c.Status(err.StatusCode)
|
c.Status(err.StatusCode)
|
||||||
@@ -410,16 +435,3 @@ func (h *APIHandlers) geminiGenerateContent(c *gin.Context, rawJson []byte) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *APIHandlers) getAlt(c *gin.Context) string {
|
|
||||||
var alt string
|
|
||||||
var hasAlt bool
|
|
||||||
alt, hasAlt = c.GetQuery("alt")
|
|
||||||
if !hasAlt {
|
|
||||||
alt, _ = c.GetQuery("$alt")
|
|
||||||
}
|
|
||||||
if alt == "sse" {
|
|
||||||
return ""
|
|
||||||
}
|
|
||||||
return alt
|
|
||||||
}
|
|
||||||
122
internal/api/handlers/handlers.go
Normal file
122
internal/api/handlers/handlers.go
Normal file
@@ -0,0 +1,122 @@
|
|||||||
|
// Package handlers provides core API handler functionality for the CLI Proxy API server.
|
||||||
|
// It includes common types, client management, load balancing, and error handling
|
||||||
|
// shared across all API endpoint handlers (OpenAI, Claude, Gemini).
|
||||||
|
package handlers
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
"github.com/luispater/CLIProxyAPI/internal/client"
|
||||||
|
"github.com/luispater/CLIProxyAPI/internal/config"
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
"sync"
|
||||||
|
)
|
||||||
|
|
||||||
|
// ErrorResponse represents a standard error response format for the API.
|
||||||
|
// It contains a single ErrorDetail field.
|
||||||
|
type ErrorResponse struct {
|
||||||
|
Error ErrorDetail `json:"error"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// ErrorDetail provides specific information about an error that occurred.
|
||||||
|
// It includes a human-readable message, an error type, and an optional error code.
|
||||||
|
type ErrorDetail struct {
|
||||||
|
// A human-readable message providing more details about the error.
|
||||||
|
Message string `json:"message"`
|
||||||
|
// The type of error that occurred (e.g., "invalid_request_error").
|
||||||
|
Type string `json:"type"`
|
||||||
|
// A short code identifying the error, if applicable.
|
||||||
|
Code string `json:"code,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// APIHandlers contains the handlers for API endpoints.
|
||||||
|
// It holds a pool of clients to interact with the backend service.
|
||||||
|
type APIHandlers struct {
|
||||||
|
CliClients []*client.Client
|
||||||
|
Cfg *config.Config
|
||||||
|
Mutex *sync.Mutex
|
||||||
|
LastUsedClientIndex int
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewAPIHandlers creates a new API handlers instance.
|
||||||
|
// It takes a slice of clients and a debug flag as input.
|
||||||
|
func NewAPIHandlers(cliClients []*client.Client, cfg *config.Config) *APIHandlers {
|
||||||
|
return &APIHandlers{
|
||||||
|
CliClients: cliClients,
|
||||||
|
Cfg: cfg,
|
||||||
|
Mutex: &sync.Mutex{},
|
||||||
|
LastUsedClientIndex: 0,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// UpdateClients updates the handlers' client list and configuration
|
||||||
|
func (h *APIHandlers) UpdateClients(clients []*client.Client, cfg *config.Config) {
|
||||||
|
h.CliClients = clients
|
||||||
|
h.Cfg = cfg
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetClient returns an available client from the pool using round-robin load balancing.
|
||||||
|
// It checks for quota limits and tries to find an unlocked client for immediate use.
|
||||||
|
// The modelName parameter is used to check quota status for specific models.
|
||||||
|
func (h *APIHandlers) GetClient(modelName string, isGenerateContent ...bool) (*client.Client, *client.ErrorMessage) {
|
||||||
|
if len(h.CliClients) == 0 {
|
||||||
|
return nil, &client.ErrorMessage{StatusCode: 500, Error: fmt.Errorf("no clients available")}
|
||||||
|
}
|
||||||
|
|
||||||
|
var cliClient *client.Client
|
||||||
|
|
||||||
|
// Lock the mutex to update the last used client index
|
||||||
|
h.Mutex.Lock()
|
||||||
|
startIndex := h.LastUsedClientIndex
|
||||||
|
if (len(isGenerateContent) > 0 && isGenerateContent[0]) || len(isGenerateContent) == 0 {
|
||||||
|
currentIndex := (startIndex + 1) % len(h.CliClients)
|
||||||
|
h.LastUsedClientIndex = currentIndex
|
||||||
|
}
|
||||||
|
h.Mutex.Unlock()
|
||||||
|
|
||||||
|
// Reorder the client to start from the last used index
|
||||||
|
reorderedClients := make([]*client.Client, 0)
|
||||||
|
for i := 0; i < len(h.CliClients); i++ {
|
||||||
|
cliClient = h.CliClients[(startIndex+1+i)%len(h.CliClients)]
|
||||||
|
if cliClient.IsModelQuotaExceeded(modelName) {
|
||||||
|
log.Debugf("Model %s is quota exceeded for account %s, project id: %s", modelName, cliClient.GetEmail(), cliClient.GetProjectID())
|
||||||
|
cliClient = nil
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
reorderedClients = append(reorderedClients, cliClient)
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(reorderedClients) == 0 {
|
||||||
|
return nil, &client.ErrorMessage{StatusCode: 429, Error: fmt.Errorf(`{"error":{"code":429,"message":"All the models of '%s' are quota exceeded","status":"RESOURCE_EXHAUSTED"}}`, modelName)}
|
||||||
|
}
|
||||||
|
|
||||||
|
locked := false
|
||||||
|
for i := 0; i < len(reorderedClients); i++ {
|
||||||
|
cliClient = reorderedClients[i]
|
||||||
|
if cliClient.RequestMutex.TryLock() {
|
||||||
|
locked = true
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if !locked {
|
||||||
|
cliClient = h.CliClients[0]
|
||||||
|
cliClient.RequestMutex.Lock()
|
||||||
|
}
|
||||||
|
|
||||||
|
return cliClient, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetAlt extracts the 'alt' parameter from the request query string.
|
||||||
|
// It checks both 'alt' and '$alt' parameters and returns the appropriate value.
|
||||||
|
func (h *APIHandlers) GetAlt(c *gin.Context) string {
|
||||||
|
var alt string
|
||||||
|
var hasAlt bool
|
||||||
|
alt, hasAlt = c.GetQuery("alt")
|
||||||
|
if !hasAlt {
|
||||||
|
alt, _ = c.GetQuery("$alt")
|
||||||
|
}
|
||||||
|
if alt == "sse" {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
return alt
|
||||||
|
}
|
||||||
@@ -1,50 +1,42 @@
|
|||||||
package api
|
// Package openai provides HTTP handlers for OpenAI API endpoints.
|
||||||
|
// This package implements the OpenAI-compatible API interface, including model listing
|
||||||
|
// and chat completion functionality. It supports both streaming and non-streaming responses,
|
||||||
|
// and manages a pool of clients to interact with backend services.
|
||||||
|
// The handlers translate OpenAI API requests to the appropriate backend format and
|
||||||
|
// convert responses back to OpenAI-compatible format.
|
||||||
|
package openai
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
"github.com/luispater/CLIProxyAPI/internal/api/translator"
|
"github.com/luispater/CLIProxyAPI/internal/api/handlers"
|
||||||
|
"github.com/luispater/CLIProxyAPI/internal/api/translator/openai"
|
||||||
"github.com/luispater/CLIProxyAPI/internal/client"
|
"github.com/luispater/CLIProxyAPI/internal/client"
|
||||||
"github.com/luispater/CLIProxyAPI/internal/config"
|
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
"github.com/tidwall/gjson"
|
"github.com/tidwall/gjson"
|
||||||
"net/http"
|
"net/http"
|
||||||
"sync"
|
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
)
|
)
|
||||||
|
|
||||||
var (
|
// OpenAIAPIHandlers contains the handlers for OpenAI API endpoints.
|
||||||
mutex = &sync.Mutex{}
|
|
||||||
lastUsedClientIndex = 0
|
|
||||||
)
|
|
||||||
|
|
||||||
// APIHandlers contains the handlers for API endpoints.
|
|
||||||
// It holds a pool of clients to interact with the backend service.
|
// It holds a pool of clients to interact with the backend service.
|
||||||
type APIHandlers struct {
|
type OpenAIAPIHandlers struct {
|
||||||
cliClients []*client.Client
|
*handlers.APIHandlers
|
||||||
cfg *config.Config
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewAPIHandlers creates a new API handlers instance.
|
// NewOpenAIAPIHandlers creates a new OpenAI API handlers instance.
|
||||||
// It takes a slice of clients and a debug flag as input.
|
// It takes an APIHandlers instance as input and returns an OpenAIAPIHandlers.
|
||||||
func NewAPIHandlers(cliClients []*client.Client, cfg *config.Config) *APIHandlers {
|
func NewOpenAIAPIHandlers(apiHandlers *handlers.APIHandlers) *OpenAIAPIHandlers {
|
||||||
return &APIHandlers{
|
return &OpenAIAPIHandlers{
|
||||||
cliClients: cliClients,
|
APIHandlers: apiHandlers,
|
||||||
cfg: cfg,
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// UpdateClients updates the handlers' client list and configuration
|
|
||||||
func (h *APIHandlers) UpdateClients(clients []*client.Client, cfg *config.Config) {
|
|
||||||
h.cliClients = clients
|
|
||||||
h.cfg = cfg
|
|
||||||
}
|
|
||||||
|
|
||||||
// Models handles the /v1/models endpoint.
|
// Models handles the /v1/models endpoint.
|
||||||
// It returns a hardcoded list of available AI models.
|
// It returns a hardcoded list of available AI models.
|
||||||
func (h *APIHandlers) Models(c *gin.Context) {
|
func (h *OpenAIAPIHandlers) Models(c *gin.Context) {
|
||||||
c.JSON(http.StatusOK, gin.H{
|
c.JSON(http.StatusOK, gin.H{
|
||||||
"data": []map[string]any{
|
"data": []map[string]any{
|
||||||
{
|
{
|
||||||
@@ -91,63 +83,15 @@ func (h *APIHandlers) Models(c *gin.Context) {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *APIHandlers) getClient(modelName string, isGenerateContent ...bool) (*client.Client, *client.ErrorMessage) {
|
|
||||||
if len(h.cliClients) == 0 {
|
|
||||||
return nil, &client.ErrorMessage{StatusCode: 500, Error: fmt.Errorf("no clients available")}
|
|
||||||
}
|
|
||||||
|
|
||||||
var cliClient *client.Client
|
|
||||||
|
|
||||||
// Lock the mutex to update the last used client index
|
|
||||||
mutex.Lock()
|
|
||||||
startIndex := lastUsedClientIndex
|
|
||||||
if (len(isGenerateContent) > 0 && isGenerateContent[0]) || len(isGenerateContent) == 0 {
|
|
||||||
currentIndex := (startIndex + 1) % len(h.cliClients)
|
|
||||||
lastUsedClientIndex = currentIndex
|
|
||||||
}
|
|
||||||
mutex.Unlock()
|
|
||||||
|
|
||||||
// Reorder the client to start from the last used index
|
|
||||||
reorderedClients := make([]*client.Client, 0)
|
|
||||||
for i := 0; i < len(h.cliClients); i++ {
|
|
||||||
cliClient = h.cliClients[(startIndex+1+i)%len(h.cliClients)]
|
|
||||||
if cliClient.IsModelQuotaExceeded(modelName) {
|
|
||||||
log.Debugf("Model %s is quota exceeded for account %s, project id: %s", modelName, cliClient.GetEmail(), cliClient.GetProjectID())
|
|
||||||
cliClient = nil
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
reorderedClients = append(reorderedClients, cliClient)
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(reorderedClients) == 0 {
|
|
||||||
return nil, &client.ErrorMessage{StatusCode: 429, Error: fmt.Errorf(`{"error":{"code":429,"message":"All the models of '%s' are quota exceeded","status":"RESOURCE_EXHAUSTED"}}`, modelName)}
|
|
||||||
}
|
|
||||||
|
|
||||||
locked := false
|
|
||||||
for i := 0; i < len(reorderedClients); i++ {
|
|
||||||
cliClient = reorderedClients[i]
|
|
||||||
if cliClient.RequestMutex.TryLock() {
|
|
||||||
locked = true
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if !locked {
|
|
||||||
cliClient = h.cliClients[0]
|
|
||||||
cliClient.RequestMutex.Lock()
|
|
||||||
}
|
|
||||||
|
|
||||||
return cliClient, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// ChatCompletions handles the /v1/chat/completions endpoint.
|
// ChatCompletions handles the /v1/chat/completions endpoint.
|
||||||
// It determines whether the request is for a streaming or non-streaming response
|
// It determines whether the request is for a streaming or non-streaming response
|
||||||
// and calls the appropriate handler.
|
// and calls the appropriate handler.
|
||||||
func (h *APIHandlers) ChatCompletions(c *gin.Context) {
|
func (h *OpenAIAPIHandlers) ChatCompletions(c *gin.Context) {
|
||||||
rawJson, err := c.GetRawData()
|
rawJSON, err := c.GetRawData()
|
||||||
// If data retrieval fails, return a 400 Bad Request error.
|
// If data retrieval fails, return a 400 Bad Request error.
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.JSON(http.StatusBadRequest, ErrorResponse{
|
c.JSON(http.StatusBadRequest, handlers.ErrorResponse{
|
||||||
Error: ErrorDetail{
|
Error: handlers.ErrorDetail{
|
||||||
Message: fmt.Sprintf("Invalid request: %v", err),
|
Message: fmt.Sprintf("Invalid request: %v", err),
|
||||||
Type: "invalid_request_error",
|
Type: "invalid_request_error",
|
||||||
},
|
},
|
||||||
@@ -156,21 +100,21 @@ func (h *APIHandlers) ChatCompletions(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Check if the client requested a streaming response.
|
// Check if the client requested a streaming response.
|
||||||
streamResult := gjson.GetBytes(rawJson, "stream")
|
streamResult := gjson.GetBytes(rawJSON, "stream")
|
||||||
if streamResult.Type == gjson.True {
|
if streamResult.Type == gjson.True {
|
||||||
h.handleStreamingResponse(c, rawJson)
|
h.handleStreamingResponse(c, rawJSON)
|
||||||
} else {
|
} else {
|
||||||
h.handleNonStreamingResponse(c, rawJson)
|
h.handleNonStreamingResponse(c, rawJSON)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// handleNonStreamingResponse handles non-streaming chat completion responses.
|
// handleNonStreamingResponse handles non-streaming chat completion responses.
|
||||||
// It selects a client from the pool, sends the request, and aggregates the response
|
// It selects a client from the pool, sends the request, and aggregates the response
|
||||||
// before sending it back to the client.
|
// before sending it back to the client.
|
||||||
func (h *APIHandlers) handleNonStreamingResponse(c *gin.Context, rawJson []byte) {
|
func (h *OpenAIAPIHandlers) handleNonStreamingResponse(c *gin.Context, rawJSON []byte) {
|
||||||
c.Header("Content-Type", "application/json")
|
c.Header("Content-Type", "application/json")
|
||||||
|
|
||||||
modelName, systemInstruction, contents, tools := translator.PrepareRequest(rawJson)
|
modelName, systemInstruction, contents, tools := openai.PrepareRequest(rawJSON)
|
||||||
cliCtx, cliCancel := context.WithCancel(context.Background())
|
cliCtx, cliCancel := context.WithCancel(context.Background())
|
||||||
var cliClient *client.Client
|
var cliClient *client.Client
|
||||||
defer func() {
|
defer func() {
|
||||||
@@ -181,7 +125,7 @@ func (h *APIHandlers) handleNonStreamingResponse(c *gin.Context, rawJson []byte)
|
|||||||
|
|
||||||
for {
|
for {
|
||||||
var errorResponse *client.ErrorMessage
|
var errorResponse *client.ErrorMessage
|
||||||
cliClient, errorResponse = h.getClient(modelName)
|
cliClient, errorResponse = h.GetClient(modelName)
|
||||||
if errorResponse != nil {
|
if errorResponse != nil {
|
||||||
c.Status(errorResponse.StatusCode)
|
c.Status(errorResponse.StatusCode)
|
||||||
_, _ = fmt.Fprint(c.Writer, errorResponse.Error)
|
_, _ = fmt.Fprint(c.Writer, errorResponse.Error)
|
||||||
@@ -197,9 +141,9 @@ func (h *APIHandlers) handleNonStreamingResponse(c *gin.Context, rawJson []byte)
|
|||||||
log.Debugf("Request use account: %s, project id: %s", cliClient.GetEmail(), cliClient.GetProjectID())
|
log.Debugf("Request use account: %s, project id: %s", cliClient.GetEmail(), cliClient.GetProjectID())
|
||||||
}
|
}
|
||||||
|
|
||||||
resp, err := cliClient.SendMessage(cliCtx, rawJson, modelName, systemInstruction, contents, tools)
|
resp, err := cliClient.SendMessage(cliCtx, rawJSON, modelName, systemInstruction, contents, tools)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if err.StatusCode == 429 && h.cfg.QuotaExceeded.SwitchProject {
|
if err.StatusCode == 429 && h.Cfg.QuotaExceeded.SwitchProject {
|
||||||
continue
|
continue
|
||||||
} else {
|
} else {
|
||||||
c.Status(err.StatusCode)
|
c.Status(err.StatusCode)
|
||||||
@@ -208,7 +152,7 @@ func (h *APIHandlers) handleNonStreamingResponse(c *gin.Context, rawJson []byte)
|
|||||||
}
|
}
|
||||||
break
|
break
|
||||||
} else {
|
} else {
|
||||||
openAIFormat := translator.ConvertCliToOpenAINonStream(resp, time.Now().Unix(), isGlAPIKey)
|
openAIFormat := openai.ConvertCliToOpenAINonStream(resp, time.Now().Unix(), isGlAPIKey)
|
||||||
if openAIFormat != "" {
|
if openAIFormat != "" {
|
||||||
_, _ = c.Writer.Write([]byte(openAIFormat))
|
_, _ = c.Writer.Write([]byte(openAIFormat))
|
||||||
}
|
}
|
||||||
@@ -219,7 +163,7 @@ func (h *APIHandlers) handleNonStreamingResponse(c *gin.Context, rawJson []byte)
|
|||||||
}
|
}
|
||||||
|
|
||||||
// handleStreamingResponse handles streaming responses
|
// handleStreamingResponse handles streaming responses
|
||||||
func (h *APIHandlers) handleStreamingResponse(c *gin.Context, rawJson []byte) {
|
func (h *OpenAIAPIHandlers) handleStreamingResponse(c *gin.Context, rawJSON []byte) {
|
||||||
c.Header("Content-Type", "text/event-stream")
|
c.Header("Content-Type", "text/event-stream")
|
||||||
c.Header("Cache-Control", "no-cache")
|
c.Header("Cache-Control", "no-cache")
|
||||||
c.Header("Connection", "keep-alive")
|
c.Header("Connection", "keep-alive")
|
||||||
@@ -228,8 +172,8 @@ func (h *APIHandlers) handleStreamingResponse(c *gin.Context, rawJson []byte) {
|
|||||||
// Get the http.Flusher interface to manually flush the response.
|
// Get the http.Flusher interface to manually flush the response.
|
||||||
flusher, ok := c.Writer.(http.Flusher)
|
flusher, ok := c.Writer.(http.Flusher)
|
||||||
if !ok {
|
if !ok {
|
||||||
c.JSON(http.StatusInternalServerError, ErrorResponse{
|
c.JSON(http.StatusInternalServerError, handlers.ErrorResponse{
|
||||||
Error: ErrorDetail{
|
Error: handlers.ErrorDetail{
|
||||||
Message: "Streaming not supported",
|
Message: "Streaming not supported",
|
||||||
Type: "server_error",
|
Type: "server_error",
|
||||||
},
|
},
|
||||||
@@ -238,7 +182,7 @@ func (h *APIHandlers) handleStreamingResponse(c *gin.Context, rawJson []byte) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Prepare the request for the backend client.
|
// Prepare the request for the backend client.
|
||||||
modelName, systemInstruction, contents, tools := translator.PrepareRequest(rawJson)
|
modelName, systemInstruction, contents, tools := openai.PrepareRequest(rawJSON)
|
||||||
cliCtx, cliCancel := context.WithCancel(context.Background())
|
cliCtx, cliCancel := context.WithCancel(context.Background())
|
||||||
var cliClient *client.Client
|
var cliClient *client.Client
|
||||||
defer func() {
|
defer func() {
|
||||||
@@ -251,7 +195,7 @@ func (h *APIHandlers) handleStreamingResponse(c *gin.Context, rawJson []byte) {
|
|||||||
outLoop:
|
outLoop:
|
||||||
for {
|
for {
|
||||||
var errorResponse *client.ErrorMessage
|
var errorResponse *client.ErrorMessage
|
||||||
cliClient, errorResponse = h.getClient(modelName)
|
cliClient, errorResponse = h.GetClient(modelName)
|
||||||
if errorResponse != nil {
|
if errorResponse != nil {
|
||||||
c.Status(errorResponse.StatusCode)
|
c.Status(errorResponse.StatusCode)
|
||||||
_, _ = fmt.Fprint(c.Writer, errorResponse.Error)
|
_, _ = fmt.Fprint(c.Writer, errorResponse.Error)
|
||||||
@@ -268,7 +212,7 @@ outLoop:
|
|||||||
log.Debugf("Request use account: %s, project id: %s", cliClient.GetEmail(), cliClient.GetProjectID())
|
log.Debugf("Request use account: %s, project id: %s", cliClient.GetEmail(), cliClient.GetProjectID())
|
||||||
}
|
}
|
||||||
// Send the message and receive response chunks and errors via channels.
|
// Send the message and receive response chunks and errors via channels.
|
||||||
respChan, errChan := cliClient.SendMessageStream(cliCtx, rawJson, modelName, systemInstruction, contents, tools)
|
respChan, errChan := cliClient.SendMessageStream(cliCtx, rawJSON, modelName, systemInstruction, contents, tools)
|
||||||
hasFirstResponse := false
|
hasFirstResponse := false
|
||||||
for {
|
for {
|
||||||
select {
|
select {
|
||||||
@@ -287,19 +231,18 @@ outLoop:
|
|||||||
flusher.Flush()
|
flusher.Flush()
|
||||||
cliCancel()
|
cliCancel()
|
||||||
return
|
return
|
||||||
} else {
|
}
|
||||||
// Convert the chunk to OpenAI format and send it to the client.
|
// Convert the chunk to OpenAI format and send it to the client.
|
||||||
hasFirstResponse = true
|
hasFirstResponse = true
|
||||||
openAIFormat := translator.ConvertCliToOpenAI(chunk, time.Now().Unix(), isGlAPIKey)
|
openAIFormat := openai.ConvertCliToOpenAI(chunk, time.Now().Unix(), isGlAPIKey)
|
||||||
if openAIFormat != "" {
|
if openAIFormat != "" {
|
||||||
_, _ = fmt.Fprintf(c.Writer, "data: %s\n\n", openAIFormat)
|
_, _ = fmt.Fprintf(c.Writer, "data: %s\n\n", openAIFormat)
|
||||||
flusher.Flush()
|
flusher.Flush()
|
||||||
}
|
|
||||||
}
|
}
|
||||||
// Handle errors from the backend.
|
// Handle errors from the backend.
|
||||||
case err, okError := <-errChan:
|
case err, okError := <-errChan:
|
||||||
if okError {
|
if okError {
|
||||||
if err.StatusCode == 429 && h.cfg.QuotaExceeded.SwitchProject {
|
if err.StatusCode == 429 && h.Cfg.QuotaExceeded.SwitchProject {
|
||||||
continue outLoop
|
continue outLoop
|
||||||
} else {
|
} else {
|
||||||
c.Status(err.StatusCode)
|
c.Status(err.StatusCode)
|
||||||
@@ -1,18 +0,0 @@
|
|||||||
package api
|
|
||||||
|
|
||||||
// ErrorResponse represents a standard error response format for the API.
|
|
||||||
// It contains a single ErrorDetail field.
|
|
||||||
type ErrorResponse struct {
|
|
||||||
Error ErrorDetail `json:"error"`
|
|
||||||
}
|
|
||||||
|
|
||||||
// ErrorDetail provides specific information about an error that occurred.
|
|
||||||
// It includes a human-readable message, an error type, and an optional error code.
|
|
||||||
type ErrorDetail struct {
|
|
||||||
// A human-readable message providing more details about the error.
|
|
||||||
Message string `json:"message"`
|
|
||||||
// The type of error that occurred (e.g., "invalid_request_error").
|
|
||||||
Type string `json:"type"`
|
|
||||||
// A short code identifying the error, if applicable.
|
|
||||||
Code string `json:"code,omitempty"`
|
|
||||||
}
|
|
||||||
@@ -1,3 +1,7 @@
|
|||||||
|
// Package api provides the HTTP API server implementation for the CLI Proxy API.
|
||||||
|
// It includes the main server struct, routing setup, middleware for CORS and authentication,
|
||||||
|
// and integration with various AI API handlers (OpenAI, Claude, Gemini).
|
||||||
|
// The server supports hot-reloading of clients and configuration.
|
||||||
package api
|
package api
|
||||||
|
|
||||||
import (
|
import (
|
||||||
@@ -5,6 +9,11 @@ import (
|
|||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
|
"github.com/luispater/CLIProxyAPI/internal/api/handlers"
|
||||||
|
"github.com/luispater/CLIProxyAPI/internal/api/handlers/claude"
|
||||||
|
"github.com/luispater/CLIProxyAPI/internal/api/handlers/gemini"
|
||||||
|
"github.com/luispater/CLIProxyAPI/internal/api/handlers/gemini/cli"
|
||||||
|
"github.com/luispater/CLIProxyAPI/internal/api/handlers/openai"
|
||||||
"github.com/luispater/CLIProxyAPI/internal/client"
|
"github.com/luispater/CLIProxyAPI/internal/client"
|
||||||
"github.com/luispater/CLIProxyAPI/internal/config"
|
"github.com/luispater/CLIProxyAPI/internal/config"
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
@@ -17,7 +26,7 @@ import (
|
|||||||
type Server struct {
|
type Server struct {
|
||||||
engine *gin.Engine
|
engine *gin.Engine
|
||||||
server *http.Server
|
server *http.Server
|
||||||
handlers *APIHandlers
|
handlers *handlers.APIHandlers
|
||||||
cfg *config.Config
|
cfg *config.Config
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -29,9 +38,6 @@ func NewServer(cfg *config.Config, cliClients []*client.Client) *Server {
|
|||||||
gin.SetMode(gin.ReleaseMode)
|
gin.SetMode(gin.ReleaseMode)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Create handlers
|
|
||||||
handlers := NewAPIHandlers(cliClients, cfg)
|
|
||||||
|
|
||||||
// Create gin engine
|
// Create gin engine
|
||||||
engine := gin.New()
|
engine := gin.New()
|
||||||
|
|
||||||
@@ -43,7 +49,7 @@ func NewServer(cfg *config.Config, cliClients []*client.Client) *Server {
|
|||||||
// Create server instance
|
// Create server instance
|
||||||
s := &Server{
|
s := &Server{
|
||||||
engine: engine,
|
engine: engine,
|
||||||
handlers: handlers,
|
handlers: handlers.NewAPIHandlers(cliClients, cfg),
|
||||||
cfg: cfg,
|
cfg: cfg,
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -62,22 +68,27 @@ func NewServer(cfg *config.Config, cliClients []*client.Client) *Server {
|
|||||||
// setupRoutes configures the API routes for the server.
|
// setupRoutes configures the API routes for the server.
|
||||||
// It defines the endpoints and associates them with their respective handlers.
|
// It defines the endpoints and associates them with their respective handlers.
|
||||||
func (s *Server) setupRoutes() {
|
func (s *Server) setupRoutes() {
|
||||||
|
openaiHandlers := openai.NewOpenAIAPIHandlers(s.handlers)
|
||||||
|
geminiHandlers := gemini.NewGeminiAPIHandlers(s.handlers)
|
||||||
|
geminiCLIHandlers := cli.NewGeminiCLIAPIHandlers(s.handlers)
|
||||||
|
claudeCodeHandlers := claude.NewClaudeCodeAPIHandlers(s.handlers)
|
||||||
|
|
||||||
// OpenAI compatible API routes
|
// OpenAI compatible API routes
|
||||||
v1 := s.engine.Group("/v1")
|
v1 := s.engine.Group("/v1")
|
||||||
v1.Use(AuthMiddleware(s.cfg))
|
v1.Use(AuthMiddleware(s.cfg))
|
||||||
{
|
{
|
||||||
v1.GET("/models", s.handlers.Models)
|
v1.GET("/models", openaiHandlers.Models)
|
||||||
v1.POST("/chat/completions", s.handlers.ChatCompletions)
|
v1.POST("/chat/completions", openaiHandlers.ChatCompletions)
|
||||||
v1.POST("/messages", s.handlers.ClaudeMessages)
|
v1.POST("/messages", claudeCodeHandlers.ClaudeMessages)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Gemini compatible API routes
|
// Gemini compatible API routes
|
||||||
v1beta := s.engine.Group("/v1beta")
|
v1beta := s.engine.Group("/v1beta")
|
||||||
v1beta.Use(AuthMiddleware(s.cfg))
|
v1beta.Use(AuthMiddleware(s.cfg))
|
||||||
{
|
{
|
||||||
v1beta.GET("/models", s.handlers.GeminiModels)
|
v1beta.GET("/models", geminiHandlers.GeminiModels)
|
||||||
v1beta.POST("/models/:action", s.handlers.GeminiHandler)
|
v1beta.POST("/models/:action", geminiHandlers.GeminiHandler)
|
||||||
v1beta.GET("/models/:action", s.handlers.GeminiGetHandler)
|
v1beta.GET("/models/:action", geminiHandlers.GeminiGetHandler)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Root endpoint
|
// Root endpoint
|
||||||
@@ -91,7 +102,7 @@ func (s *Server) setupRoutes() {
|
|||||||
},
|
},
|
||||||
})
|
})
|
||||||
})
|
})
|
||||||
s.engine.POST("/v1internal:method", s.handlers.CLIHandler)
|
s.engine.POST("/v1internal:method", geminiCLIHandlers.CLIHandler)
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -150,7 +161,7 @@ func (s *Server) UpdateClients(clients []*client.Client, cfg *config.Config) {
|
|||||||
// using API keys. If no API keys are configured, it allows all requests.
|
// using API keys. If no API keys are configured, it allows all requests.
|
||||||
func AuthMiddleware(cfg *config.Config) gin.HandlerFunc {
|
func AuthMiddleware(cfg *config.Config) gin.HandlerFunc {
|
||||||
return func(c *gin.Context) {
|
return func(c *gin.Context) {
|
||||||
if len(cfg.ApiKeys) == 0 {
|
if len(cfg.APIKeys) == 0 {
|
||||||
c.Next()
|
c.Next()
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -181,9 +192,9 @@ func AuthMiddleware(cfg *config.Config) gin.HandlerFunc {
|
|||||||
|
|
||||||
// Find the API key in the in-memory list
|
// Find the API key in the in-memory list
|
||||||
var foundKey string
|
var foundKey string
|
||||||
for i := range cfg.ApiKeys {
|
for i := range cfg.APIKeys {
|
||||||
if cfg.ApiKeys[i] == apiKey || cfg.ApiKeys[i] == authHeaderGoogle || cfg.ApiKeys[i] == authHeaderAnthropic || cfg.ApiKeys[i] == apiKeyQuery {
|
if cfg.APIKeys[i] == apiKey || cfg.APIKeys[i] == authHeaderGoogle || cfg.APIKeys[i] == authHeaderAnthropic || cfg.APIKeys[i] == apiKeyQuery {
|
||||||
foundKey = cfg.ApiKeys[i]
|
foundKey = cfg.APIKeys[i]
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
169
internal/api/translator/claude/code/request.go
Normal file
169
internal/api/translator/claude/code/request.go
Normal file
@@ -0,0 +1,169 @@
|
|||||||
|
// Package code provides request translation functionality for Claude API.
|
||||||
|
// It handles parsing and transforming Claude API requests into the internal client format,
|
||||||
|
// extracting model information, system instructions, message contents, and tool declarations.
|
||||||
|
// The package also performs JSON data cleaning and transformation to ensure compatibility
|
||||||
|
// between Claude API format and the internal client's expected format.
|
||||||
|
package code
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"encoding/json"
|
||||||
|
"github.com/luispater/CLIProxyAPI/internal/client"
|
||||||
|
"github.com/tidwall/gjson"
|
||||||
|
"github.com/tidwall/sjson"
|
||||||
|
"strings"
|
||||||
|
)
|
||||||
|
|
||||||
|
// PrepareClaudeRequest parses and transforms a Claude API request into internal client format.
|
||||||
|
// It extracts the model name, system instruction, message contents, and tool declarations
|
||||||
|
// from the raw JSON request and returns them in the format expected by the internal client.
|
||||||
|
func PrepareClaudeRequest(rawJSON []byte) (string, *client.Content, []client.Content, []client.ToolDeclaration) {
|
||||||
|
var pathsToDelete []string
|
||||||
|
root := gjson.ParseBytes(rawJSON)
|
||||||
|
walk(root, "", "additionalProperties", &pathsToDelete)
|
||||||
|
walk(root, "", "$schema", &pathsToDelete)
|
||||||
|
|
||||||
|
var err error
|
||||||
|
for _, p := range pathsToDelete {
|
||||||
|
rawJSON, err = sjson.DeleteBytes(rawJSON, p)
|
||||||
|
if err != nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
}
|
||||||
|
rawJSON = bytes.Replace(rawJSON, []byte(`"url":{"type":"string","format":"uri",`), []byte(`"url":{"type":"string",`), -1)
|
||||||
|
|
||||||
|
// log.Debug(string(rawJSON))
|
||||||
|
modelName := "gemini-2.5-pro"
|
||||||
|
modelResult := gjson.GetBytes(rawJSON, "model")
|
||||||
|
if modelResult.Type == gjson.String {
|
||||||
|
modelName = modelResult.String()
|
||||||
|
}
|
||||||
|
|
||||||
|
contents := make([]client.Content, 0)
|
||||||
|
|
||||||
|
var systemInstruction *client.Content
|
||||||
|
|
||||||
|
systemResult := gjson.GetBytes(rawJSON, "system")
|
||||||
|
if systemResult.IsArray() {
|
||||||
|
systemResults := systemResult.Array()
|
||||||
|
systemInstruction = &client.Content{Role: "user", Parts: []client.Part{}}
|
||||||
|
for i := 0; i < len(systemResults); i++ {
|
||||||
|
systemPromptResult := systemResults[i]
|
||||||
|
systemTypePromptResult := systemPromptResult.Get("type")
|
||||||
|
if systemTypePromptResult.Type == gjson.String && systemTypePromptResult.String() == "text" {
|
||||||
|
systemPrompt := systemPromptResult.Get("text").String()
|
||||||
|
systemPart := client.Part{Text: systemPrompt}
|
||||||
|
systemInstruction.Parts = append(systemInstruction.Parts, systemPart)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if len(systemInstruction.Parts) == 0 {
|
||||||
|
systemInstruction = nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
messagesResult := gjson.GetBytes(rawJSON, "messages")
|
||||||
|
if messagesResult.IsArray() {
|
||||||
|
messageResults := messagesResult.Array()
|
||||||
|
for i := 0; i < len(messageResults); i++ {
|
||||||
|
messageResult := messageResults[i]
|
||||||
|
roleResult := messageResult.Get("role")
|
||||||
|
if roleResult.Type != gjson.String {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
role := roleResult.String()
|
||||||
|
if role == "assistant" {
|
||||||
|
role = "model"
|
||||||
|
}
|
||||||
|
clientContent := client.Content{Role: role, Parts: []client.Part{}}
|
||||||
|
|
||||||
|
contentsResult := messageResult.Get("content")
|
||||||
|
if contentsResult.IsArray() {
|
||||||
|
contentResults := contentsResult.Array()
|
||||||
|
for j := 0; j < len(contentResults); j++ {
|
||||||
|
contentResult := contentResults[j]
|
||||||
|
contentTypeResult := contentResult.Get("type")
|
||||||
|
if contentTypeResult.Type == gjson.String && contentTypeResult.String() == "text" {
|
||||||
|
prompt := contentResult.Get("text").String()
|
||||||
|
clientContent.Parts = append(clientContent.Parts, client.Part{Text: prompt})
|
||||||
|
} else if contentTypeResult.Type == gjson.String && contentTypeResult.String() == "tool_use" {
|
||||||
|
functionName := contentResult.Get("name").String()
|
||||||
|
functionArgs := contentResult.Get("input").String()
|
||||||
|
var args map[string]any
|
||||||
|
if err = json.Unmarshal([]byte(functionArgs), &args); err == nil {
|
||||||
|
clientContent.Parts = append(clientContent.Parts, client.Part{
|
||||||
|
FunctionCall: &client.FunctionCall{
|
||||||
|
Name: functionName,
|
||||||
|
Args: args,
|
||||||
|
},
|
||||||
|
})
|
||||||
|
}
|
||||||
|
} else if contentTypeResult.Type == gjson.String && contentTypeResult.String() == "tool_result" {
|
||||||
|
toolCallID := contentResult.Get("tool_use_id").String()
|
||||||
|
if toolCallID != "" {
|
||||||
|
funcName := toolCallID
|
||||||
|
toolCallIDs := strings.Split(toolCallID, "-")
|
||||||
|
if len(toolCallIDs) > 1 {
|
||||||
|
funcName = strings.Join(toolCallIDs[0:len(toolCallIDs)-1], "-")
|
||||||
|
}
|
||||||
|
responseData := contentResult.Get("content").String()
|
||||||
|
functionResponse := client.FunctionResponse{Name: funcName, Response: map[string]interface{}{"result": responseData}}
|
||||||
|
clientContent.Parts = append(clientContent.Parts, client.Part{FunctionResponse: &functionResponse})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
contents = append(contents, clientContent)
|
||||||
|
} else if contentsResult.Type == gjson.String {
|
||||||
|
prompt := contentsResult.String()
|
||||||
|
contents = append(contents, client.Content{Role: role, Parts: []client.Part{{Text: prompt}}})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
var tools []client.ToolDeclaration
|
||||||
|
toolsResult := gjson.GetBytes(rawJSON, "tools")
|
||||||
|
if toolsResult.IsArray() {
|
||||||
|
tools = make([]client.ToolDeclaration, 1)
|
||||||
|
tools[0].FunctionDeclarations = make([]any, 0)
|
||||||
|
toolsResults := toolsResult.Array()
|
||||||
|
for i := 0; i < len(toolsResults); i++ {
|
||||||
|
toolResult := toolsResults[i]
|
||||||
|
inputSchemaResult := toolResult.Get("input_schema")
|
||||||
|
if inputSchemaResult.Exists() && inputSchemaResult.IsObject() {
|
||||||
|
inputSchema := inputSchemaResult.Raw
|
||||||
|
inputSchema, _ = sjson.Delete(inputSchema, "additionalProperties")
|
||||||
|
inputSchema, _ = sjson.Delete(inputSchema, "$schema")
|
||||||
|
|
||||||
|
tool, _ := sjson.Delete(toolResult.Raw, "input_schema")
|
||||||
|
tool, _ = sjson.SetRaw(tool, "parameters", inputSchema)
|
||||||
|
var toolDeclaration any
|
||||||
|
if err = json.Unmarshal([]byte(tool), &toolDeclaration); err == nil {
|
||||||
|
tools[0].FunctionDeclarations = append(tools[0].FunctionDeclarations, toolDeclaration)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
tools = make([]client.ToolDeclaration, 0)
|
||||||
|
}
|
||||||
|
|
||||||
|
return modelName, systemInstruction, contents, tools
|
||||||
|
}
|
||||||
|
|
||||||
|
func walk(value gjson.Result, path, field string, pathsToDelete *[]string) {
|
||||||
|
switch value.Type {
|
||||||
|
case gjson.JSON:
|
||||||
|
value.ForEach(func(key, val gjson.Result) bool {
|
||||||
|
var childPath string
|
||||||
|
if path == "" {
|
||||||
|
childPath = key.String()
|
||||||
|
} else {
|
||||||
|
childPath = path + "." + key.String()
|
||||||
|
}
|
||||||
|
if key.String() == field {
|
||||||
|
*pathsToDelete = append(*pathsToDelete, childPath)
|
||||||
|
}
|
||||||
|
walk(val, childPath, field, pathsToDelete)
|
||||||
|
return true
|
||||||
|
})
|
||||||
|
case gjson.String, gjson.Number, gjson.True, gjson.False, gjson.Null:
|
||||||
|
}
|
||||||
|
}
|
||||||
206
internal/api/translator/claude/code/response.go
Normal file
206
internal/api/translator/claude/code/response.go
Normal file
@@ -0,0 +1,206 @@
|
|||||||
|
// Package code provides response translation functionality for Claude API.
|
||||||
|
// This package handles the conversion of backend client responses into Claude-compatible
|
||||||
|
// Server-Sent Events (SSE) format, implementing a sophisticated state machine that manages
|
||||||
|
// different response types including text content, thinking processes, and function calls.
|
||||||
|
// The translation ensures proper sequencing of SSE events and maintains state across
|
||||||
|
// multiple response chunks to provide a seamless streaming experience.
|
||||||
|
package code
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"fmt"
|
||||||
|
"github.com/tidwall/gjson"
|
||||||
|
"github.com/tidwall/sjson"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
// ConvertCliToClaude performs sophisticated streaming response format conversion.
|
||||||
|
// This function implements a complex state machine that translates backend client responses
|
||||||
|
// into Claude-compatible Server-Sent Events (SSE) format. It manages different response types
|
||||||
|
// and handles state transitions between content blocks, thinking processes, and function calls.
|
||||||
|
//
|
||||||
|
// Response type states: 0=none, 1=content, 2=thinking, 3=function
|
||||||
|
// The function maintains state across multiple calls to ensure proper SSE event sequencing.
|
||||||
|
func ConvertCliToClaude(rawJSON []byte, isGlAPIKey, hasFirstResponse bool, responseType, responseIndex *int) string {
|
||||||
|
// Normalize the response format for different API key types
|
||||||
|
// Generative Language API keys have a different response structure
|
||||||
|
if isGlAPIKey {
|
||||||
|
rawJSON, _ = sjson.SetRawBytes(rawJSON, "response", rawJSON)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Track whether tools are being used in this response chunk
|
||||||
|
usedTool := false
|
||||||
|
output := ""
|
||||||
|
|
||||||
|
// Initialize the streaming session with a message_start event
|
||||||
|
// This is only sent for the very first response chunk
|
||||||
|
if !hasFirstResponse {
|
||||||
|
output = "event: message_start\n"
|
||||||
|
|
||||||
|
// Create the initial message structure with default values
|
||||||
|
// This follows the Claude API specification for streaming message initialization
|
||||||
|
messageStartTemplate := `{"type": "message_start", "message": {"id": "msg_1nZdL29xx5MUA1yADyHTEsnR8uuvGzszyY", "type": "message", "role": "assistant", "content": [], "model": "claude-3-5-sonnet-20241022", "stop_reason": null, "stop_sequence": null, "usage": {"input_tokens": 0, "output_tokens": 0}}}`
|
||||||
|
|
||||||
|
// Override default values with actual response metadata if available
|
||||||
|
if modelVersionResult := gjson.GetBytes(rawJSON, "response.modelVersion"); modelVersionResult.Exists() {
|
||||||
|
messageStartTemplate, _ = sjson.Set(messageStartTemplate, "message.model", modelVersionResult.String())
|
||||||
|
}
|
||||||
|
if responseIDResult := gjson.GetBytes(rawJSON, "response.responseId"); responseIDResult.Exists() {
|
||||||
|
messageStartTemplate, _ = sjson.Set(messageStartTemplate, "message.id", responseIDResult.String())
|
||||||
|
}
|
||||||
|
output = output + fmt.Sprintf("data: %s\n\n\n", messageStartTemplate)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Process the response parts array from the backend client
|
||||||
|
// Each part can contain text content, thinking content, or function calls
|
||||||
|
partsResult := gjson.GetBytes(rawJSON, "response.candidates.0.content.parts")
|
||||||
|
if partsResult.IsArray() {
|
||||||
|
partResults := partsResult.Array()
|
||||||
|
for i := 0; i < len(partResults); i++ {
|
||||||
|
partResult := partResults[i]
|
||||||
|
|
||||||
|
// Extract the different types of content from each part
|
||||||
|
partTextResult := partResult.Get("text")
|
||||||
|
functionCallResult := partResult.Get("functionCall")
|
||||||
|
|
||||||
|
// Handle text content (both regular content and thinking)
|
||||||
|
if partTextResult.Exists() {
|
||||||
|
// Process thinking content (internal reasoning)
|
||||||
|
if partResult.Get("thought").Bool() {
|
||||||
|
// Continue existing thinking block
|
||||||
|
if *responseType == 2 {
|
||||||
|
output = output + "event: content_block_delta\n"
|
||||||
|
data, _ := sjson.Set(fmt.Sprintf(`{"type":"content_block_delta","index":%d,"delta":{"type":"thinking_delta","thinking":""}}`, *responseIndex), "delta.thinking", partTextResult.String())
|
||||||
|
output = output + fmt.Sprintf("data: %s\n\n\n", data)
|
||||||
|
} else {
|
||||||
|
// Transition from another state to thinking
|
||||||
|
// First, close any existing content block
|
||||||
|
if *responseType != 0 {
|
||||||
|
if *responseType == 2 {
|
||||||
|
output = output + "event: content_block_delta\n"
|
||||||
|
output = output + fmt.Sprintf(`data: {"type":"content_block_delta","index":%d,"delta":{"type":"signature_delta","signature":null}}`, *responseIndex)
|
||||||
|
output = output + "\n\n\n"
|
||||||
|
}
|
||||||
|
output = output + "event: content_block_stop\n"
|
||||||
|
output = output + fmt.Sprintf(`data: {"type":"content_block_stop","index":%d}`, *responseIndex)
|
||||||
|
output = output + "\n\n\n"
|
||||||
|
*responseIndex++
|
||||||
|
}
|
||||||
|
|
||||||
|
// Start a new thinking content block
|
||||||
|
output = output + "event: content_block_start\n"
|
||||||
|
output = output + fmt.Sprintf(`data: {"type":"content_block_start","index":%d,"content_block":{"type":"thinking","thinking":""}}`, *responseIndex)
|
||||||
|
output = output + "\n\n\n"
|
||||||
|
output = output + "event: content_block_delta\n"
|
||||||
|
data, _ := sjson.Set(fmt.Sprintf(`{"type":"content_block_delta","index":%d,"delta":{"type":"thinking_delta","thinking":""}}`, *responseIndex), "delta.thinking", partTextResult.String())
|
||||||
|
output = output + fmt.Sprintf("data: %s\n\n\n", data)
|
||||||
|
*responseType = 2 // Set state to thinking
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// Process regular text content (user-visible output)
|
||||||
|
// Continue existing text block
|
||||||
|
if *responseType == 1 {
|
||||||
|
output = output + "event: content_block_delta\n"
|
||||||
|
data, _ := sjson.Set(fmt.Sprintf(`{"type":"content_block_delta","index":%d,"delta":{"type":"text_delta","text":""}}`, *responseIndex), "delta.text", partTextResult.String())
|
||||||
|
output = output + fmt.Sprintf("data: %s\n\n\n", data)
|
||||||
|
} else {
|
||||||
|
// Transition from another state to text content
|
||||||
|
// First, close any existing content block
|
||||||
|
if *responseType != 0 {
|
||||||
|
if *responseType == 2 {
|
||||||
|
output = output + "event: content_block_delta\n"
|
||||||
|
output = output + fmt.Sprintf(`data: {"type":"content_block_delta","index":%d,"delta":{"type":"signature_delta","signature":null}}`, *responseIndex)
|
||||||
|
output = output + "\n\n\n"
|
||||||
|
}
|
||||||
|
output = output + "event: content_block_stop\n"
|
||||||
|
output = output + fmt.Sprintf(`data: {"type":"content_block_stop","index":%d}`, *responseIndex)
|
||||||
|
output = output + "\n\n\n"
|
||||||
|
*responseIndex++
|
||||||
|
}
|
||||||
|
|
||||||
|
// Start a new text content block
|
||||||
|
output = output + "event: content_block_start\n"
|
||||||
|
output = output + fmt.Sprintf(`data: {"type":"content_block_start","index":%d,"content_block":{"type":"text","text":""}}`, *responseIndex)
|
||||||
|
output = output + "\n\n\n"
|
||||||
|
output = output + "event: content_block_delta\n"
|
||||||
|
data, _ := sjson.Set(fmt.Sprintf(`{"type":"content_block_delta","index":%d,"delta":{"type":"text_delta","text":""}}`, *responseIndex), "delta.text", partTextResult.String())
|
||||||
|
output = output + fmt.Sprintf("data: %s\n\n\n", data)
|
||||||
|
*responseType = 1 // Set state to content
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else if functionCallResult.Exists() {
|
||||||
|
// Handle function/tool calls from the AI model
|
||||||
|
// This processes tool usage requests and formats them for Claude API compatibility
|
||||||
|
usedTool = true
|
||||||
|
fcName := functionCallResult.Get("name").String()
|
||||||
|
|
||||||
|
// Handle state transitions when switching to function calls
|
||||||
|
// Close any existing function call block first
|
||||||
|
if *responseType == 3 {
|
||||||
|
output = output + "event: content_block_stop\n"
|
||||||
|
output = output + fmt.Sprintf(`data: {"type":"content_block_stop","index":%d}`, *responseIndex)
|
||||||
|
output = output + "\n\n\n"
|
||||||
|
*responseIndex++
|
||||||
|
*responseType = 0
|
||||||
|
}
|
||||||
|
|
||||||
|
// Special handling for thinking state transition
|
||||||
|
if *responseType == 2 {
|
||||||
|
output = output + "event: content_block_delta\n"
|
||||||
|
output = output + fmt.Sprintf(`data: {"type":"content_block_delta","index":%d,"delta":{"type":"signature_delta","signature":null}}`, *responseIndex)
|
||||||
|
output = output + "\n\n\n"
|
||||||
|
}
|
||||||
|
|
||||||
|
// Close any other existing content block
|
||||||
|
if *responseType != 0 {
|
||||||
|
output = output + "event: content_block_stop\n"
|
||||||
|
output = output + fmt.Sprintf(`data: {"type":"content_block_stop","index":%d}`, *responseIndex)
|
||||||
|
output = output + "\n\n\n"
|
||||||
|
*responseIndex++
|
||||||
|
}
|
||||||
|
|
||||||
|
// Start a new tool use content block
|
||||||
|
// This creates the structure for a function call in Claude format
|
||||||
|
output = output + "event: content_block_start\n"
|
||||||
|
|
||||||
|
// Create the tool use block with unique ID and function details
|
||||||
|
data := fmt.Sprintf(`{"type":"content_block_start","index":%d,"content_block":{"type":"tool_use","id":"","name":"","input":{}}}`, *responseIndex)
|
||||||
|
data, _ = sjson.Set(data, "content_block.id", fmt.Sprintf("%s-%d", fcName, time.Now().UnixNano()))
|
||||||
|
data, _ = sjson.Set(data, "content_block.name", fcName)
|
||||||
|
output = output + fmt.Sprintf("data: %s\n\n\n", data)
|
||||||
|
|
||||||
|
if fcArgsResult := functionCallResult.Get("args"); fcArgsResult.Exists() {
|
||||||
|
output = output + "event: content_block_delta\n"
|
||||||
|
data, _ = sjson.Set(fmt.Sprintf(`{"type":"content_block_delta","index":%d,"delta":{"type":"input_json_delta","partial_json":""}}`, *responseIndex), "delta.partial_json", fcArgsResult.Raw)
|
||||||
|
output = output + fmt.Sprintf("data: %s\n\n\n", data)
|
||||||
|
}
|
||||||
|
*responseType = 3
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
usageResult := gjson.GetBytes(rawJSON, "response.usageMetadata")
|
||||||
|
if usageResult.Exists() && bytes.Contains(rawJSON, []byte(`"finishReason"`)) {
|
||||||
|
if candidatesTokenCountResult := usageResult.Get("candidatesTokenCount"); candidatesTokenCountResult.Exists() {
|
||||||
|
output = output + "event: content_block_stop\n"
|
||||||
|
output = output + fmt.Sprintf(`data: {"type":"content_block_stop","index":%d}`, *responseIndex)
|
||||||
|
output = output + "\n\n\n"
|
||||||
|
|
||||||
|
output = output + "event: message_delta\n"
|
||||||
|
output = output + `data: `
|
||||||
|
|
||||||
|
template := `{"type":"message_delta","delta":{"stop_reason":"end_turn","stop_sequence":null},"usage":{"input_tokens":0,"output_tokens":0}}`
|
||||||
|
if usedTool {
|
||||||
|
template = `{"type":"message_delta","delta":{"stop_reason":"tool_use","stop_sequence":null},"usage":{"input_tokens":0,"output_tokens":0}}`
|
||||||
|
}
|
||||||
|
|
||||||
|
thoughtsTokenCount := usageResult.Get("thoughtsTokenCount").Int()
|
||||||
|
template, _ = sjson.Set(template, "usage.output_tokens", candidatesTokenCountResult.Int()+thoughtsTokenCount)
|
||||||
|
template, _ = sjson.Set(template, "usage.input_tokens", usageResult.Get("promptTokenCount").Int())
|
||||||
|
|
||||||
|
output = output + template + "\n\n\n"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return output
|
||||||
|
}
|
||||||
185
internal/api/translator/gemini/cli/request.go
Normal file
185
internal/api/translator/gemini/cli/request.go
Normal file
@@ -0,0 +1,185 @@
|
|||||||
|
// Package cli provides request translation functionality for Gemini CLI API.
|
||||||
|
// It handles the conversion and formatting of CLI tool responses, specifically
|
||||||
|
// transforming between different JSON formats to ensure proper conversation flow
|
||||||
|
// and API compatibility. The package focuses on intelligently grouping function
|
||||||
|
// calls with their corresponding responses, converting from linear format to
|
||||||
|
// grouped format where function calls and responses are properly associated.
|
||||||
|
package cli
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
"github.com/tidwall/gjson"
|
||||||
|
"github.com/tidwall/sjson"
|
||||||
|
)
|
||||||
|
|
||||||
|
// FunctionCallGroup represents a group of function calls and their responses
|
||||||
|
type FunctionCallGroup struct {
|
||||||
|
ModelContent map[string]interface{}
|
||||||
|
FunctionCalls []gjson.Result
|
||||||
|
ResponsesNeeded int
|
||||||
|
}
|
||||||
|
|
||||||
|
// FixCLIToolResponse performs sophisticated tool response format conversion and grouping.
|
||||||
|
// This function transforms the CLI tool response format by intelligently grouping function calls
|
||||||
|
// with their corresponding responses, ensuring proper conversation flow and API compatibility.
|
||||||
|
// It converts from a linear format (1.json) to a grouped format (2.json) where function calls
|
||||||
|
// and their responses are properly associated and structured.
|
||||||
|
func FixCLIToolResponse(input string) (string, error) {
|
||||||
|
// Parse the input JSON to extract the conversation structure
|
||||||
|
parsed := gjson.Parse(input)
|
||||||
|
|
||||||
|
// Extract the contents array which contains the conversation messages
|
||||||
|
contents := parsed.Get("request.contents")
|
||||||
|
if !contents.Exists() {
|
||||||
|
// log.Debugf(input)
|
||||||
|
return input, fmt.Errorf("contents not found in input")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Initialize data structures for processing and grouping
|
||||||
|
var newContents []interface{} // Final processed contents array
|
||||||
|
var pendingGroups []*FunctionCallGroup // Groups awaiting completion with responses
|
||||||
|
var collectedResponses []gjson.Result // Standalone responses to be matched
|
||||||
|
|
||||||
|
// Process each content object in the conversation
|
||||||
|
// This iterates through messages and groups function calls with their responses
|
||||||
|
contents.ForEach(func(key, value gjson.Result) bool {
|
||||||
|
role := value.Get("role").String()
|
||||||
|
parts := value.Get("parts")
|
||||||
|
|
||||||
|
// Check if this content has function responses
|
||||||
|
var responsePartsInThisContent []gjson.Result
|
||||||
|
parts.ForEach(func(_, part gjson.Result) bool {
|
||||||
|
if part.Get("functionResponse").Exists() {
|
||||||
|
responsePartsInThisContent = append(responsePartsInThisContent, part)
|
||||||
|
}
|
||||||
|
return true
|
||||||
|
})
|
||||||
|
|
||||||
|
// If this content has function responses, collect them
|
||||||
|
if len(responsePartsInThisContent) > 0 {
|
||||||
|
collectedResponses = append(collectedResponses, responsePartsInThisContent...)
|
||||||
|
|
||||||
|
// Check if any pending groups can be satisfied
|
||||||
|
for i := len(pendingGroups) - 1; i >= 0; i-- {
|
||||||
|
group := pendingGroups[i]
|
||||||
|
if len(collectedResponses) >= group.ResponsesNeeded {
|
||||||
|
// Take the needed responses for this group
|
||||||
|
groupResponses := collectedResponses[:group.ResponsesNeeded]
|
||||||
|
collectedResponses = collectedResponses[group.ResponsesNeeded:]
|
||||||
|
|
||||||
|
// Create merged function response content
|
||||||
|
var responseParts []interface{}
|
||||||
|
for _, response := range groupResponses {
|
||||||
|
var responseMap map[string]interface{}
|
||||||
|
errUnmarshal := json.Unmarshal([]byte(response.Raw), &responseMap)
|
||||||
|
if errUnmarshal != nil {
|
||||||
|
log.Warnf("failed to unmarshal function response: %v\n", errUnmarshal)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
responseParts = append(responseParts, responseMap)
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(responseParts) > 0 {
|
||||||
|
functionResponseContent := map[string]interface{}{
|
||||||
|
"parts": responseParts,
|
||||||
|
"role": "function",
|
||||||
|
}
|
||||||
|
newContents = append(newContents, functionResponseContent)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Remove this group as it's been satisfied
|
||||||
|
pendingGroups = append(pendingGroups[:i], pendingGroups[i+1:]...)
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return true // Skip adding this content, responses are merged
|
||||||
|
}
|
||||||
|
|
||||||
|
// If this is a model with function calls, create a new group
|
||||||
|
if role == "model" {
|
||||||
|
var functionCallsInThisModel []gjson.Result
|
||||||
|
parts.ForEach(func(_, part gjson.Result) bool {
|
||||||
|
if part.Get("functionCall").Exists() {
|
||||||
|
functionCallsInThisModel = append(functionCallsInThisModel, part)
|
||||||
|
}
|
||||||
|
return true
|
||||||
|
})
|
||||||
|
|
||||||
|
if len(functionCallsInThisModel) > 0 {
|
||||||
|
// Add the model content
|
||||||
|
var contentMap map[string]interface{}
|
||||||
|
errUnmarshal := json.Unmarshal([]byte(value.Raw), &contentMap)
|
||||||
|
if errUnmarshal != nil {
|
||||||
|
log.Warnf("failed to unmarshal model content: %v\n", errUnmarshal)
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
newContents = append(newContents, contentMap)
|
||||||
|
|
||||||
|
// Create a new group for tracking responses
|
||||||
|
group := &FunctionCallGroup{
|
||||||
|
ModelContent: contentMap,
|
||||||
|
FunctionCalls: functionCallsInThisModel,
|
||||||
|
ResponsesNeeded: len(functionCallsInThisModel),
|
||||||
|
}
|
||||||
|
pendingGroups = append(pendingGroups, group)
|
||||||
|
} else {
|
||||||
|
// Regular model content without function calls
|
||||||
|
var contentMap map[string]interface{}
|
||||||
|
errUnmarshal := json.Unmarshal([]byte(value.Raw), &contentMap)
|
||||||
|
if errUnmarshal != nil {
|
||||||
|
log.Warnf("failed to unmarshal content: %v\n", errUnmarshal)
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
newContents = append(newContents, contentMap)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// Non-model content (user, etc.)
|
||||||
|
var contentMap map[string]interface{}
|
||||||
|
errUnmarshal := json.Unmarshal([]byte(value.Raw), &contentMap)
|
||||||
|
if errUnmarshal != nil {
|
||||||
|
log.Warnf("failed to unmarshal content: %v\n", errUnmarshal)
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
newContents = append(newContents, contentMap)
|
||||||
|
}
|
||||||
|
|
||||||
|
return true
|
||||||
|
})
|
||||||
|
|
||||||
|
// Handle any remaining pending groups with remaining responses
|
||||||
|
for _, group := range pendingGroups {
|
||||||
|
if len(collectedResponses) >= group.ResponsesNeeded {
|
||||||
|
groupResponses := collectedResponses[:group.ResponsesNeeded]
|
||||||
|
collectedResponses = collectedResponses[group.ResponsesNeeded:]
|
||||||
|
|
||||||
|
var responseParts []interface{}
|
||||||
|
for _, response := range groupResponses {
|
||||||
|
var responseMap map[string]interface{}
|
||||||
|
errUnmarshal := json.Unmarshal([]byte(response.Raw), &responseMap)
|
||||||
|
if errUnmarshal != nil {
|
||||||
|
log.Warnf("failed to unmarshal function response: %v\n", errUnmarshal)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
responseParts = append(responseParts, responseMap)
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(responseParts) > 0 {
|
||||||
|
functionResponseContent := map[string]interface{}{
|
||||||
|
"parts": responseParts,
|
||||||
|
"role": "function",
|
||||||
|
}
|
||||||
|
newContents = append(newContents, functionResponseContent)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Update the original JSON with the new contents
|
||||||
|
result := input
|
||||||
|
newContentsJSON, _ := json.Marshal(newContents)
|
||||||
|
result, _ = sjson.Set(result, "request.contents", json.RawMessage(newContentsJSON))
|
||||||
|
|
||||||
|
return result, nil
|
||||||
|
}
|
||||||
@@ -1,3 +1,6 @@
|
|||||||
|
// Package translator provides data translation and format conversion utilities
|
||||||
|
// for the CLI Proxy API. It includes MIME type mappings and other translation
|
||||||
|
// functions used across different API endpoints.
|
||||||
package translator
|
package translator
|
||||||
|
|
||||||
// MimeTypes is a comprehensive map of file extensions to their corresponding MIME types.
|
// MimeTypes is a comprehensive map of file extensions to their corresponding MIME types.
|
||||||
|
|||||||
226
internal/api/translator/openai/request.go
Normal file
226
internal/api/translator/openai/request.go
Normal file
@@ -0,0 +1,226 @@
|
|||||||
|
// Package openai provides request translation functionality for OpenAI API.
|
||||||
|
// It handles the conversion of OpenAI-compatible request formats to the internal
|
||||||
|
// format expected by the backend client, including parsing messages, roles,
|
||||||
|
// content types (text, image, file), and tool calls.
|
||||||
|
package openai
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"github.com/luispater/CLIProxyAPI/internal/api/translator"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"github.com/luispater/CLIProxyAPI/internal/client"
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
"github.com/tidwall/gjson"
|
||||||
|
)
|
||||||
|
|
||||||
|
// PrepareRequest translates a raw JSON request from an OpenAI-compatible format
|
||||||
|
// to the internal format expected by the backend client. It parses messages,
|
||||||
|
// roles, content types (text, image, file), and tool calls.
|
||||||
|
func PrepareRequest(rawJSON []byte) (string, *client.Content, []client.Content, []client.ToolDeclaration) {
|
||||||
|
// Extract the model name from the request, defaulting to "gemini-2.5-pro".
|
||||||
|
modelName := "gemini-2.5-pro"
|
||||||
|
modelResult := gjson.GetBytes(rawJSON, "model")
|
||||||
|
if modelResult.Type == gjson.String {
|
||||||
|
modelName = modelResult.String()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Initialize data structures for processing conversation messages
|
||||||
|
// contents: stores the processed conversation history
|
||||||
|
// systemInstruction: stores system-level instructions separate from conversation
|
||||||
|
contents := make([]client.Content, 0)
|
||||||
|
var systemInstruction *client.Content
|
||||||
|
messagesResult := gjson.GetBytes(rawJSON, "messages")
|
||||||
|
|
||||||
|
// Pre-process tool responses to create a lookup map
|
||||||
|
// This first pass collects all tool responses so they can be matched with their corresponding calls
|
||||||
|
toolItems := make(map[string]*client.FunctionResponse)
|
||||||
|
if messagesResult.IsArray() {
|
||||||
|
messagesResults := messagesResult.Array()
|
||||||
|
for i := 0; i < len(messagesResults); i++ {
|
||||||
|
messageResult := messagesResults[i]
|
||||||
|
roleResult := messageResult.Get("role")
|
||||||
|
if roleResult.Type != gjson.String {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
contentResult := messageResult.Get("content")
|
||||||
|
|
||||||
|
// Extract tool responses for later matching with function calls
|
||||||
|
if roleResult.String() == "tool" {
|
||||||
|
toolCallID := messageResult.Get("tool_call_id").String()
|
||||||
|
if toolCallID != "" {
|
||||||
|
var responseData string
|
||||||
|
// Handle both string and object-based tool response formats
|
||||||
|
if contentResult.Type == gjson.String {
|
||||||
|
responseData = contentResult.String()
|
||||||
|
} else if contentResult.IsObject() && contentResult.Get("type").String() == "text" {
|
||||||
|
responseData = contentResult.Get("text").String()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Clean up tool call ID by removing timestamp suffix
|
||||||
|
// This normalizes IDs for consistent matching between calls and responses
|
||||||
|
toolCallIDs := strings.Split(toolCallID, "-")
|
||||||
|
strings.Join(toolCallIDs, "-")
|
||||||
|
newToolCallID := strings.Join(toolCallIDs[:len(toolCallIDs)-1], "-")
|
||||||
|
|
||||||
|
// Create function response object with normalized ID and response data
|
||||||
|
functionResponse := client.FunctionResponse{Name: newToolCallID, Response: map[string]interface{}{"result": responseData}}
|
||||||
|
toolItems[toolCallID] = &functionResponse
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if messagesResult.IsArray() {
|
||||||
|
messagesResults := messagesResult.Array()
|
||||||
|
for i := 0; i < len(messagesResults); i++ {
|
||||||
|
messageResult := messagesResults[i]
|
||||||
|
roleResult := messageResult.Get("role")
|
||||||
|
contentResult := messageResult.Get("content")
|
||||||
|
if roleResult.Type != gjson.String {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
switch roleResult.String() {
|
||||||
|
// System messages are converted to a user message followed by a model's acknowledgment.
|
||||||
|
case "system":
|
||||||
|
if contentResult.Type == gjson.String {
|
||||||
|
systemInstruction = &client.Content{Role: "user", Parts: []client.Part{{Text: contentResult.String()}}}
|
||||||
|
} else if contentResult.IsObject() {
|
||||||
|
// Handle object-based system messages.
|
||||||
|
if contentResult.Get("type").String() == "text" {
|
||||||
|
systemInstruction = &client.Content{Role: "user", Parts: []client.Part{{Text: contentResult.Get("text").String()}}}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// User messages can contain simple text or a multi-part body.
|
||||||
|
case "user":
|
||||||
|
if contentResult.Type == gjson.String {
|
||||||
|
contents = append(contents, client.Content{Role: "user", Parts: []client.Part{{Text: contentResult.String()}}})
|
||||||
|
} else if contentResult.IsArray() {
|
||||||
|
// Handle multi-part user messages (text, images, files).
|
||||||
|
contentItemResults := contentResult.Array()
|
||||||
|
parts := make([]client.Part, 0)
|
||||||
|
for j := 0; j < len(contentItemResults); j++ {
|
||||||
|
contentItemResult := contentItemResults[j]
|
||||||
|
contentTypeResult := contentItemResult.Get("type")
|
||||||
|
switch contentTypeResult.String() {
|
||||||
|
case "text":
|
||||||
|
parts = append(parts, client.Part{Text: contentItemResult.Get("text").String()})
|
||||||
|
case "image_url":
|
||||||
|
// Parse data URI for images.
|
||||||
|
imageURL := contentItemResult.Get("image_url.url").String()
|
||||||
|
if len(imageURL) > 5 {
|
||||||
|
imageURLs := strings.SplitN(imageURL[5:], ";", 2)
|
||||||
|
if len(imageURLs) == 2 && len(imageURLs[1]) > 7 {
|
||||||
|
parts = append(parts, client.Part{InlineData: &client.InlineData{
|
||||||
|
MimeType: imageURLs[0],
|
||||||
|
Data: imageURLs[1][7:],
|
||||||
|
}})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
case "file":
|
||||||
|
// Handle file attachments by determining MIME type from extension.
|
||||||
|
filename := contentItemResult.Get("file.filename").String()
|
||||||
|
fileData := contentItemResult.Get("file.file_data").String()
|
||||||
|
ext := ""
|
||||||
|
if split := strings.Split(filename, "."); len(split) > 1 {
|
||||||
|
ext = split[len(split)-1]
|
||||||
|
}
|
||||||
|
if mimeType, ok := translator.MimeTypes[ext]; ok {
|
||||||
|
parts = append(parts, client.Part{InlineData: &client.InlineData{
|
||||||
|
MimeType: mimeType,
|
||||||
|
Data: fileData,
|
||||||
|
}})
|
||||||
|
} else {
|
||||||
|
log.Warnf("Unknown file name extension '%s' at index %d, skipping file", ext, j)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
contents = append(contents, client.Content{Role: "user", Parts: parts})
|
||||||
|
}
|
||||||
|
// Assistant messages can contain text responses or tool calls
|
||||||
|
// In the internal format, assistant messages are converted to "model" role
|
||||||
|
case "assistant":
|
||||||
|
if contentResult.Type == gjson.String {
|
||||||
|
// Simple text response from the assistant
|
||||||
|
contents = append(contents, client.Content{Role: "model", Parts: []client.Part{{Text: contentResult.String()}}})
|
||||||
|
} else if !contentResult.Exists() || contentResult.Type == gjson.Null {
|
||||||
|
// Handle complex tool calls made by the assistant
|
||||||
|
// This processes function calls and matches them with their responses
|
||||||
|
functionIDs := make([]string, 0)
|
||||||
|
toolCallsResult := messageResult.Get("tool_calls")
|
||||||
|
if toolCallsResult.IsArray() {
|
||||||
|
parts := make([]client.Part, 0)
|
||||||
|
tcsResult := toolCallsResult.Array()
|
||||||
|
|
||||||
|
// Process each tool call in the assistant's message
|
||||||
|
for j := 0; j < len(tcsResult); j++ {
|
||||||
|
tcResult := tcsResult[j]
|
||||||
|
|
||||||
|
// Extract function call details
|
||||||
|
functionID := tcResult.Get("id").String()
|
||||||
|
functionIDs = append(functionIDs, functionID)
|
||||||
|
|
||||||
|
functionName := tcResult.Get("function.name").String()
|
||||||
|
functionArgs := tcResult.Get("function.arguments").String()
|
||||||
|
|
||||||
|
// Parse function arguments from JSON string to map
|
||||||
|
var args map[string]any
|
||||||
|
if err := json.Unmarshal([]byte(functionArgs), &args); err == nil {
|
||||||
|
parts = append(parts, client.Part{
|
||||||
|
FunctionCall: &client.FunctionCall{
|
||||||
|
Name: functionName,
|
||||||
|
Args: args,
|
||||||
|
},
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add the model's function calls to the conversation
|
||||||
|
if len(parts) > 0 {
|
||||||
|
contents = append(contents, client.Content{
|
||||||
|
Role: "model", Parts: parts,
|
||||||
|
})
|
||||||
|
|
||||||
|
// Create a separate tool response message with the collected responses
|
||||||
|
// This matches function calls with their corresponding responses
|
||||||
|
toolParts := make([]client.Part, 0)
|
||||||
|
for _, functionID := range functionIDs {
|
||||||
|
if functionResponse, ok := toolItems[functionID]; ok {
|
||||||
|
toolParts = append(toolParts, client.Part{FunctionResponse: functionResponse})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// Add the tool responses as a separate message in the conversation
|
||||||
|
contents = append(contents, client.Content{Role: "tool", Parts: toolParts})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Translate the tool declarations from the request.
|
||||||
|
var tools []client.ToolDeclaration
|
||||||
|
toolsResult := gjson.GetBytes(rawJSON, "tools")
|
||||||
|
if toolsResult.IsArray() {
|
||||||
|
tools = make([]client.ToolDeclaration, 1)
|
||||||
|
tools[0].FunctionDeclarations = make([]any, 0)
|
||||||
|
toolsResults := toolsResult.Array()
|
||||||
|
for i := 0; i < len(toolsResults); i++ {
|
||||||
|
toolResult := toolsResults[i]
|
||||||
|
if toolResult.Get("type").String() == "function" {
|
||||||
|
functionTypeResult := toolResult.Get("function")
|
||||||
|
if functionTypeResult.Exists() && functionTypeResult.IsObject() {
|
||||||
|
var functionDeclaration any
|
||||||
|
if err := json.Unmarshal([]byte(functionTypeResult.Raw), &functionDeclaration); err == nil {
|
||||||
|
tools[0].FunctionDeclarations = append(tools[0].FunctionDeclarations, functionDeclaration)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
tools = make([]client.ToolDeclaration, 0)
|
||||||
|
}
|
||||||
|
|
||||||
|
return modelName, systemInstruction, contents, tools
|
||||||
|
}
|
||||||
197
internal/api/translator/openai/response.go
Normal file
197
internal/api/translator/openai/response.go
Normal file
@@ -0,0 +1,197 @@
|
|||||||
|
// Package openai provides response translation functionality for converting between
|
||||||
|
// different API response formats and OpenAI-compatible formats. It handles both
|
||||||
|
// streaming and non-streaming responses, transforming backend client responses
|
||||||
|
// into OpenAI Server-Sent Events (SSE) format and standard JSON response formats.
|
||||||
|
// The package supports content translation, function calls, usage metadata,
|
||||||
|
// and various response attributes while maintaining compatibility with OpenAI API
|
||||||
|
// specifications.
|
||||||
|
package openai
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/tidwall/gjson"
|
||||||
|
"github.com/tidwall/sjson"
|
||||||
|
)
|
||||||
|
|
||||||
|
// ConvertCliToOpenAI translates a single chunk of a streaming response from the
|
||||||
|
// backend client format to the OpenAI Server-Sent Events (SSE) format.
|
||||||
|
// It returns an empty string if the chunk contains no useful data.
|
||||||
|
func ConvertCliToOpenAI(rawJSON []byte, unixTimestamp int64, isGlAPIKey bool) string {
|
||||||
|
if isGlAPIKey {
|
||||||
|
rawJSON, _ = sjson.SetRawBytes(rawJSON, "response", rawJSON)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Initialize the OpenAI SSE template.
|
||||||
|
template := `{"id":"","object":"chat.completion.chunk","created":12345,"model":"model","choices":[{"index":0,"delta":{"role":null,"content":null,"reasoning_content":null,"tool_calls":null},"finish_reason":null,"native_finish_reason":null}]}`
|
||||||
|
|
||||||
|
// Extract and set the model version.
|
||||||
|
if modelVersionResult := gjson.GetBytes(rawJSON, "response.modelVersion"); modelVersionResult.Exists() {
|
||||||
|
template, _ = sjson.Set(template, "model", modelVersionResult.String())
|
||||||
|
}
|
||||||
|
|
||||||
|
// Extract and set the creation timestamp.
|
||||||
|
if createTimeResult := gjson.GetBytes(rawJSON, "response.createTime"); createTimeResult.Exists() {
|
||||||
|
t, err := time.Parse(time.RFC3339Nano, createTimeResult.String())
|
||||||
|
if err == nil {
|
||||||
|
unixTimestamp = t.Unix()
|
||||||
|
}
|
||||||
|
template, _ = sjson.Set(template, "created", unixTimestamp)
|
||||||
|
} else {
|
||||||
|
template, _ = sjson.Set(template, "created", unixTimestamp)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Extract and set the response ID.
|
||||||
|
if responseIDResult := gjson.GetBytes(rawJSON, "response.responseId"); responseIDResult.Exists() {
|
||||||
|
template, _ = sjson.Set(template, "id", responseIDResult.String())
|
||||||
|
}
|
||||||
|
|
||||||
|
// Extract and set the finish reason.
|
||||||
|
if finishReasonResult := gjson.GetBytes(rawJSON, "response.candidates.0.finishReason"); finishReasonResult.Exists() {
|
||||||
|
template, _ = sjson.Set(template, "choices.0.finish_reason", finishReasonResult.String())
|
||||||
|
template, _ = sjson.Set(template, "choices.0.native_finish_reason", finishReasonResult.String())
|
||||||
|
}
|
||||||
|
|
||||||
|
// Extract and set usage metadata (token counts).
|
||||||
|
if usageResult := gjson.GetBytes(rawJSON, "response.usageMetadata"); usageResult.Exists() {
|
||||||
|
if candidatesTokenCountResult := usageResult.Get("candidatesTokenCount"); candidatesTokenCountResult.Exists() {
|
||||||
|
template, _ = sjson.Set(template, "usage.completion_tokens", candidatesTokenCountResult.Int())
|
||||||
|
}
|
||||||
|
if totalTokenCountResult := usageResult.Get("totalTokenCount"); totalTokenCountResult.Exists() {
|
||||||
|
template, _ = sjson.Set(template, "usage.total_tokens", totalTokenCountResult.Int())
|
||||||
|
}
|
||||||
|
promptTokenCount := usageResult.Get("promptTokenCount").Int()
|
||||||
|
thoughtsTokenCount := usageResult.Get("thoughtsTokenCount").Int()
|
||||||
|
template, _ = sjson.Set(template, "usage.prompt_tokens", promptTokenCount+thoughtsTokenCount)
|
||||||
|
if thoughtsTokenCount > 0 {
|
||||||
|
template, _ = sjson.Set(template, "usage.completion_tokens_details.reasoning_tokens", thoughtsTokenCount)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Process the main content part of the response.
|
||||||
|
partsResult := gjson.GetBytes(rawJSON, "response.candidates.0.content.parts")
|
||||||
|
if partsResult.IsArray() {
|
||||||
|
partResults := partsResult.Array()
|
||||||
|
for i := 0; i < len(partResults); i++ {
|
||||||
|
partResult := partResults[i]
|
||||||
|
partTextResult := partResult.Get("text")
|
||||||
|
functionCallResult := partResult.Get("functionCall")
|
||||||
|
|
||||||
|
if partTextResult.Exists() {
|
||||||
|
// Handle text content, distinguishing between regular content and reasoning/thoughts.
|
||||||
|
if partResult.Get("thought").Bool() {
|
||||||
|
template, _ = sjson.Set(template, "choices.0.delta.reasoning_content", partTextResult.String())
|
||||||
|
} else {
|
||||||
|
template, _ = sjson.Set(template, "choices.0.delta.content", partTextResult.String())
|
||||||
|
}
|
||||||
|
template, _ = sjson.Set(template, "choices.0.delta.role", "assistant")
|
||||||
|
} else if functionCallResult.Exists() {
|
||||||
|
// Handle function call content.
|
||||||
|
toolCallsResult := gjson.Get(template, "choices.0.delta.tool_calls")
|
||||||
|
if !toolCallsResult.Exists() || !toolCallsResult.IsArray() {
|
||||||
|
template, _ = sjson.SetRaw(template, "choices.0.delta.tool_calls", `[]`)
|
||||||
|
}
|
||||||
|
|
||||||
|
functionCallTemplate := `{"id": "","type": "function","function": {"name": "","arguments": ""}}`
|
||||||
|
fcName := functionCallResult.Get("name").String()
|
||||||
|
functionCallTemplate, _ = sjson.Set(functionCallTemplate, "id", fmt.Sprintf("%s-%d", fcName, time.Now().UnixNano()))
|
||||||
|
functionCallTemplate, _ = sjson.Set(functionCallTemplate, "function.name", fcName)
|
||||||
|
if fcArgsResult := functionCallResult.Get("args"); fcArgsResult.Exists() {
|
||||||
|
functionCallTemplate, _ = sjson.Set(functionCallTemplate, "function.arguments", fcArgsResult.Raw)
|
||||||
|
}
|
||||||
|
template, _ = sjson.Set(template, "choices.0.delta.role", "assistant")
|
||||||
|
template, _ = sjson.SetRaw(template, "choices.0.message.tool_calls.-1", functionCallTemplate)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return template
|
||||||
|
}
|
||||||
|
|
||||||
|
// ConvertCliToOpenAINonStream aggregates response from the backend client
|
||||||
|
// convert a single, non-streaming OpenAI-compatible JSON response.
|
||||||
|
func ConvertCliToOpenAINonStream(rawJSON []byte, unixTimestamp int64, isGlAPIKey bool) string {
|
||||||
|
if isGlAPIKey {
|
||||||
|
rawJSON, _ = sjson.SetRawBytes(rawJSON, "response", rawJSON)
|
||||||
|
}
|
||||||
|
template := `{"id":"","object":"chat.completion","created":123456,"model":"model","choices":[{"index":0,"message":{"role":"assistant","content":null,"reasoning_content":null,"tool_calls":null},"finish_reason":null,"native_finish_reason":null}]}`
|
||||||
|
if modelVersionResult := gjson.GetBytes(rawJSON, "response.modelVersion"); modelVersionResult.Exists() {
|
||||||
|
template, _ = sjson.Set(template, "model", modelVersionResult.String())
|
||||||
|
}
|
||||||
|
|
||||||
|
if createTimeResult := gjson.GetBytes(rawJSON, "response.createTime"); createTimeResult.Exists() {
|
||||||
|
t, err := time.Parse(time.RFC3339Nano, createTimeResult.String())
|
||||||
|
if err == nil {
|
||||||
|
unixTimestamp = t.Unix()
|
||||||
|
}
|
||||||
|
template, _ = sjson.Set(template, "created", unixTimestamp)
|
||||||
|
} else {
|
||||||
|
template, _ = sjson.Set(template, "created", unixTimestamp)
|
||||||
|
}
|
||||||
|
|
||||||
|
if responseIDResult := gjson.GetBytes(rawJSON, "response.responseId"); responseIDResult.Exists() {
|
||||||
|
template, _ = sjson.Set(template, "id", responseIDResult.String())
|
||||||
|
}
|
||||||
|
|
||||||
|
if finishReasonResult := gjson.GetBytes(rawJSON, "response.candidates.0.finishReason"); finishReasonResult.Exists() {
|
||||||
|
template, _ = sjson.Set(template, "choices.0.finish_reason", finishReasonResult.String())
|
||||||
|
template, _ = sjson.Set(template, "choices.0.native_finish_reason", finishReasonResult.String())
|
||||||
|
}
|
||||||
|
|
||||||
|
if usageResult := gjson.GetBytes(rawJSON, "response.usageMetadata"); usageResult.Exists() {
|
||||||
|
if candidatesTokenCountResult := usageResult.Get("candidatesTokenCount"); candidatesTokenCountResult.Exists() {
|
||||||
|
template, _ = sjson.Set(template, "usage.completion_tokens", candidatesTokenCountResult.Int())
|
||||||
|
}
|
||||||
|
if totalTokenCountResult := usageResult.Get("totalTokenCount"); totalTokenCountResult.Exists() {
|
||||||
|
template, _ = sjson.Set(template, "usage.total_tokens", totalTokenCountResult.Int())
|
||||||
|
}
|
||||||
|
promptTokenCount := usageResult.Get("promptTokenCount").Int()
|
||||||
|
thoughtsTokenCount := usageResult.Get("thoughtsTokenCount").Int()
|
||||||
|
template, _ = sjson.Set(template, "usage.prompt_tokens", promptTokenCount+thoughtsTokenCount)
|
||||||
|
if thoughtsTokenCount > 0 {
|
||||||
|
template, _ = sjson.Set(template, "usage.completion_tokens_details.reasoning_tokens", thoughtsTokenCount)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Process the main content part of the response.
|
||||||
|
partsResult := gjson.GetBytes(rawJSON, "response.candidates.0.content.parts")
|
||||||
|
if partsResult.IsArray() {
|
||||||
|
partsResults := partsResult.Array()
|
||||||
|
for i := 0; i < len(partsResults); i++ {
|
||||||
|
partResult := partsResults[i]
|
||||||
|
partTextResult := partResult.Get("text")
|
||||||
|
functionCallResult := partResult.Get("functionCall")
|
||||||
|
|
||||||
|
if partTextResult.Exists() {
|
||||||
|
// Append text content, distinguishing between regular content and reasoning.
|
||||||
|
if partResult.Get("thought").Bool() {
|
||||||
|
template, _ = sjson.Set(template, "choices.0.message.reasoning_content", partTextResult.String())
|
||||||
|
} else {
|
||||||
|
template, _ = sjson.Set(template, "choices.0.message.content", partTextResult.String())
|
||||||
|
}
|
||||||
|
template, _ = sjson.Set(template, "choices.0.message.role", "assistant")
|
||||||
|
} else if functionCallResult.Exists() {
|
||||||
|
// Append function call content to the tool_calls array.
|
||||||
|
toolCallsResult := gjson.Get(template, "choices.0.message.tool_calls")
|
||||||
|
if !toolCallsResult.Exists() || !toolCallsResult.IsArray() {
|
||||||
|
template, _ = sjson.SetRaw(template, "choices.0.message.tool_calls", `[]`)
|
||||||
|
}
|
||||||
|
functionCallItemTemplate := `{"id": "","type": "function","function": {"name": "","arguments": ""}}`
|
||||||
|
fcName := functionCallResult.Get("name").String()
|
||||||
|
functionCallItemTemplate, _ = sjson.Set(functionCallItemTemplate, "id", fmt.Sprintf("%s-%d", fcName, time.Now().UnixNano()))
|
||||||
|
functionCallItemTemplate, _ = sjson.Set(functionCallItemTemplate, "function.name", fcName)
|
||||||
|
if fcArgsResult := functionCallResult.Get("args"); fcArgsResult.Exists() {
|
||||||
|
functionCallItemTemplate, _ = sjson.Set(functionCallItemTemplate, "function.arguments", fcArgsResult.Raw)
|
||||||
|
}
|
||||||
|
template, _ = sjson.Set(template, "choices.0.message.role", "assistant")
|
||||||
|
template, _ = sjson.SetRaw(template, "choices.0.message.tool_calls.-1", functionCallItemTemplate)
|
||||||
|
} else {
|
||||||
|
// If no usable content is found, return an empty string.
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return template
|
||||||
|
}
|
||||||
@@ -1,545 +0,0 @@
|
|||||||
package translator
|
|
||||||
|
|
||||||
import (
|
|
||||||
"bytes"
|
|
||||||
"encoding/json"
|
|
||||||
"fmt"
|
|
||||||
"github.com/tidwall/sjson"
|
|
||||||
"strings"
|
|
||||||
|
|
||||||
"github.com/luispater/CLIProxyAPI/internal/client"
|
|
||||||
log "github.com/sirupsen/logrus"
|
|
||||||
"github.com/tidwall/gjson"
|
|
||||||
)
|
|
||||||
|
|
||||||
// PrepareRequest translates a raw JSON request from an OpenAI-compatible format
|
|
||||||
// to the internal format expected by the backend client. It parses messages,
|
|
||||||
// roles, content types (text, image, file), and tool calls.
|
|
||||||
func PrepareRequest(rawJson []byte) (string, *client.Content, []client.Content, []client.ToolDeclaration) {
|
|
||||||
// Extract the model name from the request, defaulting to "gemini-2.5-pro".
|
|
||||||
modelName := "gemini-2.5-pro"
|
|
||||||
modelResult := gjson.GetBytes(rawJson, "model")
|
|
||||||
if modelResult.Type == gjson.String {
|
|
||||||
modelName = modelResult.String()
|
|
||||||
}
|
|
||||||
|
|
||||||
// Initialize data structures for processing conversation messages
|
|
||||||
// contents: stores the processed conversation history
|
|
||||||
// systemInstruction: stores system-level instructions separate from conversation
|
|
||||||
contents := make([]client.Content, 0)
|
|
||||||
var systemInstruction *client.Content
|
|
||||||
messagesResult := gjson.GetBytes(rawJson, "messages")
|
|
||||||
|
|
||||||
// Pre-process tool responses to create a lookup map
|
|
||||||
// This first pass collects all tool responses so they can be matched with their corresponding calls
|
|
||||||
toolItems := make(map[string]*client.FunctionResponse)
|
|
||||||
if messagesResult.IsArray() {
|
|
||||||
messagesResults := messagesResult.Array()
|
|
||||||
for i := 0; i < len(messagesResults); i++ {
|
|
||||||
messageResult := messagesResults[i]
|
|
||||||
roleResult := messageResult.Get("role")
|
|
||||||
if roleResult.Type != gjson.String {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
contentResult := messageResult.Get("content")
|
|
||||||
|
|
||||||
// Extract tool responses for later matching with function calls
|
|
||||||
if roleResult.String() == "tool" {
|
|
||||||
toolCallID := messageResult.Get("tool_call_id").String()
|
|
||||||
if toolCallID != "" {
|
|
||||||
var responseData string
|
|
||||||
// Handle both string and object-based tool response formats
|
|
||||||
if contentResult.Type == gjson.String {
|
|
||||||
responseData = contentResult.String()
|
|
||||||
} else if contentResult.IsObject() && contentResult.Get("type").String() == "text" {
|
|
||||||
responseData = contentResult.Get("text").String()
|
|
||||||
}
|
|
||||||
|
|
||||||
// Clean up tool call ID by removing timestamp suffix
|
|
||||||
// This normalizes IDs for consistent matching between calls and responses
|
|
||||||
toolCallIDs := strings.Split(toolCallID, "-")
|
|
||||||
strings.Join(toolCallIDs, "-")
|
|
||||||
newToolCallID := strings.Join(toolCallIDs[:len(toolCallIDs)-1], "-")
|
|
||||||
|
|
||||||
// Create function response object with normalized ID and response data
|
|
||||||
functionResponse := client.FunctionResponse{Name: newToolCallID, Response: map[string]interface{}{"result": responseData}}
|
|
||||||
toolItems[toolCallID] = &functionResponse
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if messagesResult.IsArray() {
|
|
||||||
messagesResults := messagesResult.Array()
|
|
||||||
for i := 0; i < len(messagesResults); i++ {
|
|
||||||
messageResult := messagesResults[i]
|
|
||||||
roleResult := messageResult.Get("role")
|
|
||||||
contentResult := messageResult.Get("content")
|
|
||||||
if roleResult.Type != gjson.String {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
switch roleResult.String() {
|
|
||||||
// System messages are converted to a user message followed by a model's acknowledgment.
|
|
||||||
case "system":
|
|
||||||
if contentResult.Type == gjson.String {
|
|
||||||
systemInstruction = &client.Content{Role: "user", Parts: []client.Part{{Text: contentResult.String()}}}
|
|
||||||
} else if contentResult.IsObject() {
|
|
||||||
// Handle object-based system messages.
|
|
||||||
if contentResult.Get("type").String() == "text" {
|
|
||||||
systemInstruction = &client.Content{Role: "user", Parts: []client.Part{{Text: contentResult.Get("text").String()}}}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
// User messages can contain simple text or a multi-part body.
|
|
||||||
case "user":
|
|
||||||
if contentResult.Type == gjson.String {
|
|
||||||
contents = append(contents, client.Content{Role: "user", Parts: []client.Part{{Text: contentResult.String()}}})
|
|
||||||
} else if contentResult.IsArray() {
|
|
||||||
// Handle multi-part user messages (text, images, files).
|
|
||||||
contentItemResults := contentResult.Array()
|
|
||||||
parts := make([]client.Part, 0)
|
|
||||||
for j := 0; j < len(contentItemResults); j++ {
|
|
||||||
contentItemResult := contentItemResults[j]
|
|
||||||
contentTypeResult := contentItemResult.Get("type")
|
|
||||||
switch contentTypeResult.String() {
|
|
||||||
case "text":
|
|
||||||
parts = append(parts, client.Part{Text: contentItemResult.Get("text").String()})
|
|
||||||
case "image_url":
|
|
||||||
// Parse data URI for images.
|
|
||||||
imageURL := contentItemResult.Get("image_url.url").String()
|
|
||||||
if len(imageURL) > 5 {
|
|
||||||
imageURLs := strings.SplitN(imageURL[5:], ";", 2)
|
|
||||||
if len(imageURLs) == 2 && len(imageURLs[1]) > 7 {
|
|
||||||
parts = append(parts, client.Part{InlineData: &client.InlineData{
|
|
||||||
MimeType: imageURLs[0],
|
|
||||||
Data: imageURLs[1][7:],
|
|
||||||
}})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
case "file":
|
|
||||||
// Handle file attachments by determining MIME type from extension.
|
|
||||||
filename := contentItemResult.Get("file.filename").String()
|
|
||||||
fileData := contentItemResult.Get("file.file_data").String()
|
|
||||||
ext := ""
|
|
||||||
if split := strings.Split(filename, "."); len(split) > 1 {
|
|
||||||
ext = split[len(split)-1]
|
|
||||||
}
|
|
||||||
if mimeType, ok := MimeTypes[ext]; ok {
|
|
||||||
parts = append(parts, client.Part{InlineData: &client.InlineData{
|
|
||||||
MimeType: mimeType,
|
|
||||||
Data: fileData,
|
|
||||||
}})
|
|
||||||
} else {
|
|
||||||
log.Warnf("Unknown file name extension '%s' at index %d, skipping file", ext, j)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
contents = append(contents, client.Content{Role: "user", Parts: parts})
|
|
||||||
}
|
|
||||||
// Assistant messages can contain text responses or tool calls
|
|
||||||
// In the internal format, assistant messages are converted to "model" role
|
|
||||||
case "assistant":
|
|
||||||
if contentResult.Type == gjson.String {
|
|
||||||
// Simple text response from the assistant
|
|
||||||
contents = append(contents, client.Content{Role: "model", Parts: []client.Part{{Text: contentResult.String()}}})
|
|
||||||
} else if !contentResult.Exists() || contentResult.Type == gjson.Null {
|
|
||||||
// Handle complex tool calls made by the assistant
|
|
||||||
// This processes function calls and matches them with their responses
|
|
||||||
functionIDs := make([]string, 0)
|
|
||||||
toolCallsResult := messageResult.Get("tool_calls")
|
|
||||||
if toolCallsResult.IsArray() {
|
|
||||||
parts := make([]client.Part, 0)
|
|
||||||
tcsResult := toolCallsResult.Array()
|
|
||||||
|
|
||||||
// Process each tool call in the assistant's message
|
|
||||||
for j := 0; j < len(tcsResult); j++ {
|
|
||||||
tcResult := tcsResult[j]
|
|
||||||
|
|
||||||
// Extract function call details
|
|
||||||
functionID := tcResult.Get("id").String()
|
|
||||||
functionIDs = append(functionIDs, functionID)
|
|
||||||
|
|
||||||
functionName := tcResult.Get("function.name").String()
|
|
||||||
functionArgs := tcResult.Get("function.arguments").String()
|
|
||||||
|
|
||||||
// Parse function arguments from JSON string to map
|
|
||||||
var args map[string]any
|
|
||||||
if err := json.Unmarshal([]byte(functionArgs), &args); err == nil {
|
|
||||||
parts = append(parts, client.Part{
|
|
||||||
FunctionCall: &client.FunctionCall{
|
|
||||||
Name: functionName,
|
|
||||||
Args: args,
|
|
||||||
},
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Add the model's function calls to the conversation
|
|
||||||
if len(parts) > 0 {
|
|
||||||
contents = append(contents, client.Content{
|
|
||||||
Role: "model", Parts: parts,
|
|
||||||
})
|
|
||||||
|
|
||||||
// Create a separate tool response message with the collected responses
|
|
||||||
// This matches function calls with their corresponding responses
|
|
||||||
toolParts := make([]client.Part, 0)
|
|
||||||
for _, functionID := range functionIDs {
|
|
||||||
if functionResponse, ok := toolItems[functionID]; ok {
|
|
||||||
toolParts = append(toolParts, client.Part{FunctionResponse: functionResponse})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
// Add the tool responses as a separate message in the conversation
|
|
||||||
contents = append(contents, client.Content{Role: "tool", Parts: toolParts})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Translate the tool declarations from the request.
|
|
||||||
var tools []client.ToolDeclaration
|
|
||||||
toolsResult := gjson.GetBytes(rawJson, "tools")
|
|
||||||
if toolsResult.IsArray() {
|
|
||||||
tools = make([]client.ToolDeclaration, 1)
|
|
||||||
tools[0].FunctionDeclarations = make([]any, 0)
|
|
||||||
toolsResults := toolsResult.Array()
|
|
||||||
for i := 0; i < len(toolsResults); i++ {
|
|
||||||
toolResult := toolsResults[i]
|
|
||||||
if toolResult.Get("type").String() == "function" {
|
|
||||||
functionTypeResult := toolResult.Get("function")
|
|
||||||
if functionTypeResult.Exists() && functionTypeResult.IsObject() {
|
|
||||||
var functionDeclaration any
|
|
||||||
if err := json.Unmarshal([]byte(functionTypeResult.Raw), &functionDeclaration); err == nil {
|
|
||||||
tools[0].FunctionDeclarations = append(tools[0].FunctionDeclarations, functionDeclaration)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
tools = make([]client.ToolDeclaration, 0)
|
|
||||||
}
|
|
||||||
|
|
||||||
return modelName, systemInstruction, contents, tools
|
|
||||||
}
|
|
||||||
|
|
||||||
// FunctionCallGroup represents a group of function calls and their responses
|
|
||||||
type FunctionCallGroup struct {
|
|
||||||
ModelContent map[string]interface{}
|
|
||||||
FunctionCalls []gjson.Result
|
|
||||||
ResponsesNeeded int
|
|
||||||
}
|
|
||||||
|
|
||||||
// FixCLIToolResponse performs sophisticated tool response format conversion and grouping.
|
|
||||||
// This function transforms the CLI tool response format by intelligently grouping function calls
|
|
||||||
// with their corresponding responses, ensuring proper conversation flow and API compatibility.
|
|
||||||
// It converts from a linear format (1.json) to a grouped format (2.json) where function calls
|
|
||||||
// and their responses are properly associated and structured.
|
|
||||||
func FixCLIToolResponse(input string) (string, error) {
|
|
||||||
// Parse the input JSON to extract the conversation structure
|
|
||||||
parsed := gjson.Parse(input)
|
|
||||||
|
|
||||||
// Extract the contents array which contains the conversation messages
|
|
||||||
contents := parsed.Get("request.contents")
|
|
||||||
if !contents.Exists() {
|
|
||||||
// log.Debugf(input)
|
|
||||||
return input, fmt.Errorf("contents not found in input")
|
|
||||||
}
|
|
||||||
|
|
||||||
// Initialize data structures for processing and grouping
|
|
||||||
var newContents []interface{} // Final processed contents array
|
|
||||||
var pendingGroups []*FunctionCallGroup // Groups awaiting completion with responses
|
|
||||||
var collectedResponses []gjson.Result // Standalone responses to be matched
|
|
||||||
|
|
||||||
// Process each content object in the conversation
|
|
||||||
// This iterates through messages and groups function calls with their responses
|
|
||||||
contents.ForEach(func(key, value gjson.Result) bool {
|
|
||||||
role := value.Get("role").String()
|
|
||||||
parts := value.Get("parts")
|
|
||||||
|
|
||||||
// Check if this content has function responses
|
|
||||||
var responsePartsInThisContent []gjson.Result
|
|
||||||
parts.ForEach(func(_, part gjson.Result) bool {
|
|
||||||
if part.Get("functionResponse").Exists() {
|
|
||||||
responsePartsInThisContent = append(responsePartsInThisContent, part)
|
|
||||||
}
|
|
||||||
return true
|
|
||||||
})
|
|
||||||
|
|
||||||
// If this content has function responses, collect them
|
|
||||||
if len(responsePartsInThisContent) > 0 {
|
|
||||||
collectedResponses = append(collectedResponses, responsePartsInThisContent...)
|
|
||||||
|
|
||||||
// Check if any pending groups can be satisfied
|
|
||||||
for i := len(pendingGroups) - 1; i >= 0; i-- {
|
|
||||||
group := pendingGroups[i]
|
|
||||||
if len(collectedResponses) >= group.ResponsesNeeded {
|
|
||||||
// Take the needed responses for this group
|
|
||||||
groupResponses := collectedResponses[:group.ResponsesNeeded]
|
|
||||||
collectedResponses = collectedResponses[group.ResponsesNeeded:]
|
|
||||||
|
|
||||||
// Create merged function response content
|
|
||||||
var responseParts []interface{}
|
|
||||||
for _, response := range groupResponses {
|
|
||||||
var responseMap map[string]interface{}
|
|
||||||
errUnmarshal := json.Unmarshal([]byte(response.Raw), &responseMap)
|
|
||||||
if errUnmarshal != nil {
|
|
||||||
log.Warnf("failed to unmarshal function response: %v\n", errUnmarshal)
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
responseParts = append(responseParts, responseMap)
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(responseParts) > 0 {
|
|
||||||
functionResponseContent := map[string]interface{}{
|
|
||||||
"parts": responseParts,
|
|
||||||
"role": "function",
|
|
||||||
}
|
|
||||||
newContents = append(newContents, functionResponseContent)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Remove this group as it's been satisfied
|
|
||||||
pendingGroups = append(pendingGroups[:i], pendingGroups[i+1:]...)
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return true // Skip adding this content, responses are merged
|
|
||||||
}
|
|
||||||
|
|
||||||
// If this is a model with function calls, create a new group
|
|
||||||
if role == "model" {
|
|
||||||
var functionCallsInThisModel []gjson.Result
|
|
||||||
parts.ForEach(func(_, part gjson.Result) bool {
|
|
||||||
if part.Get("functionCall").Exists() {
|
|
||||||
functionCallsInThisModel = append(functionCallsInThisModel, part)
|
|
||||||
}
|
|
||||||
return true
|
|
||||||
})
|
|
||||||
|
|
||||||
if len(functionCallsInThisModel) > 0 {
|
|
||||||
// Add the model content
|
|
||||||
var contentMap map[string]interface{}
|
|
||||||
errUnmarshal := json.Unmarshal([]byte(value.Raw), &contentMap)
|
|
||||||
if errUnmarshal != nil {
|
|
||||||
log.Warnf("failed to unmarshal model content: %v\n", errUnmarshal)
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
newContents = append(newContents, contentMap)
|
|
||||||
|
|
||||||
// Create a new group for tracking responses
|
|
||||||
group := &FunctionCallGroup{
|
|
||||||
ModelContent: contentMap,
|
|
||||||
FunctionCalls: functionCallsInThisModel,
|
|
||||||
ResponsesNeeded: len(functionCallsInThisModel),
|
|
||||||
}
|
|
||||||
pendingGroups = append(pendingGroups, group)
|
|
||||||
} else {
|
|
||||||
// Regular model content without function calls
|
|
||||||
var contentMap map[string]interface{}
|
|
||||||
errUnmarshal := json.Unmarshal([]byte(value.Raw), &contentMap)
|
|
||||||
if errUnmarshal != nil {
|
|
||||||
log.Warnf("failed to unmarshal content: %v\n", errUnmarshal)
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
newContents = append(newContents, contentMap)
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
// Non-model content (user, etc.)
|
|
||||||
var contentMap map[string]interface{}
|
|
||||||
errUnmarshal := json.Unmarshal([]byte(value.Raw), &contentMap)
|
|
||||||
if errUnmarshal != nil {
|
|
||||||
log.Warnf("failed to unmarshal content: %v\n", errUnmarshal)
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
newContents = append(newContents, contentMap)
|
|
||||||
}
|
|
||||||
|
|
||||||
return true
|
|
||||||
})
|
|
||||||
|
|
||||||
// Handle any remaining pending groups with remaining responses
|
|
||||||
for _, group := range pendingGroups {
|
|
||||||
if len(collectedResponses) >= group.ResponsesNeeded {
|
|
||||||
groupResponses := collectedResponses[:group.ResponsesNeeded]
|
|
||||||
collectedResponses = collectedResponses[group.ResponsesNeeded:]
|
|
||||||
|
|
||||||
var responseParts []interface{}
|
|
||||||
for _, response := range groupResponses {
|
|
||||||
var responseMap map[string]interface{}
|
|
||||||
errUnmarshal := json.Unmarshal([]byte(response.Raw), &responseMap)
|
|
||||||
if errUnmarshal != nil {
|
|
||||||
log.Warnf("failed to unmarshal function response: %v\n", errUnmarshal)
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
responseParts = append(responseParts, responseMap)
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(responseParts) > 0 {
|
|
||||||
functionResponseContent := map[string]interface{}{
|
|
||||||
"parts": responseParts,
|
|
||||||
"role": "function",
|
|
||||||
}
|
|
||||||
newContents = append(newContents, functionResponseContent)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Update the original JSON with the new contents
|
|
||||||
result := input
|
|
||||||
newContentsJSON, _ := json.Marshal(newContents)
|
|
||||||
result, _ = sjson.Set(result, "request.contents", json.RawMessage(newContentsJSON))
|
|
||||||
|
|
||||||
return result, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func PrepareClaudeRequest(rawJson []byte) (string, *client.Content, []client.Content, []client.ToolDeclaration) {
|
|
||||||
var pathsToDelete []string
|
|
||||||
root := gjson.ParseBytes(rawJson)
|
|
||||||
walk(root, "", "additionalProperties", &pathsToDelete)
|
|
||||||
walk(root, "", "$schema", &pathsToDelete)
|
|
||||||
|
|
||||||
var err error
|
|
||||||
for _, p := range pathsToDelete {
|
|
||||||
rawJson, err = sjson.DeleteBytes(rawJson, p)
|
|
||||||
if err != nil {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
}
|
|
||||||
rawJson = bytes.Replace(rawJson, []byte(`"url":{"type":"string","format":"uri",`), []byte(`"url":{"type":"string",`), -1)
|
|
||||||
|
|
||||||
// log.Debug(string(rawJson))
|
|
||||||
modelName := "gemini-2.5-pro"
|
|
||||||
modelResult := gjson.GetBytes(rawJson, "model")
|
|
||||||
if modelResult.Type == gjson.String {
|
|
||||||
modelName = modelResult.String()
|
|
||||||
}
|
|
||||||
|
|
||||||
contents := make([]client.Content, 0)
|
|
||||||
|
|
||||||
var systemInstruction *client.Content
|
|
||||||
|
|
||||||
systemResult := gjson.GetBytes(rawJson, "system")
|
|
||||||
if systemResult.IsArray() {
|
|
||||||
systemResults := systemResult.Array()
|
|
||||||
systemInstruction = &client.Content{Role: "user", Parts: []client.Part{}}
|
|
||||||
for i := 0; i < len(systemResults); i++ {
|
|
||||||
systemPromptResult := systemResults[i]
|
|
||||||
systemTypePromptResult := systemPromptResult.Get("type")
|
|
||||||
if systemTypePromptResult.Type == gjson.String && systemTypePromptResult.String() == "text" {
|
|
||||||
systemPrompt := systemPromptResult.Get("text").String()
|
|
||||||
systemPart := client.Part{Text: systemPrompt}
|
|
||||||
systemInstruction.Parts = append(systemInstruction.Parts, systemPart)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if len(systemInstruction.Parts) == 0 {
|
|
||||||
systemInstruction = nil
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
messagesResult := gjson.GetBytes(rawJson, "messages")
|
|
||||||
if messagesResult.IsArray() {
|
|
||||||
messageResults := messagesResult.Array()
|
|
||||||
for i := 0; i < len(messageResults); i++ {
|
|
||||||
messageResult := messageResults[i]
|
|
||||||
roleResult := messageResult.Get("role")
|
|
||||||
if roleResult.Type != gjson.String {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
role := roleResult.String()
|
|
||||||
if role == "assistant" {
|
|
||||||
role = "model"
|
|
||||||
}
|
|
||||||
clientContent := client.Content{Role: role, Parts: []client.Part{}}
|
|
||||||
|
|
||||||
contentsResult := messageResult.Get("content")
|
|
||||||
if contentsResult.IsArray() {
|
|
||||||
contentResults := contentsResult.Array()
|
|
||||||
for j := 0; j < len(contentResults); j++ {
|
|
||||||
contentResult := contentResults[j]
|
|
||||||
contentTypeResult := contentResult.Get("type")
|
|
||||||
if contentTypeResult.Type == gjson.String && contentTypeResult.String() == "text" {
|
|
||||||
prompt := contentResult.Get("text").String()
|
|
||||||
clientContent.Parts = append(clientContent.Parts, client.Part{Text: prompt})
|
|
||||||
} else if contentTypeResult.Type == gjson.String && contentTypeResult.String() == "tool_use" {
|
|
||||||
functionName := contentResult.Get("name").String()
|
|
||||||
functionArgs := contentResult.Get("input").String()
|
|
||||||
var args map[string]any
|
|
||||||
if err = json.Unmarshal([]byte(functionArgs), &args); err == nil {
|
|
||||||
clientContent.Parts = append(clientContent.Parts, client.Part{
|
|
||||||
FunctionCall: &client.FunctionCall{
|
|
||||||
Name: functionName,
|
|
||||||
Args: args,
|
|
||||||
},
|
|
||||||
})
|
|
||||||
}
|
|
||||||
} else if contentTypeResult.Type == gjson.String && contentTypeResult.String() == "tool_result" {
|
|
||||||
toolCallID := contentResult.Get("tool_use_id").String()
|
|
||||||
if toolCallID != "" {
|
|
||||||
funcName := toolCallID
|
|
||||||
toolCallIDs := strings.Split(toolCallID, "-")
|
|
||||||
if len(toolCallIDs) > 1 {
|
|
||||||
funcName = strings.Join(toolCallIDs[0:len(toolCallIDs)-1], "-")
|
|
||||||
}
|
|
||||||
responseData := contentResult.Get("content").String()
|
|
||||||
functionResponse := client.FunctionResponse{Name: funcName, Response: map[string]interface{}{"result": responseData}}
|
|
||||||
clientContent.Parts = append(clientContent.Parts, client.Part{FunctionResponse: &functionResponse})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
contents = append(contents, clientContent)
|
|
||||||
} else if contentsResult.Type == gjson.String {
|
|
||||||
prompt := contentsResult.String()
|
|
||||||
contents = append(contents, client.Content{Role: role, Parts: []client.Part{{Text: prompt}}})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
var tools []client.ToolDeclaration
|
|
||||||
toolsResult := gjson.GetBytes(rawJson, "tools")
|
|
||||||
if toolsResult.IsArray() {
|
|
||||||
tools = make([]client.ToolDeclaration, 1)
|
|
||||||
tools[0].FunctionDeclarations = make([]any, 0)
|
|
||||||
toolsResults := toolsResult.Array()
|
|
||||||
for i := 0; i < len(toolsResults); i++ {
|
|
||||||
toolResult := toolsResults[i]
|
|
||||||
inputSchemaResult := toolResult.Get("input_schema")
|
|
||||||
if inputSchemaResult.Exists() && inputSchemaResult.IsObject() {
|
|
||||||
inputSchema := inputSchemaResult.Raw
|
|
||||||
inputSchema, _ = sjson.Delete(inputSchema, "additionalProperties")
|
|
||||||
inputSchema, _ = sjson.Delete(inputSchema, "$schema")
|
|
||||||
|
|
||||||
tool, _ := sjson.Delete(toolResult.Raw, "input_schema")
|
|
||||||
tool, _ = sjson.SetRaw(tool, "parameters", inputSchema)
|
|
||||||
var toolDeclaration any
|
|
||||||
if err = json.Unmarshal([]byte(tool), &toolDeclaration); err == nil {
|
|
||||||
tools[0].FunctionDeclarations = append(tools[0].FunctionDeclarations, toolDeclaration)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
tools = make([]client.ToolDeclaration, 0)
|
|
||||||
}
|
|
||||||
|
|
||||||
return modelName, systemInstruction, contents, tools
|
|
||||||
}
|
|
||||||
|
|
||||||
func walk(value gjson.Result, path, field string, pathsToDelete *[]string) {
|
|
||||||
switch value.Type {
|
|
||||||
case gjson.JSON:
|
|
||||||
value.ForEach(func(key, val gjson.Result) bool {
|
|
||||||
var childPath string
|
|
||||||
if path == "" {
|
|
||||||
childPath = key.String()
|
|
||||||
} else {
|
|
||||||
childPath = path + "." + key.String()
|
|
||||||
}
|
|
||||||
if key.String() == field {
|
|
||||||
*pathsToDelete = append(*pathsToDelete, childPath)
|
|
||||||
}
|
|
||||||
walk(val, childPath, field, pathsToDelete)
|
|
||||||
return true
|
|
||||||
})
|
|
||||||
case gjson.String, gjson.Number, gjson.True, gjson.False, gjson.Null:
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@@ -1,382 +0,0 @@
|
|||||||
package translator
|
|
||||||
|
|
||||||
import (
|
|
||||||
"bytes"
|
|
||||||
"fmt"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/tidwall/gjson"
|
|
||||||
"github.com/tidwall/sjson"
|
|
||||||
)
|
|
||||||
|
|
||||||
// ConvertCliToOpenAI translates a single chunk of a streaming response from the
|
|
||||||
// backend client format to the OpenAI Server-Sent Events (SSE) format.
|
|
||||||
// It returns an empty string if the chunk contains no useful data.
|
|
||||||
func ConvertCliToOpenAI(rawJson []byte, unixTimestamp int64, isGlAPIKey bool) string {
|
|
||||||
if isGlAPIKey {
|
|
||||||
rawJson, _ = sjson.SetRawBytes(rawJson, "response", rawJson)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Initialize the OpenAI SSE template.
|
|
||||||
template := `{"id":"","object":"chat.completion.chunk","created":12345,"model":"model","choices":[{"index":0,"delta":{"role":null,"content":null,"reasoning_content":null,"tool_calls":null},"finish_reason":null,"native_finish_reason":null}]}`
|
|
||||||
|
|
||||||
// Extract and set the model version.
|
|
||||||
if modelVersionResult := gjson.GetBytes(rawJson, "response.modelVersion"); modelVersionResult.Exists() {
|
|
||||||
template, _ = sjson.Set(template, "model", modelVersionResult.String())
|
|
||||||
}
|
|
||||||
|
|
||||||
// Extract and set the creation timestamp.
|
|
||||||
if createTimeResult := gjson.GetBytes(rawJson, "response.createTime"); createTimeResult.Exists() {
|
|
||||||
t, err := time.Parse(time.RFC3339Nano, createTimeResult.String())
|
|
||||||
if err == nil {
|
|
||||||
unixTimestamp = t.Unix()
|
|
||||||
}
|
|
||||||
template, _ = sjson.Set(template, "created", unixTimestamp)
|
|
||||||
} else {
|
|
||||||
template, _ = sjson.Set(template, "created", unixTimestamp)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Extract and set the response ID.
|
|
||||||
if responseIdResult := gjson.GetBytes(rawJson, "response.responseId"); responseIdResult.Exists() {
|
|
||||||
template, _ = sjson.Set(template, "id", responseIdResult.String())
|
|
||||||
}
|
|
||||||
|
|
||||||
// Extract and set the finish reason.
|
|
||||||
if finishReasonResult := gjson.GetBytes(rawJson, "response.candidates.0.finishReason"); finishReasonResult.Exists() {
|
|
||||||
template, _ = sjson.Set(template, "choices.0.finish_reason", finishReasonResult.String())
|
|
||||||
template, _ = sjson.Set(template, "choices.0.native_finish_reason", finishReasonResult.String())
|
|
||||||
}
|
|
||||||
|
|
||||||
// Extract and set usage metadata (token counts).
|
|
||||||
if usageResult := gjson.GetBytes(rawJson, "response.usageMetadata"); usageResult.Exists() {
|
|
||||||
if candidatesTokenCountResult := usageResult.Get("candidatesTokenCount"); candidatesTokenCountResult.Exists() {
|
|
||||||
template, _ = sjson.Set(template, "usage.completion_tokens", candidatesTokenCountResult.Int())
|
|
||||||
}
|
|
||||||
if totalTokenCountResult := usageResult.Get("totalTokenCount"); totalTokenCountResult.Exists() {
|
|
||||||
template, _ = sjson.Set(template, "usage.total_tokens", totalTokenCountResult.Int())
|
|
||||||
}
|
|
||||||
promptTokenCount := usageResult.Get("promptTokenCount").Int()
|
|
||||||
thoughtsTokenCount := usageResult.Get("thoughtsTokenCount").Int()
|
|
||||||
template, _ = sjson.Set(template, "usage.prompt_tokens", promptTokenCount+thoughtsTokenCount)
|
|
||||||
if thoughtsTokenCount > 0 {
|
|
||||||
template, _ = sjson.Set(template, "usage.completion_tokens_details.reasoning_tokens", thoughtsTokenCount)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Process the main content part of the response.
|
|
||||||
partsResult := gjson.GetBytes(rawJson, "response.candidates.0.content.parts")
|
|
||||||
if partsResult.IsArray() {
|
|
||||||
partResults := partsResult.Array()
|
|
||||||
for i := 0; i < len(partResults); i++ {
|
|
||||||
partResult := partResults[i]
|
|
||||||
partTextResult := partResult.Get("text")
|
|
||||||
functionCallResult := partResult.Get("functionCall")
|
|
||||||
|
|
||||||
if partTextResult.Exists() {
|
|
||||||
// Handle text content, distinguishing between regular content and reasoning/thoughts.
|
|
||||||
if partResult.Get("thought").Bool() {
|
|
||||||
template, _ = sjson.Set(template, "choices.0.delta.reasoning_content", partTextResult.String())
|
|
||||||
} else {
|
|
||||||
template, _ = sjson.Set(template, "choices.0.delta.content", partTextResult.String())
|
|
||||||
}
|
|
||||||
template, _ = sjson.Set(template, "choices.0.delta.role", "assistant")
|
|
||||||
} else if functionCallResult.Exists() {
|
|
||||||
// Handle function call content.
|
|
||||||
toolCallsResult := gjson.Get(template, "choices.0.delta.tool_calls")
|
|
||||||
if !toolCallsResult.Exists() || !toolCallsResult.IsArray() {
|
|
||||||
template, _ = sjson.SetRaw(template, "choices.0.delta.tool_calls", `[]`)
|
|
||||||
}
|
|
||||||
|
|
||||||
functionCallTemplate := `{"id": "","type": "function","function": {"name": "","arguments": ""}}`
|
|
||||||
fcName := functionCallResult.Get("name").String()
|
|
||||||
functionCallTemplate, _ = sjson.Set(functionCallTemplate, "id", fmt.Sprintf("%s-%d", fcName, time.Now().UnixNano()))
|
|
||||||
functionCallTemplate, _ = sjson.Set(functionCallTemplate, "function.name", fcName)
|
|
||||||
if fcArgsResult := functionCallResult.Get("args"); fcArgsResult.Exists() {
|
|
||||||
functionCallTemplate, _ = sjson.Set(functionCallTemplate, "function.arguments", fcArgsResult.Raw)
|
|
||||||
}
|
|
||||||
template, _ = sjson.Set(template, "choices.0.delta.role", "assistant")
|
|
||||||
template, _ = sjson.SetRaw(template, "choices.0.message.tool_calls.-1", functionCallTemplate)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return template
|
|
||||||
}
|
|
||||||
|
|
||||||
// ConvertCliToOpenAINonStream aggregates response from the backend client
|
|
||||||
// convert a single, non-streaming OpenAI-compatible JSON response.
|
|
||||||
func ConvertCliToOpenAINonStream(rawJson []byte, unixTimestamp int64, isGlAPIKey bool) string {
|
|
||||||
if isGlAPIKey {
|
|
||||||
rawJson, _ = sjson.SetRawBytes(rawJson, "response", rawJson)
|
|
||||||
}
|
|
||||||
template := `{"id":"","object":"chat.completion","created":123456,"model":"model","choices":[{"index":0,"message":{"role":"assistant","content":null,"reasoning_content":null,"tool_calls":null},"finish_reason":null,"native_finish_reason":null}]}`
|
|
||||||
if modelVersionResult := gjson.GetBytes(rawJson, "response.modelVersion"); modelVersionResult.Exists() {
|
|
||||||
template, _ = sjson.Set(template, "model", modelVersionResult.String())
|
|
||||||
}
|
|
||||||
|
|
||||||
if createTimeResult := gjson.GetBytes(rawJson, "response.createTime"); createTimeResult.Exists() {
|
|
||||||
t, err := time.Parse(time.RFC3339Nano, createTimeResult.String())
|
|
||||||
if err == nil {
|
|
||||||
unixTimestamp = t.Unix()
|
|
||||||
}
|
|
||||||
template, _ = sjson.Set(template, "created", unixTimestamp)
|
|
||||||
} else {
|
|
||||||
template, _ = sjson.Set(template, "created", unixTimestamp)
|
|
||||||
}
|
|
||||||
|
|
||||||
if responseIdResult := gjson.GetBytes(rawJson, "response.responseId"); responseIdResult.Exists() {
|
|
||||||
template, _ = sjson.Set(template, "id", responseIdResult.String())
|
|
||||||
}
|
|
||||||
|
|
||||||
if finishReasonResult := gjson.GetBytes(rawJson, "response.candidates.0.finishReason"); finishReasonResult.Exists() {
|
|
||||||
template, _ = sjson.Set(template, "choices.0.finish_reason", finishReasonResult.String())
|
|
||||||
template, _ = sjson.Set(template, "choices.0.native_finish_reason", finishReasonResult.String())
|
|
||||||
}
|
|
||||||
|
|
||||||
if usageResult := gjson.GetBytes(rawJson, "response.usageMetadata"); usageResult.Exists() {
|
|
||||||
if candidatesTokenCountResult := usageResult.Get("candidatesTokenCount"); candidatesTokenCountResult.Exists() {
|
|
||||||
template, _ = sjson.Set(template, "usage.completion_tokens", candidatesTokenCountResult.Int())
|
|
||||||
}
|
|
||||||
if totalTokenCountResult := usageResult.Get("totalTokenCount"); totalTokenCountResult.Exists() {
|
|
||||||
template, _ = sjson.Set(template, "usage.total_tokens", totalTokenCountResult.Int())
|
|
||||||
}
|
|
||||||
promptTokenCount := usageResult.Get("promptTokenCount").Int()
|
|
||||||
thoughtsTokenCount := usageResult.Get("thoughtsTokenCount").Int()
|
|
||||||
template, _ = sjson.Set(template, "usage.prompt_tokens", promptTokenCount+thoughtsTokenCount)
|
|
||||||
if thoughtsTokenCount > 0 {
|
|
||||||
template, _ = sjson.Set(template, "usage.completion_tokens_details.reasoning_tokens", thoughtsTokenCount)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Process the main content part of the response.
|
|
||||||
partsResult := gjson.GetBytes(rawJson, "response.candidates.0.content.parts")
|
|
||||||
if partsResult.IsArray() {
|
|
||||||
partsResults := partsResult.Array()
|
|
||||||
for i := 0; i < len(partsResults); i++ {
|
|
||||||
partResult := partsResults[i]
|
|
||||||
partTextResult := partResult.Get("text")
|
|
||||||
functionCallResult := partResult.Get("functionCall")
|
|
||||||
|
|
||||||
if partTextResult.Exists() {
|
|
||||||
// Append text content, distinguishing between regular content and reasoning.
|
|
||||||
if partResult.Get("thought").Bool() {
|
|
||||||
template, _ = sjson.Set(template, "choices.0.message.reasoning_content", partTextResult.String())
|
|
||||||
} else {
|
|
||||||
template, _ = sjson.Set(template, "choices.0.message.content", partTextResult.String())
|
|
||||||
}
|
|
||||||
template, _ = sjson.Set(template, "choices.0.message.role", "assistant")
|
|
||||||
} else if functionCallResult.Exists() {
|
|
||||||
// Append function call content to the tool_calls array.
|
|
||||||
toolCallsResult := gjson.Get(template, "choices.0.message.tool_calls")
|
|
||||||
if !toolCallsResult.Exists() || !toolCallsResult.IsArray() {
|
|
||||||
template, _ = sjson.SetRaw(template, "choices.0.message.tool_calls", `[]`)
|
|
||||||
}
|
|
||||||
functionCallItemTemplate := `{"id": "","type": "function","function": {"name": "","arguments": ""}}`
|
|
||||||
fcName := functionCallResult.Get("name").String()
|
|
||||||
functionCallItemTemplate, _ = sjson.Set(functionCallItemTemplate, "id", fmt.Sprintf("%s-%d", fcName, time.Now().UnixNano()))
|
|
||||||
functionCallItemTemplate, _ = sjson.Set(functionCallItemTemplate, "function.name", fcName)
|
|
||||||
if fcArgsResult := functionCallResult.Get("args"); fcArgsResult.Exists() {
|
|
||||||
functionCallItemTemplate, _ = sjson.Set(functionCallItemTemplate, "function.arguments", fcArgsResult.Raw)
|
|
||||||
}
|
|
||||||
template, _ = sjson.Set(template, "choices.0.message.role", "assistant")
|
|
||||||
template, _ = sjson.SetRaw(template, "choices.0.message.tool_calls.-1", functionCallItemTemplate)
|
|
||||||
} else {
|
|
||||||
// If no usable content is found, return an empty string.
|
|
||||||
return ""
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return template
|
|
||||||
}
|
|
||||||
|
|
||||||
// ConvertCliToClaude performs sophisticated streaming response format conversion.
|
|
||||||
// This function implements a complex state machine that translates backend client responses
|
|
||||||
// into Claude-compatible Server-Sent Events (SSE) format. It manages different response types
|
|
||||||
// and handles state transitions between content blocks, thinking processes, and function calls.
|
|
||||||
//
|
|
||||||
// Response type states: 0=none, 1=content, 2=thinking, 3=function
|
|
||||||
// The function maintains state across multiple calls to ensure proper SSE event sequencing.
|
|
||||||
func ConvertCliToClaude(rawJson []byte, isGlAPIKey, hasFirstResponse bool, responseType, responseIndex *int) string {
|
|
||||||
// Normalize the response format for different API key types
|
|
||||||
// Generative Language API keys have a different response structure
|
|
||||||
if isGlAPIKey {
|
|
||||||
rawJson, _ = sjson.SetRawBytes(rawJson, "response", rawJson)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Track whether tools are being used in this response chunk
|
|
||||||
usedTool := false
|
|
||||||
output := ""
|
|
||||||
|
|
||||||
// Initialize the streaming session with a message_start event
|
|
||||||
// This is only sent for the very first response chunk
|
|
||||||
if !hasFirstResponse {
|
|
||||||
output = "event: message_start\n"
|
|
||||||
|
|
||||||
// Create the initial message structure with default values
|
|
||||||
// This follows the Claude API specification for streaming message initialization
|
|
||||||
messageStartTemplate := `{"type": "message_start", "message": {"id": "msg_1nZdL29xx5MUA1yADyHTEsnR8uuvGzszyY", "type": "message", "role": "assistant", "content": [], "model": "claude-3-5-sonnet-20241022", "stop_reason": null, "stop_sequence": null, "usage": {"input_tokens": 0, "output_tokens": 0}}}`
|
|
||||||
|
|
||||||
// Override default values with actual response metadata if available
|
|
||||||
if modelVersionResult := gjson.GetBytes(rawJson, "response.modelVersion"); modelVersionResult.Exists() {
|
|
||||||
messageStartTemplate, _ = sjson.Set(messageStartTemplate, "message.model", modelVersionResult.String())
|
|
||||||
}
|
|
||||||
if responseIdResult := gjson.GetBytes(rawJson, "response.responseId"); responseIdResult.Exists() {
|
|
||||||
messageStartTemplate, _ = sjson.Set(messageStartTemplate, "message.id", responseIdResult.String())
|
|
||||||
}
|
|
||||||
output = output + fmt.Sprintf("data: %s\n\n\n", messageStartTemplate)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Process the response parts array from the backend client
|
|
||||||
// Each part can contain text content, thinking content, or function calls
|
|
||||||
partsResult := gjson.GetBytes(rawJson, "response.candidates.0.content.parts")
|
|
||||||
if partsResult.IsArray() {
|
|
||||||
partResults := partsResult.Array()
|
|
||||||
for i := 0; i < len(partResults); i++ {
|
|
||||||
partResult := partResults[i]
|
|
||||||
|
|
||||||
// Extract the different types of content from each part
|
|
||||||
partTextResult := partResult.Get("text")
|
|
||||||
functionCallResult := partResult.Get("functionCall")
|
|
||||||
|
|
||||||
// Handle text content (both regular content and thinking)
|
|
||||||
if partTextResult.Exists() {
|
|
||||||
// Process thinking content (internal reasoning)
|
|
||||||
if partResult.Get("thought").Bool() {
|
|
||||||
// Continue existing thinking block
|
|
||||||
if *responseType == 2 {
|
|
||||||
output = output + "event: content_block_delta\n"
|
|
||||||
data, _ := sjson.Set(fmt.Sprintf(`{"type":"content_block_delta","index":%d,"delta":{"type":"thinking_delta","thinking":""}}`, *responseIndex), "delta.thinking", partTextResult.String())
|
|
||||||
output = output + fmt.Sprintf("data: %s\n\n\n", data)
|
|
||||||
} else {
|
|
||||||
// Transition from another state to thinking
|
|
||||||
// First, close any existing content block
|
|
||||||
if *responseType != 0 {
|
|
||||||
if *responseType == 2 {
|
|
||||||
output = output + "event: content_block_delta\n"
|
|
||||||
output = output + fmt.Sprintf(`data: {"type":"content_block_delta","index":%d,"delta":{"type":"signature_delta","signature":null}}`, *responseIndex)
|
|
||||||
output = output + "\n\n\n"
|
|
||||||
}
|
|
||||||
output = output + "event: content_block_stop\n"
|
|
||||||
output = output + fmt.Sprintf(`data: {"type":"content_block_stop","index":%d}`, *responseIndex)
|
|
||||||
output = output + "\n\n\n"
|
|
||||||
*responseIndex++
|
|
||||||
}
|
|
||||||
|
|
||||||
// Start a new thinking content block
|
|
||||||
output = output + "event: content_block_start\n"
|
|
||||||
output = output + fmt.Sprintf(`data: {"type":"content_block_start","index":%d,"content_block":{"type":"thinking","thinking":""}}`, *responseIndex)
|
|
||||||
output = output + "\n\n\n"
|
|
||||||
output = output + "event: content_block_delta\n"
|
|
||||||
data, _ := sjson.Set(fmt.Sprintf(`{"type":"content_block_delta","index":%d,"delta":{"type":"thinking_delta","thinking":""}}`, *responseIndex), "delta.thinking", partTextResult.String())
|
|
||||||
output = output + fmt.Sprintf("data: %s\n\n\n", data)
|
|
||||||
*responseType = 2 // Set state to thinking
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
// Process regular text content (user-visible output)
|
|
||||||
// Continue existing text block
|
|
||||||
if *responseType == 1 {
|
|
||||||
output = output + "event: content_block_delta\n"
|
|
||||||
data, _ := sjson.Set(fmt.Sprintf(`{"type":"content_block_delta","index":%d,"delta":{"type":"text_delta","text":""}}`, *responseIndex), "delta.text", partTextResult.String())
|
|
||||||
output = output + fmt.Sprintf("data: %s\n\n\n", data)
|
|
||||||
} else {
|
|
||||||
// Transition from another state to text content
|
|
||||||
// First, close any existing content block
|
|
||||||
if *responseType != 0 {
|
|
||||||
if *responseType == 2 {
|
|
||||||
output = output + "event: content_block_delta\n"
|
|
||||||
output = output + fmt.Sprintf(`data: {"type":"content_block_delta","index":%d,"delta":{"type":"signature_delta","signature":null}}`, *responseIndex)
|
|
||||||
output = output + "\n\n\n"
|
|
||||||
}
|
|
||||||
output = output + "event: content_block_stop\n"
|
|
||||||
output = output + fmt.Sprintf(`data: {"type":"content_block_stop","index":%d}`, *responseIndex)
|
|
||||||
output = output + "\n\n\n"
|
|
||||||
*responseIndex++
|
|
||||||
}
|
|
||||||
|
|
||||||
// Start a new text content block
|
|
||||||
output = output + "event: content_block_start\n"
|
|
||||||
output = output + fmt.Sprintf(`data: {"type":"content_block_start","index":%d,"content_block":{"type":"text","text":""}}`, *responseIndex)
|
|
||||||
output = output + "\n\n\n"
|
|
||||||
output = output + "event: content_block_delta\n"
|
|
||||||
data, _ := sjson.Set(fmt.Sprintf(`{"type":"content_block_delta","index":%d,"delta":{"type":"text_delta","text":""}}`, *responseIndex), "delta.text", partTextResult.String())
|
|
||||||
output = output + fmt.Sprintf("data: %s\n\n\n", data)
|
|
||||||
*responseType = 1 // Set state to content
|
|
||||||
}
|
|
||||||
}
|
|
||||||
} else if functionCallResult.Exists() {
|
|
||||||
// Handle function/tool calls from the AI model
|
|
||||||
// This processes tool usage requests and formats them for Claude API compatibility
|
|
||||||
usedTool = true
|
|
||||||
fcName := functionCallResult.Get("name").String()
|
|
||||||
|
|
||||||
// Handle state transitions when switching to function calls
|
|
||||||
// Close any existing function call block first
|
|
||||||
if *responseType == 3 {
|
|
||||||
output = output + "event: content_block_stop\n"
|
|
||||||
output = output + fmt.Sprintf(`data: {"type":"content_block_stop","index":%d}`, *responseIndex)
|
|
||||||
output = output + "\n\n\n"
|
|
||||||
*responseIndex++
|
|
||||||
*responseType = 0
|
|
||||||
}
|
|
||||||
|
|
||||||
// Special handling for thinking state transition
|
|
||||||
if *responseType == 2 {
|
|
||||||
output = output + "event: content_block_delta\n"
|
|
||||||
output = output + fmt.Sprintf(`data: {"type":"content_block_delta","index":%d,"delta":{"type":"signature_delta","signature":null}}`, *responseIndex)
|
|
||||||
output = output + "\n\n\n"
|
|
||||||
}
|
|
||||||
|
|
||||||
// Close any other existing content block
|
|
||||||
if *responseType != 0 {
|
|
||||||
output = output + "event: content_block_stop\n"
|
|
||||||
output = output + fmt.Sprintf(`data: {"type":"content_block_stop","index":%d}`, *responseIndex)
|
|
||||||
output = output + "\n\n\n"
|
|
||||||
*responseIndex++
|
|
||||||
}
|
|
||||||
|
|
||||||
// Start a new tool use content block
|
|
||||||
// This creates the structure for a function call in Claude format
|
|
||||||
output = output + "event: content_block_start\n"
|
|
||||||
|
|
||||||
// Create the tool use block with unique ID and function details
|
|
||||||
data := fmt.Sprintf(`{"type":"content_block_start","index":%d,"content_block":{"type":"tool_use","id":"","name":"","input":{}}}`, *responseIndex)
|
|
||||||
data, _ = sjson.Set(data, "content_block.id", fmt.Sprintf("%s-%d", fcName, time.Now().UnixNano()))
|
|
||||||
data, _ = sjson.Set(data, "content_block.name", fcName)
|
|
||||||
output = output + fmt.Sprintf("data: %s\n\n\n", data)
|
|
||||||
|
|
||||||
if fcArgsResult := functionCallResult.Get("args"); fcArgsResult.Exists() {
|
|
||||||
output = output + "event: content_block_delta\n"
|
|
||||||
data, _ = sjson.Set(fmt.Sprintf(`{"type":"content_block_delta","index":%d,"delta":{"type":"input_json_delta","partial_json":""}}`, *responseIndex), "delta.partial_json", fcArgsResult.Raw)
|
|
||||||
output = output + fmt.Sprintf("data: %s\n\n\n", data)
|
|
||||||
}
|
|
||||||
*responseType = 3
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
usageResult := gjson.GetBytes(rawJson, "response.usageMetadata")
|
|
||||||
if usageResult.Exists() && bytes.Contains(rawJson, []byte(`"finishReason"`)) {
|
|
||||||
if candidatesTokenCountResult := usageResult.Get("candidatesTokenCount"); candidatesTokenCountResult.Exists() {
|
|
||||||
output = output + "event: content_block_stop\n"
|
|
||||||
output = output + fmt.Sprintf(`data: {"type":"content_block_stop","index":%d}`, *responseIndex)
|
|
||||||
output = output + "\n\n\n"
|
|
||||||
|
|
||||||
output = output + "event: message_delta\n"
|
|
||||||
output = output + `data: `
|
|
||||||
|
|
||||||
template := `{"type":"message_delta","delta":{"stop_reason":"end_turn","stop_sequence":null},"usage":{"input_tokens":0,"output_tokens":0}}`
|
|
||||||
if usedTool {
|
|
||||||
template = `{"type":"message_delta","delta":{"stop_reason":"tool_use","stop_sequence":null},"usage":{"input_tokens":0,"output_tokens":0}}`
|
|
||||||
}
|
|
||||||
|
|
||||||
thoughtsTokenCount := usageResult.Get("thoughtsTokenCount").Int()
|
|
||||||
template, _ = sjson.Set(template, "usage.output_tokens", candidatesTokenCountResult.Int()+thoughtsTokenCount)
|
|
||||||
template, _ = sjson.Set(template, "usage.input_tokens", usageResult.Get("promptTokenCount").Int())
|
|
||||||
|
|
||||||
output = output + template + "\n\n\n"
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return output
|
|
||||||
}
|
|
||||||
@@ -1,3 +1,6 @@
|
|||||||
|
// Package auth provides OAuth2 authentication functionality for Google Cloud APIs.
|
||||||
|
// It handles the complete OAuth2 flow including token storage, web-based authentication,
|
||||||
|
// proxy support, and automatic token refresh. The package supports both SOCKS5 and HTTP/HTTPS proxies.
|
||||||
package auth
|
package auth
|
||||||
|
|
||||||
import (
|
import (
|
||||||
@@ -39,7 +42,7 @@ var (
|
|||||||
// initiating a new web-based OAuth flow if necessary, and refreshing tokens.
|
// initiating a new web-based OAuth flow if necessary, and refreshing tokens.
|
||||||
func GetAuthenticatedClient(ctx context.Context, ts *TokenStorage, cfg *config.Config) (*http.Client, error) {
|
func GetAuthenticatedClient(ctx context.Context, ts *TokenStorage, cfg *config.Config) (*http.Client, error) {
|
||||||
// Configure proxy settings for the HTTP client if a proxy URL is provided.
|
// Configure proxy settings for the HTTP client if a proxy URL is provided.
|
||||||
proxyURL, err := url.Parse(cfg.ProxyUrl)
|
proxyURL, err := url.Parse(cfg.ProxyURL)
|
||||||
if err == nil {
|
if err == nil {
|
||||||
var transport *http.Transport
|
var transport *http.Transport
|
||||||
if proxyURL.Scheme == "socks5" {
|
if proxyURL.Scheme == "socks5" {
|
||||||
|
|||||||
@@ -1,3 +1,7 @@
|
|||||||
|
// Package client provides HTTP client functionality for interacting with Google Cloud AI APIs.
|
||||||
|
// It handles OAuth2 authentication, token management, request/response processing,
|
||||||
|
// streaming communication, quota management, and automatic model fallback.
|
||||||
|
// The package supports both direct API key authentication and OAuth2 flows.
|
||||||
package client
|
package client
|
||||||
|
|
||||||
import (
|
import (
|
||||||
@@ -29,7 +33,7 @@ const (
|
|||||||
pluginVersion = "0.1.9"
|
pluginVersion = "0.1.9"
|
||||||
|
|
||||||
glEndPoint = "https://generativelanguage.googleapis.com"
|
glEndPoint = "https://generativelanguage.googleapis.com"
|
||||||
glApiVersion = "v1beta"
|
glAPIVersion = "v1beta"
|
||||||
)
|
)
|
||||||
|
|
||||||
var (
|
var (
|
||||||
@@ -64,30 +68,37 @@ func NewClient(httpClient *http.Client, ts *auth.TokenStorage, cfg *config.Confi
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// SetProjectID updates the project ID for the client's token storage.
|
||||||
func (c *Client) SetProjectID(projectID string) {
|
func (c *Client) SetProjectID(projectID string) {
|
||||||
c.tokenStorage.ProjectID = projectID
|
c.tokenStorage.ProjectID = projectID
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// SetIsAuto configures whether the client should operate in automatic mode.
|
||||||
func (c *Client) SetIsAuto(auto bool) {
|
func (c *Client) SetIsAuto(auto bool) {
|
||||||
c.tokenStorage.Auto = auto
|
c.tokenStorage.Auto = auto
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// SetIsChecked sets the checked status for the client's token storage.
|
||||||
func (c *Client) SetIsChecked(checked bool) {
|
func (c *Client) SetIsChecked(checked bool) {
|
||||||
c.tokenStorage.Checked = checked
|
c.tokenStorage.Checked = checked
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// IsChecked returns whether the client's token storage has been checked.
|
||||||
func (c *Client) IsChecked() bool {
|
func (c *Client) IsChecked() bool {
|
||||||
return c.tokenStorage.Checked
|
return c.tokenStorage.Checked
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// IsAuto returns whether the client is operating in automatic mode.
|
||||||
func (c *Client) IsAuto() bool {
|
func (c *Client) IsAuto() bool {
|
||||||
return c.tokenStorage.Auto
|
return c.tokenStorage.Auto
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// GetEmail returns the email address associated with the client's token storage.
|
||||||
func (c *Client) GetEmail() string {
|
func (c *Client) GetEmail() string {
|
||||||
return c.tokenStorage.Email
|
return c.tokenStorage.Email
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// GetProjectID returns the Google Cloud project ID from the client's token storage.
|
||||||
func (c *Client) GetProjectID() string {
|
func (c *Client) GetProjectID() string {
|
||||||
if c.tokenStorage != nil {
|
if c.tokenStorage != nil {
|
||||||
return c.tokenStorage.ProjectID
|
return c.tokenStorage.ProjectID
|
||||||
@@ -95,6 +106,7 @@ func (c *Client) GetProjectID() string {
|
|||||||
return ""
|
return ""
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// GetGenerativeLanguageAPIKey returns the generative language API key if configured.
|
||||||
func (c *Client) GetGenerativeLanguageAPIKey() string {
|
func (c *Client) GetGenerativeLanguageAPIKey() string {
|
||||||
return c.glAPIKey
|
return c.glAPIKey
|
||||||
}
|
}
|
||||||
@@ -267,10 +279,10 @@ func (c *Client) APIRequest(ctx context.Context, endpoint string, body interface
|
|||||||
} else {
|
} else {
|
||||||
if endpoint == "countTokens" {
|
if endpoint == "countTokens" {
|
||||||
modelResult := gjson.GetBytes(jsonBody, "model")
|
modelResult := gjson.GetBytes(jsonBody, "model")
|
||||||
url = fmt.Sprintf("%s/%s/models/%s:%s", glEndPoint, glApiVersion, modelResult.String(), endpoint)
|
url = fmt.Sprintf("%s/%s/models/%s:%s", glEndPoint, glAPIVersion, modelResult.String(), endpoint)
|
||||||
} else {
|
} else {
|
||||||
modelResult := gjson.GetBytes(jsonBody, "model")
|
modelResult := gjson.GetBytes(jsonBody, "model")
|
||||||
url = fmt.Sprintf("%s/%s/models/%s:%s", glEndPoint, glApiVersion, modelResult.String(), endpoint)
|
url = fmt.Sprintf("%s/%s/models/%s:%s", glEndPoint, glAPIVersion, modelResult.String(), endpoint)
|
||||||
if alt == "" && stream {
|
if alt == "" && stream {
|
||||||
url = url + "?alt=sse"
|
url = url + "?alt=sse"
|
||||||
} else {
|
} else {
|
||||||
@@ -333,7 +345,7 @@ func (c *Client) APIRequest(ctx context.Context, endpoint string, body interface
|
|||||||
}
|
}
|
||||||
|
|
||||||
// SendMessage handles a single conversational turn, including tool calls.
|
// SendMessage handles a single conversational turn, including tool calls.
|
||||||
func (c *Client) SendMessage(ctx context.Context, rawJson []byte, model string, systemInstruction *Content, contents []Content, tools []ToolDeclaration) ([]byte, *ErrorMessage) {
|
func (c *Client) SendMessage(ctx context.Context, rawJSON []byte, model string, systemInstruction *Content, contents []Content, tools []ToolDeclaration) ([]byte, *ErrorMessage) {
|
||||||
request := GenerateContentRequest{
|
request := GenerateContentRequest{
|
||||||
Contents: contents,
|
Contents: contents,
|
||||||
GenerationConfig: GenerationConfig{
|
GenerationConfig: GenerationConfig{
|
||||||
@@ -357,7 +369,7 @@ func (c *Client) SendMessage(ctx context.Context, rawJson []byte, model string,
|
|||||||
|
|
||||||
// log.Debug(string(byteRequestBody))
|
// log.Debug(string(byteRequestBody))
|
||||||
|
|
||||||
reasoningEffortResult := gjson.GetBytes(rawJson, "reasoning_effort")
|
reasoningEffortResult := gjson.GetBytes(rawJSON, "reasoning_effort")
|
||||||
if reasoningEffortResult.String() == "none" {
|
if reasoningEffortResult.String() == "none" {
|
||||||
byteRequestBody, _ = sjson.DeleteBytes(byteRequestBody, "request.generationConfig.thinkingConfig.include_thoughts")
|
byteRequestBody, _ = sjson.DeleteBytes(byteRequestBody, "request.generationConfig.thinkingConfig.include_thoughts")
|
||||||
byteRequestBody, _ = sjson.SetBytes(byteRequestBody, "request.generationConfig.thinkingConfig.thinkingBudget", 0)
|
byteRequestBody, _ = sjson.SetBytes(byteRequestBody, "request.generationConfig.thinkingConfig.thinkingBudget", 0)
|
||||||
@@ -373,17 +385,17 @@ func (c *Client) SendMessage(ctx context.Context, rawJson []byte, model string,
|
|||||||
byteRequestBody, _ = sjson.SetBytes(byteRequestBody, "request.generationConfig.thinkingConfig.thinkingBudget", -1)
|
byteRequestBody, _ = sjson.SetBytes(byteRequestBody, "request.generationConfig.thinkingConfig.thinkingBudget", -1)
|
||||||
}
|
}
|
||||||
|
|
||||||
temperatureResult := gjson.GetBytes(rawJson, "temperature")
|
temperatureResult := gjson.GetBytes(rawJSON, "temperature")
|
||||||
if temperatureResult.Exists() && temperatureResult.Type == gjson.Number {
|
if temperatureResult.Exists() && temperatureResult.Type == gjson.Number {
|
||||||
byteRequestBody, _ = sjson.SetBytes(byteRequestBody, "request.generationConfig.temperature", temperatureResult.Num)
|
byteRequestBody, _ = sjson.SetBytes(byteRequestBody, "request.generationConfig.temperature", temperatureResult.Num)
|
||||||
}
|
}
|
||||||
|
|
||||||
topPResult := gjson.GetBytes(rawJson, "top_p")
|
topPResult := gjson.GetBytes(rawJSON, "top_p")
|
||||||
if topPResult.Exists() && topPResult.Type == gjson.Number {
|
if topPResult.Exists() && topPResult.Type == gjson.Number {
|
||||||
byteRequestBody, _ = sjson.SetBytes(byteRequestBody, "request.generationConfig.topP", topPResult.Num)
|
byteRequestBody, _ = sjson.SetBytes(byteRequestBody, "request.generationConfig.topP", topPResult.Num)
|
||||||
}
|
}
|
||||||
|
|
||||||
topKResult := gjson.GetBytes(rawJson, "top_k")
|
topKResult := gjson.GetBytes(rawJSON, "top_k")
|
||||||
if topKResult.Exists() && topKResult.Type == gjson.Number {
|
if topKResult.Exists() && topKResult.Type == gjson.Number {
|
||||||
byteRequestBody, _ = sjson.SetBytes(byteRequestBody, "request.generationConfig.topK", topKResult.Num)
|
byteRequestBody, _ = sjson.SetBytes(byteRequestBody, "request.generationConfig.topK", topKResult.Num)
|
||||||
}
|
}
|
||||||
@@ -430,7 +442,7 @@ func (c *Client) SendMessage(ctx context.Context, rawJson []byte, model string,
|
|||||||
// This function implements a sophisticated streaming system that supports tool calls, reasoning modes,
|
// This function implements a sophisticated streaming system that supports tool calls, reasoning modes,
|
||||||
// quota management, and automatic model fallback. It returns two channels for asynchronous communication:
|
// quota management, and automatic model fallback. It returns two channels for asynchronous communication:
|
||||||
// one for streaming response data and another for error handling.
|
// one for streaming response data and another for error handling.
|
||||||
func (c *Client) SendMessageStream(ctx context.Context, rawJson []byte, model string, systemInstruction *Content, contents []Content, tools []ToolDeclaration, includeThoughts ...bool) (<-chan []byte, <-chan *ErrorMessage) {
|
func (c *Client) SendMessageStream(ctx context.Context, rawJSON []byte, model string, systemInstruction *Content, contents []Content, tools []ToolDeclaration, includeThoughts ...bool) (<-chan []byte, <-chan *ErrorMessage) {
|
||||||
// Define the data prefix used in Server-Sent Events streaming format
|
// Define the data prefix used in Server-Sent Events streaming format
|
||||||
dataTag := []byte("data: ")
|
dataTag := []byte("data: ")
|
||||||
|
|
||||||
@@ -486,7 +498,7 @@ func (c *Client) SendMessageStream(ctx context.Context, rawJson []byte, model st
|
|||||||
|
|
||||||
// Parse and configure reasoning effort levels from the original request
|
// Parse and configure reasoning effort levels from the original request
|
||||||
// This maps Claude-style reasoning effort parameters to Gemini's thinking budget system
|
// This maps Claude-style reasoning effort parameters to Gemini's thinking budget system
|
||||||
reasoningEffortResult := gjson.GetBytes(rawJson, "reasoning_effort")
|
reasoningEffortResult := gjson.GetBytes(rawJSON, "reasoning_effort")
|
||||||
if reasoningEffortResult.String() == "none" {
|
if reasoningEffortResult.String() == "none" {
|
||||||
// Disable thinking entirely for fastest responses
|
// Disable thinking entirely for fastest responses
|
||||||
byteRequestBody, _ = sjson.DeleteBytes(byteRequestBody, "request.generationConfig.thinkingConfig.include_thoughts")
|
byteRequestBody, _ = sjson.DeleteBytes(byteRequestBody, "request.generationConfig.thinkingConfig.include_thoughts")
|
||||||
@@ -510,21 +522,21 @@ func (c *Client) SendMessageStream(ctx context.Context, rawJson []byte, model st
|
|||||||
|
|
||||||
// Configure temperature parameter for response randomness control
|
// Configure temperature parameter for response randomness control
|
||||||
// Temperature affects the creativity vs consistency trade-off in responses
|
// Temperature affects the creativity vs consistency trade-off in responses
|
||||||
temperatureResult := gjson.GetBytes(rawJson, "temperature")
|
temperatureResult := gjson.GetBytes(rawJSON, "temperature")
|
||||||
if temperatureResult.Exists() && temperatureResult.Type == gjson.Number {
|
if temperatureResult.Exists() && temperatureResult.Type == gjson.Number {
|
||||||
byteRequestBody, _ = sjson.SetBytes(byteRequestBody, "request.generationConfig.temperature", temperatureResult.Num)
|
byteRequestBody, _ = sjson.SetBytes(byteRequestBody, "request.generationConfig.temperature", temperatureResult.Num)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Configure top-p parameter for nucleus sampling
|
// Configure top-p parameter for nucleus sampling
|
||||||
// Controls the cumulative probability threshold for token selection
|
// Controls the cumulative probability threshold for token selection
|
||||||
topPResult := gjson.GetBytes(rawJson, "top_p")
|
topPResult := gjson.GetBytes(rawJSON, "top_p")
|
||||||
if topPResult.Exists() && topPResult.Type == gjson.Number {
|
if topPResult.Exists() && topPResult.Type == gjson.Number {
|
||||||
byteRequestBody, _ = sjson.SetBytes(byteRequestBody, "request.generationConfig.topP", topPResult.Num)
|
byteRequestBody, _ = sjson.SetBytes(byteRequestBody, "request.generationConfig.topP", topPResult.Num)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Configure top-k parameter for limiting token candidates
|
// Configure top-k parameter for limiting token candidates
|
||||||
// Restricts the model to consider only the top K most likely tokens
|
// Restricts the model to consider only the top K most likely tokens
|
||||||
topKResult := gjson.GetBytes(rawJson, "top_k")
|
topKResult := gjson.GetBytes(rawJSON, "top_k")
|
||||||
if topKResult.Exists() && topKResult.Type == gjson.Number {
|
if topKResult.Exists() && topKResult.Type == gjson.Number {
|
||||||
byteRequestBody, _ = sjson.SetBytes(byteRequestBody, "request.generationConfig.topK", topKResult.Num)
|
byteRequestBody, _ = sjson.SetBytes(byteRequestBody, "request.generationConfig.topK", topKResult.Num)
|
||||||
}
|
}
|
||||||
@@ -608,8 +620,8 @@ func (c *Client) SendMessageStream(ctx context.Context, rawJson []byte, model st
|
|||||||
}
|
}
|
||||||
|
|
||||||
// SendRawTokenCount handles a token count.
|
// SendRawTokenCount handles a token count.
|
||||||
func (c *Client) SendRawTokenCount(ctx context.Context, rawJson []byte, alt string) ([]byte, *ErrorMessage) {
|
func (c *Client) SendRawTokenCount(ctx context.Context, rawJSON []byte, alt string) ([]byte, *ErrorMessage) {
|
||||||
modelResult := gjson.GetBytes(rawJson, "model")
|
modelResult := gjson.GetBytes(rawJSON, "model")
|
||||||
model := modelResult.String()
|
model := modelResult.String()
|
||||||
modelName := model
|
modelName := model
|
||||||
for {
|
for {
|
||||||
@@ -618,7 +630,7 @@ func (c *Client) SendRawTokenCount(ctx context.Context, rawJson []byte, alt stri
|
|||||||
modelName = c.getPreviewModel(model)
|
modelName = c.getPreviewModel(model)
|
||||||
if modelName != "" {
|
if modelName != "" {
|
||||||
log.Debugf("Model %s is quota exceeded. Switch to preview model %s", model, modelName)
|
log.Debugf("Model %s is quota exceeded. Switch to preview model %s", model, modelName)
|
||||||
rawJson, _ = sjson.SetBytes(rawJson, "model", modelName)
|
rawJSON, _ = sjson.SetBytes(rawJSON, "model", modelName)
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -628,7 +640,7 @@ func (c *Client) SendRawTokenCount(ctx context.Context, rawJson []byte, alt stri
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
respBody, err := c.APIRequest(ctx, "countTokens", rawJson, alt, false)
|
respBody, err := c.APIRequest(ctx, "countTokens", rawJSON, alt, false)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if err.StatusCode == 429 {
|
if err.StatusCode == 429 {
|
||||||
now := time.Now()
|
now := time.Now()
|
||||||
@@ -649,12 +661,12 @@ func (c *Client) SendRawTokenCount(ctx context.Context, rawJson []byte, alt stri
|
|||||||
}
|
}
|
||||||
|
|
||||||
// SendRawMessage handles a single conversational turn, including tool calls.
|
// SendRawMessage handles a single conversational turn, including tool calls.
|
||||||
func (c *Client) SendRawMessage(ctx context.Context, rawJson []byte, alt string) ([]byte, *ErrorMessage) {
|
func (c *Client) SendRawMessage(ctx context.Context, rawJSON []byte, alt string) ([]byte, *ErrorMessage) {
|
||||||
if c.glAPIKey == "" {
|
if c.glAPIKey == "" {
|
||||||
rawJson, _ = sjson.SetBytes(rawJson, "project", c.GetProjectID())
|
rawJSON, _ = sjson.SetBytes(rawJSON, "project", c.GetProjectID())
|
||||||
}
|
}
|
||||||
|
|
||||||
modelResult := gjson.GetBytes(rawJson, "model")
|
modelResult := gjson.GetBytes(rawJSON, "model")
|
||||||
model := modelResult.String()
|
model := modelResult.String()
|
||||||
modelName := model
|
modelName := model
|
||||||
for {
|
for {
|
||||||
@@ -663,7 +675,7 @@ func (c *Client) SendRawMessage(ctx context.Context, rawJson []byte, alt string)
|
|||||||
modelName = c.getPreviewModel(model)
|
modelName = c.getPreviewModel(model)
|
||||||
if modelName != "" {
|
if modelName != "" {
|
||||||
log.Debugf("Model %s is quota exceeded. Switch to preview model %s", model, modelName)
|
log.Debugf("Model %s is quota exceeded. Switch to preview model %s", model, modelName)
|
||||||
rawJson, _ = sjson.SetBytes(rawJson, "model", modelName)
|
rawJSON, _ = sjson.SetBytes(rawJSON, "model", modelName)
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -673,7 +685,7 @@ func (c *Client) SendRawMessage(ctx context.Context, rawJson []byte, alt string)
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
respBody, err := c.APIRequest(ctx, "generateContent", rawJson, alt, false)
|
respBody, err := c.APIRequest(ctx, "generateContent", rawJSON, alt, false)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if err.StatusCode == 429 {
|
if err.StatusCode == 429 {
|
||||||
now := time.Now()
|
now := time.Now()
|
||||||
@@ -694,7 +706,7 @@ func (c *Client) SendRawMessage(ctx context.Context, rawJson []byte, alt string)
|
|||||||
}
|
}
|
||||||
|
|
||||||
// SendRawMessageStream handles a single conversational turn, including tool calls.
|
// SendRawMessageStream handles a single conversational turn, including tool calls.
|
||||||
func (c *Client) SendRawMessageStream(ctx context.Context, rawJson []byte, alt string) (<-chan []byte, <-chan *ErrorMessage) {
|
func (c *Client) SendRawMessageStream(ctx context.Context, rawJSON []byte, alt string) (<-chan []byte, <-chan *ErrorMessage) {
|
||||||
dataTag := []byte("data: ")
|
dataTag := []byte("data: ")
|
||||||
errChan := make(chan *ErrorMessage)
|
errChan := make(chan *ErrorMessage)
|
||||||
dataChan := make(chan []byte)
|
dataChan := make(chan []byte)
|
||||||
@@ -703,10 +715,10 @@ func (c *Client) SendRawMessageStream(ctx context.Context, rawJson []byte, alt s
|
|||||||
defer close(dataChan)
|
defer close(dataChan)
|
||||||
|
|
||||||
if c.glAPIKey == "" {
|
if c.glAPIKey == "" {
|
||||||
rawJson, _ = sjson.SetBytes(rawJson, "project", c.GetProjectID())
|
rawJSON, _ = sjson.SetBytes(rawJSON, "project", c.GetProjectID())
|
||||||
}
|
}
|
||||||
|
|
||||||
modelResult := gjson.GetBytes(rawJson, "model")
|
modelResult := gjson.GetBytes(rawJSON, "model")
|
||||||
model := modelResult.String()
|
model := modelResult.String()
|
||||||
modelName := model
|
modelName := model
|
||||||
var stream io.ReadCloser
|
var stream io.ReadCloser
|
||||||
@@ -716,7 +728,7 @@ func (c *Client) SendRawMessageStream(ctx context.Context, rawJson []byte, alt s
|
|||||||
modelName = c.getPreviewModel(model)
|
modelName = c.getPreviewModel(model)
|
||||||
if modelName != "" {
|
if modelName != "" {
|
||||||
log.Debugf("Model %s is quota exceeded. Switch to preview model %s", model, modelName)
|
log.Debugf("Model %s is quota exceeded. Switch to preview model %s", model, modelName)
|
||||||
rawJson, _ = sjson.SetBytes(rawJson, "model", modelName)
|
rawJSON, _ = sjson.SetBytes(rawJSON, "model", modelName)
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -727,7 +739,7 @@ func (c *Client) SendRawMessageStream(ctx context.Context, rawJson []byte, alt s
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
var err *ErrorMessage
|
var err *ErrorMessage
|
||||||
stream, err = c.APIRequest(ctx, "streamGenerateContent", rawJson, alt, true)
|
stream, err = c.APIRequest(ctx, "streamGenerateContent", rawJSON, alt, true)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if err.StatusCode == 429 {
|
if err.StatusCode == 429 {
|
||||||
now := time.Now()
|
now := time.Now()
|
||||||
@@ -774,6 +786,8 @@ func (c *Client) SendRawMessageStream(ctx context.Context, rawJson []byte, alt s
|
|||||||
return dataChan, errChan
|
return dataChan, errChan
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// isModelQuotaExceeded checks if the specified model has exceeded its quota
|
||||||
|
// within the last 30 minutes.
|
||||||
func (c *Client) isModelQuotaExceeded(model string) bool {
|
func (c *Client) isModelQuotaExceeded(model string) bool {
|
||||||
if lastExceededTime, hasKey := c.modelQuotaExceeded[model]; hasKey {
|
if lastExceededTime, hasKey := c.modelQuotaExceeded[model]; hasKey {
|
||||||
duration := time.Now().Sub(*lastExceededTime)
|
duration := time.Now().Sub(*lastExceededTime)
|
||||||
@@ -785,6 +799,8 @@ func (c *Client) isModelQuotaExceeded(model string) bool {
|
|||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// getPreviewModel returns an available preview model for the given base model,
|
||||||
|
// or an empty string if no preview models are available or all are quota exceeded.
|
||||||
func (c *Client) getPreviewModel(model string) string {
|
func (c *Client) getPreviewModel(model string) string {
|
||||||
if models, hasKey := previewModels[model]; hasKey {
|
if models, hasKey := previewModels[model]; hasKey {
|
||||||
for i := 0; i < len(models); i++ {
|
for i := 0; i < len(models); i++ {
|
||||||
@@ -796,6 +812,8 @@ func (c *Client) getPreviewModel(model string) string {
|
|||||||
return ""
|
return ""
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// IsModelQuotaExceeded returns true if the specified model has exceeded its quota
|
||||||
|
// and no fallback options are available.
|
||||||
func (c *Client) IsModelQuotaExceeded(model string) bool {
|
func (c *Client) IsModelQuotaExceeded(model string) bool {
|
||||||
if c.isModelQuotaExceeded(model) {
|
if c.isModelQuotaExceeded(model) {
|
||||||
if c.cfg.QuotaExceeded.SwitchPreviewModel {
|
if c.cfg.QuotaExceeded.SwitchPreviewModel {
|
||||||
@@ -824,20 +842,20 @@ func (c *Client) CheckCloudAPIIsEnabled() (bool, error) {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
// If a 403 Forbidden error occurs, it likely means the API is not enabled.
|
// If a 403 Forbidden error occurs, it likely means the API is not enabled.
|
||||||
if err.StatusCode == 403 {
|
if err.StatusCode == 403 {
|
||||||
errJson := err.Error.Error()
|
errJSON := err.Error.Error()
|
||||||
// Check for a specific error code and extract the activation URL.
|
// Check for a specific error code and extract the activation URL.
|
||||||
if gjson.Get(errJson, "error.code").Int() == 403 {
|
if gjson.Get(errJSON, "error.code").Int() == 403 {
|
||||||
activationUrl := gjson.Get(errJson, "error.details.0.metadata.activationUrl").String()
|
activationURL := gjson.Get(errJSON, "error.details.0.metadata.activationUrl").String()
|
||||||
if activationUrl != "" {
|
if activationURL != "" {
|
||||||
log.Warnf(
|
log.Warnf(
|
||||||
"\n\nPlease activate your account with this url:\n\n%s\n And execute this command again:\n%s --login --project_id %s",
|
"\n\nPlease activate your account with this url:\n\n%s\n And execute this command again:\n%s --login --project_id %s",
|
||||||
activationUrl,
|
activationURL,
|
||||||
os.Args[0],
|
os.Args[0],
|
||||||
c.tokenStorage.ProjectID,
|
c.tokenStorage.ProjectID,
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
log.Warnf("\n\nPlease copy this message and create an issue.\n\n%s\n\n", errJson)
|
log.Warnf("\n\nPlease copy this message and create an issue.\n\n%s\n\n", errJSON)
|
||||||
return false, nil
|
return false, nil
|
||||||
}
|
}
|
||||||
return false, err.Error
|
return false, err.Error
|
||||||
|
|||||||
@@ -1,3 +1,6 @@
|
|||||||
|
// Package cmd provides command-line interface functionality for the CLI Proxy API.
|
||||||
|
// It implements the main application commands including login/authentication
|
||||||
|
// and server startup, handling the complete user onboarding and service lifecycle.
|
||||||
package cmd
|
package cmd
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
|||||||
@@ -1,3 +1,8 @@
|
|||||||
|
// Package cmd provides the main service execution functionality for the CLIProxyAPI.
|
||||||
|
// It contains the core logic for starting and managing the API proxy service,
|
||||||
|
// including authentication client management, server initialization, and graceful shutdown handling.
|
||||||
|
// The package handles loading authentication tokens, creating client pools, starting the API server,
|
||||||
|
// and monitoring configuration changes through file watchers.
|
||||||
package cmd
|
package cmd
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
|||||||
@@ -1,3 +1,7 @@
|
|||||||
|
// Package config provides configuration management for the CLI Proxy API server.
|
||||||
|
// It handles loading and parsing YAML configuration files, and provides structured
|
||||||
|
// access to application settings including server port, authentication directory,
|
||||||
|
// debug settings, proxy configuration, and API keys.
|
||||||
package config
|
package config
|
||||||
|
|
||||||
import (
|
import (
|
||||||
@@ -14,17 +18,19 @@ type Config struct {
|
|||||||
AuthDir string `yaml:"auth-dir"`
|
AuthDir string `yaml:"auth-dir"`
|
||||||
// Debug enables or disables debug-level logging and other debug features.
|
// Debug enables or disables debug-level logging and other debug features.
|
||||||
Debug bool `yaml:"debug"`
|
Debug bool `yaml:"debug"`
|
||||||
// ProxyUrl is the URL of an optional proxy server to use for outbound requests.
|
// ProxyURL is the URL of an optional proxy server to use for outbound requests.
|
||||||
ProxyUrl string `yaml:"proxy-url"`
|
ProxyURL string `yaml:"proxy-url"`
|
||||||
// ApiKeys is a list of keys for authenticating clients to this proxy server.
|
// APIKeys is a list of keys for authenticating clients to this proxy server.
|
||||||
ApiKeys []string `yaml:"api-keys"`
|
APIKeys []string `yaml:"api-keys"`
|
||||||
// QuotaExceeded defines the behavior when a quota is exceeded.
|
// QuotaExceeded defines the behavior when a quota is exceeded.
|
||||||
QuotaExceeded ConfigQuotaExceeded `yaml:"quota-exceeded"`
|
QuotaExceeded QuotaExceeded `yaml:"quota-exceeded"`
|
||||||
// GlAPIKey is the API key for the generative language API.
|
// GlAPIKey is the API key for the generative language API.
|
||||||
GlAPIKey []string `yaml:"generative-language-api-key"`
|
GlAPIKey []string `yaml:"generative-language-api-key"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type ConfigQuotaExceeded struct {
|
// QuotaExceeded defines the behavior when API quota limits are exceeded.
|
||||||
|
// It provides configuration options for automatic failover mechanisms.
|
||||||
|
type QuotaExceeded struct {
|
||||||
// SwitchProject indicates whether to automatically switch to another project when a quota is exceeded.
|
// SwitchProject indicates whether to automatically switch to another project when a quota is exceeded.
|
||||||
SwitchProject bool `yaml:"switch-project"`
|
SwitchProject bool `yaml:"switch-project"`
|
||||||
// SwitchPreviewModel indicates whether to automatically switch to a preview model when a quota is exceeded.
|
// SwitchPreviewModel indicates whether to automatically switch to a preview model when a quota is exceeded.
|
||||||
|
|||||||
@@ -1,3 +1,6 @@
|
|||||||
|
// Package util provides utility functions for the CLI Proxy API server.
|
||||||
|
// It includes helper functions for proxy configuration, HTTP client setup,
|
||||||
|
// and other common operations used across the application.
|
||||||
package util
|
package util
|
||||||
|
|
||||||
import (
|
import (
|
||||||
@@ -9,9 +12,12 @@ import (
|
|||||||
"net/url"
|
"net/url"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// SetProxy configures the provided HTTP client with proxy settings from the configuration.
|
||||||
|
// It supports SOCKS5, HTTP, and HTTPS proxies. The function modifies the client's transport
|
||||||
|
// to route requests through the configured proxy server.
|
||||||
func SetProxy(cfg *config.Config, httpClient *http.Client) (*http.Client, error) {
|
func SetProxy(cfg *config.Config, httpClient *http.Client) (*http.Client, error) {
|
||||||
var transport *http.Transport
|
var transport *http.Transport
|
||||||
proxyURL, errParse := url.Parse(cfg.ProxyUrl)
|
proxyURL, errParse := url.Parse(cfg.ProxyURL)
|
||||||
if errParse == nil {
|
if errParse == nil {
|
||||||
if proxyURL.Scheme == "socks5" {
|
if proxyURL.Scheme == "socks5" {
|
||||||
username := proxyURL.User.Username()
|
username := proxyURL.User.Username()
|
||||||
|
|||||||
@@ -1,3 +1,7 @@
|
|||||||
|
// Package watcher provides file system monitoring functionality for the CLI Proxy API.
|
||||||
|
// It watches configuration files and authentication directories for changes,
|
||||||
|
// automatically reloading clients and configuration when files are modified.
|
||||||
|
// The package handles cross-platform file system events and supports hot-reloading.
|
||||||
package watcher
|
package watcher
|
||||||
|
|
||||||
import (
|
import (
|
||||||
@@ -156,11 +160,11 @@ func (w *Watcher) reloadConfig() {
|
|||||||
if oldConfig.Debug != newConfig.Debug {
|
if oldConfig.Debug != newConfig.Debug {
|
||||||
log.Debugf(" debug: %t -> %t", oldConfig.Debug, newConfig.Debug)
|
log.Debugf(" debug: %t -> %t", oldConfig.Debug, newConfig.Debug)
|
||||||
}
|
}
|
||||||
if oldConfig.ProxyUrl != newConfig.ProxyUrl {
|
if oldConfig.ProxyURL != newConfig.ProxyURL {
|
||||||
log.Debugf(" proxy-url: %s -> %s", oldConfig.ProxyUrl, newConfig.ProxyUrl)
|
log.Debugf(" proxy-url: %s -> %s", oldConfig.ProxyURL, newConfig.ProxyURL)
|
||||||
}
|
}
|
||||||
if len(oldConfig.ApiKeys) != len(newConfig.ApiKeys) {
|
if len(oldConfig.APIKeys) != len(newConfig.APIKeys) {
|
||||||
log.Debugf(" api-keys count: %d -> %d", len(oldConfig.ApiKeys), len(newConfig.ApiKeys))
|
log.Debugf(" api-keys count: %d -> %d", len(oldConfig.APIKeys), len(newConfig.APIKeys))
|
||||||
}
|
}
|
||||||
if len(oldConfig.GlAPIKey) != len(newConfig.GlAPIKey) {
|
if len(oldConfig.GlAPIKey) != len(newConfig.GlAPIKey) {
|
||||||
log.Debugf(" generative-language-api-key count: %d -> %d", len(oldConfig.GlAPIKey), len(newConfig.GlAPIKey))
|
log.Debugf(" generative-language-api-key count: %d -> %d", len(oldConfig.GlAPIKey), len(newConfig.GlAPIKey))
|
||||||
@@ -248,7 +252,7 @@ func (w *Watcher) reloadClients() {
|
|||||||
log.Debugf("auth directory scan complete - found %d .json files, %d successful authentications", authFileCount, successfulAuthCount)
|
log.Debugf("auth directory scan complete - found %d .json files, %d successful authentications", authFileCount, successfulAuthCount)
|
||||||
|
|
||||||
// Add clients for Generative Language API keys if configured
|
// Add clients for Generative Language API keys if configured
|
||||||
glApiKeyCount := 0
|
glAPIKeyCount := 0
|
||||||
if len(cfg.GlAPIKey) > 0 {
|
if len(cfg.GlAPIKey) > 0 {
|
||||||
log.Debugf("processing %d Generative Language API keys", len(cfg.GlAPIKey))
|
log.Debugf("processing %d Generative Language API keys", len(cfg.GlAPIKey))
|
||||||
for i := 0; i < len(cfg.GlAPIKey); i++ {
|
for i := 0; i < len(cfg.GlAPIKey); i++ {
|
||||||
@@ -261,9 +265,9 @@ func (w *Watcher) reloadClients() {
|
|||||||
log.Debugf(" initializing with Generative Language API key %d...", i+1)
|
log.Debugf(" initializing with Generative Language API key %d...", i+1)
|
||||||
cliClient := client.NewClient(httpClient, nil, cfg, cfg.GlAPIKey[i])
|
cliClient := client.NewClient(httpClient, nil, cfg, cfg.GlAPIKey[i])
|
||||||
newClients = append(newClients, cliClient)
|
newClients = append(newClients, cliClient)
|
||||||
glApiKeyCount++
|
glAPIKeyCount++
|
||||||
}
|
}
|
||||||
log.Debugf("successfully initialized %d Generative Language API key clients", glApiKeyCount)
|
log.Debugf("successfully initialized %d Generative Language API key clients", glAPIKeyCount)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Update the client list
|
// Update the client list
|
||||||
@@ -272,7 +276,7 @@ func (w *Watcher) reloadClients() {
|
|||||||
w.clientsMutex.Unlock()
|
w.clientsMutex.Unlock()
|
||||||
|
|
||||||
log.Infof("client reload complete - old: %d clients, new: %d clients (%d auth files + %d GL API keys)",
|
log.Infof("client reload complete - old: %d clients, new: %d clients (%d auth files + %d GL API keys)",
|
||||||
oldClientCount, len(newClients), successfulAuthCount, glApiKeyCount)
|
oldClientCount, len(newClients), successfulAuthCount, glAPIKeyCount)
|
||||||
|
|
||||||
// Trigger the callback to update the server
|
// Trigger the callback to update the server
|
||||||
if w.reloadCallback != nil {
|
if w.reloadCallback != nil {
|
||||||
|
|||||||
Reference in New Issue
Block a user