From ad943b2d4d8e46a181705986e5145c893baf8a66 Mon Sep 17 00:00:00 2001 From: Luis Pater Date: Thu, 4 Sep 2025 02:39:56 +0800 Subject: [PATCH] Add reverse mappings for original tool names and improve error logging - Introduced reverse mapping logic for tool names in translators to restore original names when shortened. - Enhanced error handling by logging API response errors consistently across handlers. - Refactored request and response loggers to include API error details, improving debugging capabilities. - Integrated robust tool name shortening and uniqueness mechanisms for OpenAI, Gemini, and Claude requests. - Improved handler retry logic to properly capture and respond to errors. --- internal/api/handlers/claude/code_handlers.go | 13 +- .../handlers/gemini/gemini-cli_handlers.go | 24 +++- .../api/handlers/gemini/gemini_handlers.go | 23 +++- internal/api/handlers/handlers.go | 16 +++ .../api/handlers/openai/openai_handlers.go | 47 ++++++- .../openai/openai_responses_handlers.go | 24 +++- internal/api/middleware/response_writer.go | 12 ++ internal/logging/request_logger.go | 17 ++- .../codex/claude/codex_claude_request.go | 127 +++++++++++++++++- .../codex/claude/codex_claude_response.go | 34 ++++- .../codex/gemini/codex_gemini_request.go | 111 ++++++++++++++- .../codex/gemini/codex_gemini_response.go | 48 ++++++- .../chat-completions/codex_openai_request.go | 124 ++++++++++++++++- .../chat-completions/codex_openai_response.go | 49 ++++++- 14 files changed, 644 insertions(+), 25 deletions(-) diff --git a/internal/api/handlers/claude/code_handlers.go b/internal/api/handlers/claude/code_handlers.go index 75ce2c81..c3004483 100644 --- a/internal/api/handlers/claude/code_handlers.go +++ b/internal/api/handlers/claude/code_handlers.go @@ -139,12 +139,13 @@ func (h *ClaudeCodeAPIHandler) handleStreamingResponse(c *gin.Context, rawJSON [ } } }() + + var errorResponse *interfaces.ErrorMessage retryCount := 0 // Main client rotation loop with quota management // This loop implements a sophisticated load balancing and failover mechanism outLoop: for retryCount <= h.Cfg.RequestRetry { - var errorResponse *interfaces.ErrorMessage cliClient, errorResponse = h.GetClient(modelName) if errorResponse != nil { c.Status(errorResponse.StatusCode) @@ -185,6 +186,8 @@ outLoop: // This manages various error conditions and implements retry logic case errInfo, okError := <-errChan: if okError { + errorResponse = errInfo + h.LoggingAPIResponseError(cliCtx, errInfo) // Special handling for quota exceeded errors // If configured, attempt to switch to a different project/client switch errInfo.StatusCode { @@ -221,4 +224,12 @@ outLoop: } } } + + if errorResponse != nil { + c.Status(errorResponse.StatusCode) + _, _ = fmt.Fprint(c.Writer, errorResponse.Error.Error()) + flusher.Flush() + cliCancel(errorResponse.Error) + return + } } diff --git a/internal/api/handlers/gemini/gemini-cli_handlers.go b/internal/api/handlers/gemini/gemini-cli_handlers.go index 2537bcad..e105d4b1 100644 --- a/internal/api/handlers/gemini/gemini-cli_handlers.go +++ b/internal/api/handlers/gemini/gemini-cli_handlers.go @@ -169,10 +169,10 @@ func (h *GeminiCLIAPIHandler) handleInternalStreamGenerateContent(c *gin.Context } }() + var errorResponse *interfaces.ErrorMessage retryCount := 0 outLoop: for retryCount <= h.Cfg.RequestRetry { - var errorResponse *interfaces.ErrorMessage cliClient, errorResponse = h.GetClient(modelName) if errorResponse != nil { c.Status(errorResponse.StatusCode) @@ -208,6 +208,9 @@ outLoop: // Handle errors from the backend. case err, okError := <-errChan: if okError { + errorResponse = err + h.LoggingAPIResponseError(cliCtx, err) + switch err.StatusCode { case 429: if h.Cfg.QuotaExceeded.SwitchProject { @@ -232,6 +235,13 @@ outLoop: } } } + if errorResponse != nil { + c.Status(errorResponse.StatusCode) + _, _ = fmt.Fprint(c.Writer, errorResponse.Error.Error()) + flusher.Flush() + cliCancel(errorResponse.Error) + return + } } // handleInternalGenerateContent handles non-streaming content generation requests. @@ -252,9 +262,9 @@ func (h *GeminiCLIAPIHandler) handleInternalGenerateContent(c *gin.Context, rawJ } }() + var errorResponse *interfaces.ErrorMessage retryCount := 0 for retryCount <= h.Cfg.RequestRetry { - var errorResponse *interfaces.ErrorMessage cliClient, errorResponse = h.GetClient(modelName) if errorResponse != nil { c.Status(errorResponse.StatusCode) @@ -265,6 +275,9 @@ func (h *GeminiCLIAPIHandler) handleInternalGenerateContent(c *gin.Context, rawJ resp, err := cliClient.SendRawMessage(cliCtx, modelName, rawJSON, "") if err != nil { + errorResponse = err + h.LoggingAPIResponseError(cliCtx, err) + switch err.StatusCode { case 429: if h.Cfg.QuotaExceeded.SwitchProject { @@ -296,4 +309,11 @@ func (h *GeminiCLIAPIHandler) handleInternalGenerateContent(c *gin.Context, rawJ break } } + if errorResponse != nil { + c.Status(errorResponse.StatusCode) + _, _ = c.Writer.Write([]byte(errorResponse.Error.Error())) + cliCancel(errorResponse.Error) + return + } + } diff --git a/internal/api/handlers/gemini/gemini_handlers.go b/internal/api/handlers/gemini/gemini_handlers.go index 6110576c..132f73bc 100644 --- a/internal/api/handlers/gemini/gemini_handlers.go +++ b/internal/api/handlers/gemini/gemini_handlers.go @@ -221,10 +221,10 @@ func (h *GeminiAPIHandler) handleStreamGenerateContent(c *gin.Context, modelName } }() + var errorResponse *interfaces.ErrorMessage retryCount := 0 outLoop: for retryCount <= h.Cfg.RequestRetry { - var errorResponse *interfaces.ErrorMessage cliClient, errorResponse = h.GetClient(modelName) if errorResponse != nil { c.Status(errorResponse.StatusCode) @@ -263,6 +263,9 @@ outLoop: // Handle errors from the backend. case err, okError := <-errChan: if okError { + errorResponse = err + h.LoggingAPIResponseError(cliCtx, err) + switch err.StatusCode { case 429: if h.Cfg.QuotaExceeded.SwitchProject { @@ -287,6 +290,13 @@ outLoop: } } } + if errorResponse != nil { + c.Status(errorResponse.StatusCode) + _, _ = fmt.Fprint(c.Writer, errorResponse.Error.Error()) + flusher.Flush() + cliCancel(errorResponse.Error) + return + } } // handleCountTokens handles token counting requests for Gemini models. @@ -365,9 +375,9 @@ func (h *GeminiAPIHandler) handleGenerateContent(c *gin.Context, modelName strin } }() + var errorResponse *interfaces.ErrorMessage retryCount := 0 for retryCount <= h.Cfg.RequestRetry { - var errorResponse *interfaces.ErrorMessage cliClient, errorResponse = h.GetClient(modelName) if errorResponse != nil { c.Status(errorResponse.StatusCode) @@ -378,6 +388,9 @@ func (h *GeminiAPIHandler) handleGenerateContent(c *gin.Context, modelName strin resp, err := cliClient.SendRawMessage(cliCtx, modelName, rawJSON, alt) if err != nil { + errorResponse = err + h.LoggingAPIResponseError(cliCtx, err) + switch err.StatusCode { case 429: if h.Cfg.QuotaExceeded.SwitchProject { @@ -409,4 +422,10 @@ func (h *GeminiAPIHandler) handleGenerateContent(c *gin.Context, modelName strin break } } + if errorResponse != nil { + c.Status(errorResponse.StatusCode) + _, _ = c.Writer.Write([]byte(errorResponse.Error.Error())) + cliCancel(errorResponse.Error) + return + } } diff --git a/internal/api/handlers/handlers.go b/internal/api/handlers/handlers.go index 8fe28ad7..2eef87ac 100644 --- a/internal/api/handlers/handlers.go +++ b/internal/api/handlers/handlers.go @@ -235,6 +235,22 @@ func (h *BaseAPIHandler) GetContextWithCancel(handler interfaces.APIHandler, c * } } +func (h *BaseAPIHandler) LoggingAPIResponseError(ctx context.Context, err *interfaces.ErrorMessage) { + if h.Cfg.RequestLog { + if ginContext, ok := ctx.Value("gin").(*gin.Context); ok { + if apiResponseErrors, isExist := ginContext.Get("API_RESPONSE_ERROR"); isExist { + if slicesAPIResponseError, isOk := apiResponseErrors.([]*interfaces.ErrorMessage); isOk { + slicesAPIResponseError = append(slicesAPIResponseError, err) + ginContext.Set("API_RESPONSE_ERROR", slicesAPIResponseError) + } + } else { + // Create new response data entry + ginContext.Set("API_RESPONSE_ERROR", []*interfaces.ErrorMessage{err}) + } + } + } +} + // APIHandlerCancelFunc is a function type for canceling an API handler's context. // It can optionally accept parameters, which are used for logging the response. type APIHandlerCancelFunc func(params ...interface{}) diff --git a/internal/api/handlers/openai/openai_handlers.go b/internal/api/handlers/openai/openai_handlers.go index 897f04f4..47a5e247 100644 --- a/internal/api/handlers/openai/openai_handlers.go +++ b/internal/api/handlers/openai/openai_handlers.go @@ -387,9 +387,9 @@ func (h *OpenAIAPIHandler) handleNonStreamingResponse(c *gin.Context, rawJSON [] } }() + var errorResponse *interfaces.ErrorMessage retryCount := 0 for retryCount <= h.Cfg.RequestRetry { - var errorResponse *interfaces.ErrorMessage cliClient, errorResponse = h.GetClient(modelName) if errorResponse != nil { c.Status(errorResponse.StatusCode) @@ -400,6 +400,9 @@ func (h *OpenAIAPIHandler) handleNonStreamingResponse(c *gin.Context, rawJSON [] resp, err := cliClient.SendRawMessage(cliCtx, modelName, rawJSON, "") if err != nil { + errorResponse = err + h.LoggingAPIResponseError(cliCtx, err) + switch err.StatusCode { case 429: if h.Cfg.QuotaExceeded.SwitchProject { @@ -431,6 +434,12 @@ func (h *OpenAIAPIHandler) handleNonStreamingResponse(c *gin.Context, rawJSON [] break } } + if errorResponse != nil { + c.Status(errorResponse.StatusCode) + _, _ = c.Writer.Write([]byte(errorResponse.Error.Error())) + cliCancel(errorResponse.Error) + return + } } // handleStreamingResponse handles streaming responses for Gemini models. @@ -471,10 +480,10 @@ func (h *OpenAIAPIHandler) handleStreamingResponse(c *gin.Context, rawJSON []byt } }() + var errorResponse *interfaces.ErrorMessage retryCount := 0 outLoop: for retryCount <= h.Cfg.RequestRetry { - var errorResponse *interfaces.ErrorMessage cliClient, errorResponse = h.GetClient(modelName) if errorResponse != nil { c.Status(errorResponse.StatusCode) @@ -511,6 +520,9 @@ outLoop: // Handle errors from the backend. case err, okError := <-errChan: if okError { + errorResponse = err + h.LoggingAPIResponseError(cliCtx, err) + switch err.StatusCode { case 429: if h.Cfg.QuotaExceeded.SwitchProject { @@ -535,6 +547,13 @@ outLoop: } } } + if errorResponse != nil { + c.Status(errorResponse.StatusCode) + _, _ = fmt.Fprint(c.Writer, errorResponse.Error.Error()) + flusher.Flush() + cliCancel(errorResponse.Error) + return + } } // handleCompletionsNonStreamingResponse handles non-streaming completions responses. @@ -562,9 +581,9 @@ func (h *OpenAIAPIHandler) handleCompletionsNonStreamingResponse(c *gin.Context, } }() + var errorResponse *interfaces.ErrorMessage retryCount := 0 for retryCount <= h.Cfg.RequestRetry { - var errorResponse *interfaces.ErrorMessage cliClient, errorResponse = h.GetClient(modelName) if errorResponse != nil { c.Status(errorResponse.StatusCode) @@ -576,6 +595,9 @@ func (h *OpenAIAPIHandler) handleCompletionsNonStreamingResponse(c *gin.Context, // Send the converted chat completions request resp, err := cliClient.SendRawMessage(cliCtx, modelName, chatCompletionsJSON, "") if err != nil { + errorResponse = err + h.LoggingAPIResponseError(cliCtx, err) + switch err.StatusCode { case 429: if h.Cfg.QuotaExceeded.SwitchProject { @@ -601,6 +623,13 @@ func (h *OpenAIAPIHandler) handleCompletionsNonStreamingResponse(c *gin.Context, break } } + if errorResponse != nil { + c.Status(errorResponse.StatusCode) + _, _ = c.Writer.Write([]byte(errorResponse.Error.Error())) + cliCancel(errorResponse.Error) + return + } + } // handleCompletionsStreamingResponse handles streaming completions responses. @@ -644,10 +673,10 @@ func (h *OpenAIAPIHandler) handleCompletionsStreamingResponse(c *gin.Context, ra } }() + var errorResponse *interfaces.ErrorMessage retryCount := 0 outLoop: for retryCount <= h.Cfg.RequestRetry { - var errorResponse *interfaces.ErrorMessage cliClient, errorResponse = h.GetClient(modelName) if errorResponse != nil { c.Status(errorResponse.StatusCode) @@ -689,6 +718,9 @@ outLoop: // Handle errors from the backend. case err, okError := <-errChan: if okError { + errorResponse = err + h.LoggingAPIResponseError(cliCtx, err) + switch err.StatusCode { case 429: if h.Cfg.QuotaExceeded.SwitchProject { @@ -713,4 +745,11 @@ outLoop: } } } + if errorResponse != nil { + c.Status(errorResponse.StatusCode) + _, _ = fmt.Fprint(c.Writer, errorResponse.Error.Error()) + flusher.Flush() + cliCancel(errorResponse.Error) + return + } } diff --git a/internal/api/handlers/openai/openai_responses_handlers.go b/internal/api/handlers/openai/openai_responses_handlers.go index 6d5f0e82..86ccca49 100644 --- a/internal/api/handlers/openai/openai_responses_handlers.go +++ b/internal/api/handlers/openai/openai_responses_handlers.go @@ -115,9 +115,9 @@ func (h *OpenAIResponsesAPIHandler) handleNonStreamingResponse(c *gin.Context, r } }() + var errorResponse *interfaces.ErrorMessage retryCount := 0 for retryCount <= h.Cfg.RequestRetry { - var errorResponse *interfaces.ErrorMessage cliClient, errorResponse = h.GetClient(modelName) if errorResponse != nil { c.Status(errorResponse.StatusCode) @@ -128,6 +128,9 @@ func (h *OpenAIResponsesAPIHandler) handleNonStreamingResponse(c *gin.Context, r resp, err := cliClient.SendRawMessage(cliCtx, modelName, rawJSON, "") if err != nil { + errorResponse = err + h.LoggingAPIResponseError(cliCtx, err) + switch err.StatusCode { case 429: if h.Cfg.QuotaExceeded.SwitchProject { @@ -159,6 +162,13 @@ func (h *OpenAIResponsesAPIHandler) handleNonStreamingResponse(c *gin.Context, r break } } + if errorResponse != nil { + c.Status(errorResponse.StatusCode) + _, _ = c.Writer.Write([]byte(errorResponse.Error.Error())) + cliCancel(errorResponse.Error) + return + } + } // handleStreamingResponse handles streaming responses for Gemini models. @@ -199,10 +209,10 @@ func (h *OpenAIResponsesAPIHandler) handleStreamingResponse(c *gin.Context, rawJ } }() + var errorResponse *interfaces.ErrorMessage retryCount := 0 outLoop: for retryCount <= h.Cfg.RequestRetry { - var errorResponse *interfaces.ErrorMessage cliClient, errorResponse = h.GetClient(modelName) if errorResponse != nil { c.Status(errorResponse.StatusCode) @@ -238,6 +248,8 @@ outLoop: // Handle errors from the backend. case err, okError := <-errChan: if okError { + errorResponse = err + h.LoggingAPIResponseError(cliCtx, err) switch err.StatusCode { case 429: if h.Cfg.QuotaExceeded.SwitchProject { @@ -262,4 +274,12 @@ outLoop: } } } + + if errorResponse != nil { + c.Status(errorResponse.StatusCode) + _, _ = fmt.Fprint(c.Writer, errorResponse.Error.Error()) + flusher.Flush() + cliCancel(errorResponse.Error) + return + } } diff --git a/internal/api/middleware/response_writer.go b/internal/api/middleware/response_writer.go index d8068944..294c2d60 100644 --- a/internal/api/middleware/response_writer.go +++ b/internal/api/middleware/response_writer.go @@ -8,6 +8,7 @@ import ( "strings" "github.com/gin-gonic/gin" + "github.com/luispater/CLIProxyAPI/internal/interfaces" "github.com/luispater/CLIProxyAPI/internal/logging" ) @@ -240,6 +241,16 @@ func (w *ResponseWriterWrapper) Finalize(c *gin.Context) error { } } + var slicesAPIResponseError []*interfaces.ErrorMessage + apiResponseError, isExist := c.Get("API_RESPONSE_ERROR") + if isExist { + var ok bool + slicesAPIResponseError, ok = apiResponseError.([]*interfaces.ErrorMessage) + if !ok { + slicesAPIResponseError = nil + } + } + // Log complete non-streaming response return w.logger.LogRequest( w.requestInfo.URL, @@ -251,6 +262,7 @@ func (w *ResponseWriterWrapper) Finalize(c *gin.Context) error { w.body.Bytes(), apiRequestBody, apiResponseBody, + slicesAPIResponseError, ) } diff --git a/internal/logging/request_logger.go b/internal/logging/request_logger.go index 4779c62a..f655ad9f 100644 --- a/internal/logging/request_logger.go +++ b/internal/logging/request_logger.go @@ -14,6 +14,8 @@ import ( "regexp" "strings" "time" + + "github.com/luispater/CLIProxyAPI/internal/interfaces" ) // RequestLogger defines the interface for logging HTTP requests and responses. @@ -34,7 +36,7 @@ type RequestLogger interface { // // Returns: // - error: An error if logging fails, nil otherwise - LogRequest(url, method string, requestHeaders map[string][]string, body []byte, statusCode int, responseHeaders map[string][]string, response, apiRequest, apiResponse []byte) error + LogRequest(url, method string, requestHeaders map[string][]string, body []byte, statusCode int, responseHeaders map[string][]string, response, apiRequest, apiResponse []byte, apiResponseErrors []*interfaces.ErrorMessage) error // LogStreamingRequest initiates logging for a streaming request and returns a writer for chunks. // @@ -139,7 +141,7 @@ func (l *FileRequestLogger) SetEnabled(enabled bool) { // // Returns: // - error: An error if logging fails, nil otherwise -func (l *FileRequestLogger) LogRequest(url, method string, requestHeaders map[string][]string, body []byte, statusCode int, responseHeaders map[string][]string, response, apiRequest, apiResponse []byte) error { +func (l *FileRequestLogger) LogRequest(url, method string, requestHeaders map[string][]string, body []byte, statusCode int, responseHeaders map[string][]string, response, apiRequest, apiResponse []byte, apiResponseErrors []*interfaces.ErrorMessage) error { if !l.enabled { return nil } @@ -161,7 +163,7 @@ func (l *FileRequestLogger) LogRequest(url, method string, requestHeaders map[st } // Create log content - content := l.formatLogContent(url, method, requestHeaders, body, apiRequest, apiResponse, decompressedResponse, statusCode, responseHeaders) + content := l.formatLogContent(url, method, requestHeaders, body, apiRequest, apiResponse, decompressedResponse, statusCode, responseHeaders, apiResponseErrors) // Write to file if err = os.WriteFile(filePath, []byte(content), 0644); err != nil { @@ -310,7 +312,7 @@ func (l *FileRequestLogger) sanitizeForFilename(path string) string { // // Returns: // - string: The formatted log content -func (l *FileRequestLogger) formatLogContent(url, method string, headers map[string][]string, body, apiRequest, apiResponse, response []byte, status int, responseHeaders map[string][]string) string { +func (l *FileRequestLogger) formatLogContent(url, method string, headers map[string][]string, body, apiRequest, apiResponse, response []byte, status int, responseHeaders map[string][]string, apiResponseErrors []*interfaces.ErrorMessage) string { var content strings.Builder // Request info @@ -320,6 +322,13 @@ func (l *FileRequestLogger) formatLogContent(url, method string, headers map[str content.Write(apiRequest) content.WriteString("\n\n") + for i := 0; i < len(apiResponseErrors); i++ { + content.WriteString("=== API ERROR RESPONSE ===\n") + content.WriteString(fmt.Sprintf("HTTP Status: %d\n", apiResponseErrors[i].StatusCode)) + content.WriteString(apiResponseErrors[i].Error.Error()) + content.WriteString("\n\n") + } + content.WriteString("=== API RESPONSE ===\n") content.Write(apiResponse) content.WriteString("\n\n") diff --git a/internal/translator/codex/claude/codex_claude_request.go b/internal/translator/codex/claude/codex_claude_request.go index 46597cff..b03a4b5e 100644 --- a/internal/translator/codex/claude/codex_claude_request.go +++ b/internal/translator/codex/claude/codex_claude_request.go @@ -8,6 +8,8 @@ package claude import ( "bytes" "fmt" + "strconv" + "strings" "github.com/luispater/CLIProxyAPI/internal/misc" "github.com/tidwall/gjson" @@ -94,7 +96,17 @@ func ConvertClaudeRequestToCodex(modelName string, inputRawJSON []byte, _ bool) // Handle tool use content by creating function call message. functionCallMessage := `{"type":"function_call"}` functionCallMessage, _ = sjson.Set(functionCallMessage, "call_id", messageContentResult.Get("id").String()) - functionCallMessage, _ = sjson.Set(functionCallMessage, "name", messageContentResult.Get("name").String()) + { + // Shorten tool name if needed based on declared tools + name := messageContentResult.Get("name").String() + toolMap := buildReverseMapFromClaudeOriginalToShort(rawJSON) + if short, ok := toolMap[name]; ok { + name = short + } else { + name = shortenNameIfNeeded(name) + } + functionCallMessage, _ = sjson.Set(functionCallMessage, "name", name) + } functionCallMessage, _ = sjson.Set(functionCallMessage, "arguments", messageContentResult.Get("input").Raw) template, _ = sjson.SetRaw(template, "input.-1", functionCallMessage) } else if contentType == "tool_result" { @@ -130,10 +142,29 @@ func ConvertClaudeRequestToCodex(modelName string, inputRawJSON []byte, _ bool) template, _ = sjson.SetRaw(template, "tools", `[]`) template, _ = sjson.Set(template, "tool_choice", `auto`) toolResults := toolsResult.Array() + // Build short name map from declared tools + var names []string + for i := 0; i < len(toolResults); i++ { + n := toolResults[i].Get("name").String() + if n != "" { + names = append(names, n) + } + } + shortMap := buildShortNameMap(names) for i := 0; i < len(toolResults); i++ { toolResult := toolResults[i] tool := toolResult.Raw tool, _ = sjson.Set(tool, "type", "function") + // Apply shortened name if needed + if v := toolResult.Get("name"); v.Exists() { + name := v.String() + if short, ok := shortMap[name]; ok { + name = short + } else { + name = shortenNameIfNeeded(name) + } + tool, _ = sjson.Set(tool, "name", name) + } tool, _ = sjson.SetRaw(tool, "parameters", toolResult.Get("input_schema").Raw) tool, _ = sjson.Delete(tool, "input_schema") tool, _ = sjson.Delete(tool, "parameters.$schema") @@ -170,3 +201,97 @@ func ConvertClaudeRequestToCodex(modelName string, inputRawJSON []byte, _ bool) return []byte(template) } + +// shortenNameIfNeeded applies a simple shortening rule for a single name. +func shortenNameIfNeeded(name string) string { + const limit = 64 + if len(name) <= limit { + return name + } + if strings.HasPrefix(name, "mcp__") { + idx := strings.LastIndex(name, "__") + if idx > 0 { + cand := "mcp__" + name[idx+2:] + if len(cand) > limit { + return cand[:limit] + } + return cand + } + } + return name[:limit] +} + +// buildShortNameMap ensures uniqueness of shortened names within a request. +func buildShortNameMap(names []string) map[string]string { + const limit = 64 + used := map[string]struct{}{} + m := map[string]string{} + + baseCandidate := func(n string) string { + if len(n) <= limit { + return n + } + if strings.HasPrefix(n, "mcp__") { + idx := strings.LastIndex(n, "__") + if idx > 0 { + cand := "mcp__" + n[idx+2:] + if len(cand) > limit { + cand = cand[:limit] + } + return cand + } + } + return n[:limit] + } + + makeUnique := func(cand string) string { + if _, ok := used[cand]; !ok { + return cand + } + base := cand + for i := 1; ; i++ { + suffix := "~" + strconv.Itoa(i) + allowed := limit - len(suffix) + if allowed < 0 { + allowed = 0 + } + tmp := base + if len(tmp) > allowed { + tmp = tmp[:allowed] + } + tmp = tmp + suffix + if _, ok := used[tmp]; !ok { + return tmp + } + } + } + + for _, n := range names { + cand := baseCandidate(n) + uniq := makeUnique(cand) + used[uniq] = struct{}{} + m[n] = uniq + } + return m +} + +// buildReverseMapFromClaudeOriginalToShort builds original->short map, used to map tool_use names to short. +func buildReverseMapFromClaudeOriginalToShort(original []byte) map[string]string { + tools := gjson.GetBytes(original, "tools") + m := map[string]string{} + if !tools.IsArray() { + return m + } + var names []string + arr := tools.Array() + for i := 0; i < len(arr); i++ { + n := arr[i].Get("name").String() + if n != "" { + names = append(names, n) + } + } + if len(names) > 0 { + m = buildShortNameMap(names) + } + return m +} diff --git a/internal/translator/codex/claude/codex_claude_response.go b/internal/translator/codex/claude/codex_claude_response.go index 346dd4a5..704568e1 100644 --- a/internal/translator/codex/claude/codex_claude_response.go +++ b/internal/translator/codex/claude/codex_claude_response.go @@ -122,7 +122,15 @@ func ConvertCodexResponseToClaude(_ context.Context, _ string, originalRequestRa template = `{"type":"content_block_start","index":0,"content_block":{"type":"tool_use","id":"","name":"","input":{}}}` template, _ = sjson.Set(template, "index", rootResult.Get("output_index").Int()) template, _ = sjson.Set(template, "content_block.id", itemResult.Get("call_id").String()) - template, _ = sjson.Set(template, "content_block.name", itemResult.Get("name").String()) + { + // Restore original tool name if shortened + name := itemResult.Get("name").String() + rev := buildReverseMapFromClaudeOriginalShortToOriginal(originalRequestRawJSON) + if orig, ok := rev[name]; ok { + name = orig + } + template, _ = sjson.Set(template, "content_block.name", name) + } output = "event: content_block_start\n" output += fmt.Sprintf("data: %s\n\n", template) @@ -171,3 +179,27 @@ func ConvertCodexResponseToClaude(_ context.Context, _ string, originalRequestRa func ConvertCodexResponseToClaudeNonStream(_ context.Context, _ string, originalRequestRawJSON, requestRawJSON, _ []byte, _ *any) string { return "" } + +// buildReverseMapFromClaudeOriginalShortToOriginal builds a map[short]original from original Claude request tools. +func buildReverseMapFromClaudeOriginalShortToOriginal(original []byte) map[string]string { + tools := gjson.GetBytes(original, "tools") + rev := map[string]string{} + if !tools.IsArray() { + return rev + } + var names []string + arr := tools.Array() + for i := 0; i < len(arr); i++ { + n := arr[i].Get("name").String() + if n != "" { + names = append(names, n) + } + } + if len(names) > 0 { + m := buildShortNameMap(names) + for orig, short := range m { + rev[short] = orig + } + } + return rev +} diff --git a/internal/translator/codex/gemini/codex_gemini_request.go b/internal/translator/codex/gemini/codex_gemini_request.go index 0dded5cb..bf5f9e8a 100644 --- a/internal/translator/codex/gemini/codex_gemini_request.go +++ b/internal/translator/codex/gemini/codex_gemini_request.go @@ -10,6 +10,7 @@ import ( "crypto/rand" "fmt" "math/big" + "strconv" "strings" "github.com/luispater/CLIProxyAPI/internal/misc" @@ -46,6 +47,27 @@ func ConvertGeminiRequestToCodex(modelName string, inputRawJSON []byte, _ bool) root := gjson.ParseBytes(rawJSON) + // Pre-compute tool name shortening map from declared functionDeclarations + shortMap := map[string]string{} + if tools := root.Get("tools"); tools.IsArray() { + var names []string + tarr := tools.Array() + for i := 0; i < len(tarr); i++ { + fns := tarr[i].Get("functionDeclarations") + if !fns.IsArray() { + continue + } + for _, fn := range fns.Array() { + if v := fn.Get("name"); v.Exists() { + names = append(names, v.String()) + } + } + } + if len(names) > 0 { + shortMap = buildShortNameMap(names) + } + } + // helper for generating paired call IDs in the form: call_ // Gemini uses sequential pairing across possibly multiple in-flight // functionCalls, so we keep a FIFO queue of generated call IDs and @@ -124,7 +146,13 @@ func ConvertGeminiRequestToCodex(modelName string, inputRawJSON []byte, _ bool) if fc := p.Get("functionCall"); fc.Exists() { fn := `{"type":"function_call"}` if name := fc.Get("name"); name.Exists() { - fn, _ = sjson.Set(fn, "name", name.String()) + n := name.String() + if short, ok := shortMap[n]; ok { + n = short + } else { + n = shortenNameIfNeeded(n) + } + fn, _ = sjson.Set(fn, "name", n) } if args := fc.Get("args"); args.Exists() { fn, _ = sjson.Set(fn, "arguments", args.Raw) @@ -185,7 +213,13 @@ func ConvertGeminiRequestToCodex(modelName string, inputRawJSON []byte, _ bool) tool := `{}` tool, _ = sjson.Set(tool, "type", "function") if v := fn.Get("name"); v.Exists() { - tool, _ = sjson.Set(tool, "name", v.String()) + name := v.String() + if short, ok := shortMap[name]; ok { + name = short + } else { + name = shortenNameIfNeeded(name) + } + tool, _ = sjson.Set(tool, "name", name) } if v := fn.Get("description"); v.Exists() { tool, _ = sjson.Set(tool, "description", v.String()) @@ -227,3 +261,76 @@ func ConvertGeminiRequestToCodex(modelName string, inputRawJSON []byte, _ bool) return []byte(out) } + +// shortenNameIfNeeded applies the simple shortening rule for a single name. +func shortenNameIfNeeded(name string) string { + const limit = 64 + if len(name) <= limit { + return name + } + if strings.HasPrefix(name, "mcp__") { + idx := strings.LastIndex(name, "__") + if idx > 0 { + cand := "mcp__" + name[idx+2:] + if len(cand) > limit { + return cand[:limit] + } + return cand + } + } + return name[:limit] +} + +// buildShortNameMap ensures uniqueness of shortened names within a request. +func buildShortNameMap(names []string) map[string]string { + const limit = 64 + used := map[string]struct{}{} + m := map[string]string{} + + baseCandidate := func(n string) string { + if len(n) <= limit { + return n + } + if strings.HasPrefix(n, "mcp__") { + idx := strings.LastIndex(n, "__") + if idx > 0 { + cand := "mcp__" + n[idx+2:] + if len(cand) > limit { + cand = cand[:limit] + } + return cand + } + } + return n[:limit] + } + + makeUnique := func(cand string) string { + if _, ok := used[cand]; !ok { + return cand + } + base := cand + for i := 1; ; i++ { + suffix := "~" + strconv.Itoa(i) + allowed := limit - len(suffix) + if allowed < 0 { + allowed = 0 + } + tmp := base + if len(tmp) > allowed { + tmp = tmp[:allowed] + } + tmp = tmp + suffix + if _, ok := used[tmp]; !ok { + return tmp + } + } + } + + for _, n := range names { + cand := baseCandidate(n) + uniq := makeUnique(cand) + used[uniq] = struct{}{} + m[n] = uniq + } + return m +} diff --git a/internal/translator/codex/gemini/codex_gemini_response.go b/internal/translator/codex/gemini/codex_gemini_response.go index f915ce4e..67559ac2 100644 --- a/internal/translator/codex/gemini/codex_gemini_response.go +++ b/internal/translator/codex/gemini/codex_gemini_response.go @@ -80,7 +80,15 @@ func ConvertCodexResponseToGemini(_ context.Context, modelName string, originalR if itemType == "function_call" { // Create function call part functionCall := `{"functionCall":{"name":"","args":{}}}` - functionCall, _ = sjson.Set(functionCall, "functionCall.name", itemResult.Get("name").String()) + { + // Restore original tool name if shortened + n := itemResult.Get("name").String() + rev := buildReverseMapFromGeminiOriginal(originalRequestRawJSON) + if orig, ok := rev[n]; ok { + n = orig + } + functionCall, _ = sjson.Set(functionCall, "functionCall.name", n) + } // Parse and set arguments argsStr := itemResult.Get("arguments").String() @@ -250,7 +258,14 @@ func ConvertCodexResponseToGeminiNonStream(_ context.Context, modelName string, hasToolCall = true functionCall := map[string]interface{}{ "functionCall": map[string]interface{}{ - "name": value.Get("name").String(), + "name": func() string { + n := value.Get("name").String() + rev := buildReverseMapFromGeminiOriginal(originalRequestRawJSON) + if orig, ok := rev[n]; ok { + return orig + } + return n + }(), "args": map[string]interface{}{}, }, } @@ -292,6 +307,35 @@ func ConvertCodexResponseToGeminiNonStream(_ context.Context, modelName string, return "" } +// buildReverseMapFromGeminiOriginal builds a map[short]original from original Gemini request tools. +func buildReverseMapFromGeminiOriginal(original []byte) map[string]string { + tools := gjson.GetBytes(original, "tools") + rev := map[string]string{} + if !tools.IsArray() { + return rev + } + var names []string + tarr := tools.Array() + for i := 0; i < len(tarr); i++ { + fns := tarr[i].Get("functionDeclarations") + if !fns.IsArray() { + continue + } + for _, fn := range fns.Array() { + if v := fn.Get("name"); v.Exists() { + names = append(names, v.String()) + } + } + } + if len(names) > 0 { + m := buildShortNameMap(names) + for orig, short := range m { + rev[short] = orig + } + } + return rev +} + // mustMarshalJSON marshals a value to JSON, panicking on error. func mustMarshalJSON(v interface{}) string { data, err := json.Marshal(v) diff --git a/internal/translator/codex/openai/chat-completions/codex_openai_request.go b/internal/translator/codex/openai/chat-completions/codex_openai_request.go index fb098471..de493cbc 100644 --- a/internal/translator/codex/openai/chat-completions/codex_openai_request.go +++ b/internal/translator/codex/openai/chat-completions/codex_openai_request.go @@ -9,6 +9,9 @@ package chat_completions import ( "bytes" + "strconv" + "strings" + "github.com/luispater/CLIProxyAPI/internal/misc" "github.com/tidwall/gjson" "github.com/tidwall/sjson" @@ -67,6 +70,31 @@ func ConvertOpenAIRequestToCodex(modelName string, inputRawJSON []byte, stream b // Model out, _ = sjson.Set(out, "model", modelName) + // Build tool name shortening map from original tools (if any) + originalToolNameMap := map[string]string{} + { + tools := gjson.GetBytes(rawJSON, "tools") + if tools.IsArray() && len(tools.Array()) > 0 { + // Collect original tool names + var names []string + arr := tools.Array() + for i := 0; i < len(arr); i++ { + t := arr[i] + if t.Get("type").String() == "function" { + fn := t.Get("function") + if fn.Exists() { + if v := fn.Get("name"); v.Exists() { + names = append(names, v.String()) + } + } + } + } + if len(names) > 0 { + originalToolNameMap = buildShortNameMap(names) + } + } + } + // Extract system instructions from first system message (string or text object) messages := gjson.GetBytes(rawJSON, "messages") instructions := misc.CodexInstructions @@ -177,7 +205,15 @@ func ConvertOpenAIRequestToCodex(modelName string, inputRawJSON []byte, stream b funcCall := `{}` funcCall, _ = sjson.Set(funcCall, "type", "function_call") funcCall, _ = sjson.Set(funcCall, "call_id", tc.Get("id").String()) - funcCall, _ = sjson.Set(funcCall, "name", tc.Get("function.name").String()) + { + name := tc.Get("function.name").String() + if short, ok := originalToolNameMap[name]; ok { + name = short + } else { + name = shortenNameIfNeeded(name) + } + funcCall, _ = sjson.Set(funcCall, "name", name) + } funcCall, _ = sjson.Set(funcCall, "arguments", tc.Get("function.arguments").String()) out, _ = sjson.SetRaw(out, "input.-1", funcCall) } @@ -249,7 +285,13 @@ func ConvertOpenAIRequestToCodex(modelName string, inputRawJSON []byte, stream b fn := t.Get("function") if fn.Exists() { if v := fn.Get("name"); v.Exists() { - item, _ = sjson.Set(item, "name", v.Value()) + name := v.String() + if short, ok := originalToolNameMap[name]; ok { + name = short + } else { + name = shortenNameIfNeeded(name) + } + item, _ = sjson.Set(item, "name", name) } if v := fn.Get("description"); v.Exists() { item, _ = sjson.Set(item, "description", v.Value()) @@ -273,3 +315,81 @@ func ConvertOpenAIRequestToCodex(modelName string, inputRawJSON []byte, stream b out, _ = sjson.Set(out, "store", store) return []byte(out) } + +// shortenNameIfNeeded applies the simple shortening rule for a single name. +// If the name length exceeds 64, it will try to preserve the "mcp__" prefix and last segment. +// Otherwise it truncates to 64 characters. +func shortenNameIfNeeded(name string) string { + const limit = 64 + if len(name) <= limit { + return name + } + if strings.HasPrefix(name, "mcp__") { + // Keep prefix and last segment after '__' + idx := strings.LastIndex(name, "__") + if idx > 0 { + candidate := "mcp__" + name[idx+2:] + if len(candidate) > limit { + return candidate[:limit] + } + return candidate + } + } + return name[:limit] +} + +// buildShortNameMap generates unique short names (<=64) for the given list of names. +// It preserves the "mcp__" prefix with the last segment when possible and ensures uniqueness +// by appending suffixes like "~1", "~2" if needed. +func buildShortNameMap(names []string) map[string]string { + const limit = 64 + used := map[string]struct{}{} + m := map[string]string{} + + baseCandidate := func(n string) string { + if len(n) <= limit { + return n + } + if strings.HasPrefix(n, "mcp__") { + idx := strings.LastIndex(n, "__") + if idx > 0 { + cand := "mcp__" + n[idx+2:] + if len(cand) > limit { + cand = cand[:limit] + } + return cand + } + } + return n[:limit] + } + + makeUnique := func(cand string) string { + if _, ok := used[cand]; !ok { + return cand + } + base := cand + for i := 1; ; i++ { + suffix := "~" + strconv.Itoa(i) + allowed := limit - len(suffix) + if allowed < 0 { + allowed = 0 + } + tmp := base + if len(tmp) > allowed { + tmp = tmp[:allowed] + } + tmp = tmp + suffix + if _, ok := used[tmp]; !ok { + return tmp + } + } + } + + for _, n := range names { + cand := baseCandidate(n) + uniq := makeUnique(cand) + used[uniq] = struct{}{} + m[n] = uniq + } + return m +} diff --git a/internal/translator/codex/openai/chat-completions/codex_openai_response.go b/internal/translator/codex/openai/chat-completions/codex_openai_response.go index 50aa39f7..49ed00ef 100644 --- a/internal/translator/codex/openai/chat-completions/codex_openai_response.go +++ b/internal/translator/codex/openai/chat-completions/codex_openai_response.go @@ -119,7 +119,16 @@ func ConvertCodexResponseToOpenAI(_ context.Context, modelName string, originalR } template, _ = sjson.SetRaw(template, "choices.0.delta.tool_calls", `[]`) functionCallItemTemplate, _ = sjson.Set(functionCallItemTemplate, "id", itemResult.Get("call_id").String()) - functionCallItemTemplate, _ = sjson.Set(functionCallItemTemplate, "function.name", itemResult.Get("name").String()) + { + // Restore original tool name if it was shortened + name := itemResult.Get("name").String() + // Build reverse map on demand from original request tools + rev := buildReverseMapFromOriginalOpenAI(originalRequestRawJSON) + if orig, ok := rev[name]; ok { + name = orig + } + functionCallItemTemplate, _ = sjson.Set(functionCallItemTemplate, "function.name", name) + } functionCallItemTemplate, _ = sjson.Set(functionCallItemTemplate, "function.arguments", itemResult.Get("arguments").String()) template, _ = sjson.Set(template, "choices.0.delta.role", "assistant") template, _ = sjson.SetRaw(template, "choices.0.delta.tool_calls.-1", functionCallItemTemplate) @@ -244,7 +253,12 @@ func ConvertCodexResponseToOpenAINonStream(_ context.Context, _ string, original } if nameResult := outputItem.Get("name"); nameResult.Exists() { - functionCallTemplate, _ = sjson.Set(functionCallTemplate, "function.name", nameResult.String()) + n := nameResult.String() + rev := buildReverseMapFromOriginalOpenAI(originalRequestRawJSON) + if orig, ok := rev[n]; ok { + n = orig + } + functionCallTemplate, _ = sjson.Set(functionCallTemplate, "function.name", n) } if argsResult := outputItem.Get("arguments"); argsResult.Exists() { @@ -289,3 +303,34 @@ func ConvertCodexResponseToOpenAINonStream(_ context.Context, _ string, original } return "" } + +// buildReverseMapFromOriginalOpenAI builds a map of shortened tool name -> original tool name +// from the original OpenAI-style request JSON using the same shortening logic. +func buildReverseMapFromOriginalOpenAI(original []byte) map[string]string { + tools := gjson.GetBytes(original, "tools") + rev := map[string]string{} + if tools.IsArray() && len(tools.Array()) > 0 { + var names []string + arr := tools.Array() + for i := 0; i < len(arr); i++ { + t := arr[i] + if t.Get("type").String() != "function" { + continue + } + fn := t.Get("function") + if !fn.Exists() { + continue + } + if v := fn.Get("name"); v.Exists() { + names = append(names, v.String()) + } + } + if len(names) > 0 { + m := buildShortNameMap(names) + for orig, short := range m { + rev[short] = orig + } + } + } + return rev +}