diff --git a/internal/api/gemini-handlers.go b/internal/api/gemini-handlers.go index c6fa8226..2b62bf09 100644 --- a/internal/api/gemini-handlers.go +++ b/internal/api/gemini-handlers.go @@ -4,6 +4,7 @@ import ( "context" "fmt" "github.com/gin-gonic/gin" + "github.com/luispater/CLIProxyAPI/internal/api/translator" "github.com/luispater/CLIProxyAPI/internal/client" log "github.com/sirupsen/logrus" "github.com/tidwall/gjson" @@ -90,6 +91,19 @@ outLoop: template, _ = sjson.SetRaw(template, "request", string(rawJson)) template, _ = sjson.Set(template, "model", gjson.Get(template, "request.model").String()) template, _ = sjson.Delete(template, "request.model") + + template, errFixCLIToolResponse := translator.FixCLIToolResponse(template) + if errFixCLIToolResponse != nil { + c.JSON(http.StatusInternalServerError, ErrorResponse{ + Error: ErrorDetail{ + Message: errFixCLIToolResponse.Error(), + Type: "server_error", + }, + }) + cliCancel() + return + } + systemInstructionResult := gjson.Get(template, "request.system_instruction") if systemInstructionResult.Exists() { template, _ = sjson.SetRaw(template, "request.systemInstruction", systemInstructionResult.Raw) @@ -178,6 +192,19 @@ func (h *APIHandlers) geminiGenerateContent(c *gin.Context, rawJson []byte) { template, _ = sjson.SetRaw(template, "request", string(rawJson)) template, _ = sjson.Set(template, "model", gjson.Get(template, "request.model").String()) template, _ = sjson.Delete(template, "request.model") + + template, errFixCLIToolResponse := translator.FixCLIToolResponse(template) + if errFixCLIToolResponse != nil { + c.JSON(http.StatusInternalServerError, ErrorResponse{ + Error: ErrorDetail{ + Message: errFixCLIToolResponse.Error(), + Type: "server_error", + }, + }) + cliCancel() + return + } + systemInstructionResult := gjson.Get(template, "request.system_instruction") if systemInstructionResult.Exists() { template, _ = sjson.SetRaw(template, "request.systemInstruction", systemInstructionResult.Raw) diff --git a/internal/api/translator/request.go b/internal/api/translator/request.go index 64517f5b..d8c3a75a 100644 --- a/internal/api/translator/request.go +++ b/internal/api/translator/request.go @@ -2,6 +2,8 @@ package translator import ( "encoding/json" + "fmt" + "github.com/tidwall/sjson" "strings" "github.com/luispater/CLIProxyAPI/internal/client" @@ -197,3 +199,167 @@ func PrepareRequest(rawJson []byte) (string, *client.Content, []client.Content, return modelName, systemInstruction, contents, tools } + +// FunctionCallGroup represents a group of function calls and their responses +type FunctionCallGroup struct { + ModelContent map[string]interface{} + FunctionCalls []gjson.Result + ResponsesNeeded int +} + +// FixCLIToolResponse converts the format from 1.json to 2.json +// It groups function calls with their corresponding responses +func FixCLIToolResponse(input string) (string, error) { + // Parse the input JSON + parsed := gjson.Parse(input) + + // Get the contents array + contents := parsed.Get("request.contents") + if !contents.Exists() { + return input, fmt.Errorf("contents not found in input") + } + + var newContents []interface{} + var pendingGroups []*FunctionCallGroup + var collectedResponses []gjson.Result + + // Process each content object + contents.ForEach(func(key, value gjson.Result) bool { + role := value.Get("role").String() + parts := value.Get("parts") + + // Check if this content has function responses + var responsePartsInThisContent []gjson.Result + parts.ForEach(func(_, part gjson.Result) bool { + if part.Get("functionResponse").Exists() { + responsePartsInThisContent = append(responsePartsInThisContent, part) + } + return true + }) + + // If this content has function responses, collect them + if len(responsePartsInThisContent) > 0 { + collectedResponses = append(collectedResponses, responsePartsInThisContent...) + + // Check if any pending groups can be satisfied + for i := len(pendingGroups) - 1; i >= 0; i-- { + group := pendingGroups[i] + if len(collectedResponses) >= group.ResponsesNeeded { + // Take the needed responses for this group + groupResponses := collectedResponses[:group.ResponsesNeeded] + collectedResponses = collectedResponses[group.ResponsesNeeded:] + + // Create merged function response content + var responseParts []interface{} + for _, response := range groupResponses { + var responseMap map[string]interface{} + errUnmarshal := json.Unmarshal([]byte(response.Raw), &responseMap) + if errUnmarshal != nil { + log.Warnf("failed to unmarshal function response: %v\n", errUnmarshal) + continue + } + responseParts = append(responseParts, responseMap) + } + + if len(responseParts) > 0 { + functionResponseContent := map[string]interface{}{ + "parts": responseParts, + "role": "function", + } + newContents = append(newContents, functionResponseContent) + } + + // Remove this group as it's been satisfied + pendingGroups = append(pendingGroups[:i], pendingGroups[i+1:]...) + break + } + } + + return true // Skip adding this content, responses are merged + } + + // If this is a model with function calls, create a new group + if role == "model" { + var functionCallsInThisModel []gjson.Result + parts.ForEach(func(_, part gjson.Result) bool { + if part.Get("functionCall").Exists() { + functionCallsInThisModel = append(functionCallsInThisModel, part) + } + return true + }) + + if len(functionCallsInThisModel) > 0 { + // Add the model content + var contentMap map[string]interface{} + errUnmarshal := json.Unmarshal([]byte(value.Raw), &contentMap) + if errUnmarshal != nil { + log.Warnf("failed to unmarshal model content: %v\n", errUnmarshal) + return true + } + newContents = append(newContents, contentMap) + + // Create a new group for tracking responses + group := &FunctionCallGroup{ + ModelContent: contentMap, + FunctionCalls: functionCallsInThisModel, + ResponsesNeeded: len(functionCallsInThisModel), + } + pendingGroups = append(pendingGroups, group) + } else { + // Regular model content without function calls + var contentMap map[string]interface{} + errUnmarshal := json.Unmarshal([]byte(value.Raw), &contentMap) + if errUnmarshal != nil { + log.Warnf("failed to unmarshal content: %v\n", errUnmarshal) + return true + } + newContents = append(newContents, contentMap) + } + } else { + // Non-model content (user, etc.) + var contentMap map[string]interface{} + errUnmarshal := json.Unmarshal([]byte(value.Raw), &contentMap) + if errUnmarshal != nil { + log.Warnf("failed to unmarshal content: %v\n", errUnmarshal) + return true + } + newContents = append(newContents, contentMap) + } + + return true + }) + + // Handle any remaining pending groups with remaining responses + for _, group := range pendingGroups { + if len(collectedResponses) >= group.ResponsesNeeded { + groupResponses := collectedResponses[:group.ResponsesNeeded] + collectedResponses = collectedResponses[group.ResponsesNeeded:] + + var responseParts []interface{} + for _, response := range groupResponses { + var responseMap map[string]interface{} + errUnmarshal := json.Unmarshal([]byte(response.Raw), &responseMap) + if errUnmarshal != nil { + log.Warnf("failed to unmarshal function response: %v\n", errUnmarshal) + continue + } + responseParts = append(responseParts, responseMap) + } + + if len(responseParts) > 0 { + functionResponseContent := map[string]interface{}{ + "parts": responseParts, + "role": "function", + } + newContents = append(newContents, functionResponseContent) + } + } + } + + // Update the original JSON with the new contents + result := input + newContentsJSON, _ := json.Marshal(newContents) + result, _ = sjson.Set(result, "request.contents", json.RawMessage(newContentsJSON)) + + return result, nil +}