mirror of
https://github.com/router-for-me/CLIProxyAPI.git
synced 2026-02-19 04:40:52 +08:00
Add FixCLIToolResponse for enhanced function call-response mapping
- Introduced `FixCLIToolResponse` in `translator` to group function calls with corresponding responses. - Updated Gemini handlers to integrate new function for improved response handling. - Enhanced error handling in case response mapping fails.
This commit is contained in:
@@ -4,6 +4,7 @@ import (
|
|||||||
"context"
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
|
"github.com/luispater/CLIProxyAPI/internal/api/translator"
|
||||||
"github.com/luispater/CLIProxyAPI/internal/client"
|
"github.com/luispater/CLIProxyAPI/internal/client"
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
"github.com/tidwall/gjson"
|
"github.com/tidwall/gjson"
|
||||||
@@ -90,6 +91,19 @@ outLoop:
|
|||||||
template, _ = sjson.SetRaw(template, "request", string(rawJson))
|
template, _ = sjson.SetRaw(template, "request", string(rawJson))
|
||||||
template, _ = sjson.Set(template, "model", gjson.Get(template, "request.model").String())
|
template, _ = sjson.Set(template, "model", gjson.Get(template, "request.model").String())
|
||||||
template, _ = sjson.Delete(template, "request.model")
|
template, _ = sjson.Delete(template, "request.model")
|
||||||
|
|
||||||
|
template, errFixCLIToolResponse := translator.FixCLIToolResponse(template)
|
||||||
|
if errFixCLIToolResponse != nil {
|
||||||
|
c.JSON(http.StatusInternalServerError, ErrorResponse{
|
||||||
|
Error: ErrorDetail{
|
||||||
|
Message: errFixCLIToolResponse.Error(),
|
||||||
|
Type: "server_error",
|
||||||
|
},
|
||||||
|
})
|
||||||
|
cliCancel()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
systemInstructionResult := gjson.Get(template, "request.system_instruction")
|
systemInstructionResult := gjson.Get(template, "request.system_instruction")
|
||||||
if systemInstructionResult.Exists() {
|
if systemInstructionResult.Exists() {
|
||||||
template, _ = sjson.SetRaw(template, "request.systemInstruction", systemInstructionResult.Raw)
|
template, _ = sjson.SetRaw(template, "request.systemInstruction", systemInstructionResult.Raw)
|
||||||
@@ -178,6 +192,19 @@ func (h *APIHandlers) geminiGenerateContent(c *gin.Context, rawJson []byte) {
|
|||||||
template, _ = sjson.SetRaw(template, "request", string(rawJson))
|
template, _ = sjson.SetRaw(template, "request", string(rawJson))
|
||||||
template, _ = sjson.Set(template, "model", gjson.Get(template, "request.model").String())
|
template, _ = sjson.Set(template, "model", gjson.Get(template, "request.model").String())
|
||||||
template, _ = sjson.Delete(template, "request.model")
|
template, _ = sjson.Delete(template, "request.model")
|
||||||
|
|
||||||
|
template, errFixCLIToolResponse := translator.FixCLIToolResponse(template)
|
||||||
|
if errFixCLIToolResponse != nil {
|
||||||
|
c.JSON(http.StatusInternalServerError, ErrorResponse{
|
||||||
|
Error: ErrorDetail{
|
||||||
|
Message: errFixCLIToolResponse.Error(),
|
||||||
|
Type: "server_error",
|
||||||
|
},
|
||||||
|
})
|
||||||
|
cliCancel()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
systemInstructionResult := gjson.Get(template, "request.system_instruction")
|
systemInstructionResult := gjson.Get(template, "request.system_instruction")
|
||||||
if systemInstructionResult.Exists() {
|
if systemInstructionResult.Exists() {
|
||||||
template, _ = sjson.SetRaw(template, "request.systemInstruction", systemInstructionResult.Raw)
|
template, _ = sjson.SetRaw(template, "request.systemInstruction", systemInstructionResult.Raw)
|
||||||
|
|||||||
@@ -2,6 +2,8 @@ package translator
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"github.com/tidwall/sjson"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"github.com/luispater/CLIProxyAPI/internal/client"
|
"github.com/luispater/CLIProxyAPI/internal/client"
|
||||||
@@ -197,3 +199,167 @@ func PrepareRequest(rawJson []byte) (string, *client.Content, []client.Content,
|
|||||||
|
|
||||||
return modelName, systemInstruction, contents, tools
|
return modelName, systemInstruction, contents, tools
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// FunctionCallGroup represents a group of function calls and their responses
|
||||||
|
type FunctionCallGroup struct {
|
||||||
|
ModelContent map[string]interface{}
|
||||||
|
FunctionCalls []gjson.Result
|
||||||
|
ResponsesNeeded int
|
||||||
|
}
|
||||||
|
|
||||||
|
// FixCLIToolResponse converts the format from 1.json to 2.json
|
||||||
|
// It groups function calls with their corresponding responses
|
||||||
|
func FixCLIToolResponse(input string) (string, error) {
|
||||||
|
// Parse the input JSON
|
||||||
|
parsed := gjson.Parse(input)
|
||||||
|
|
||||||
|
// Get the contents array
|
||||||
|
contents := parsed.Get("request.contents")
|
||||||
|
if !contents.Exists() {
|
||||||
|
return input, fmt.Errorf("contents not found in input")
|
||||||
|
}
|
||||||
|
|
||||||
|
var newContents []interface{}
|
||||||
|
var pendingGroups []*FunctionCallGroup
|
||||||
|
var collectedResponses []gjson.Result
|
||||||
|
|
||||||
|
// Process each content object
|
||||||
|
contents.ForEach(func(key, value gjson.Result) bool {
|
||||||
|
role := value.Get("role").String()
|
||||||
|
parts := value.Get("parts")
|
||||||
|
|
||||||
|
// Check if this content has function responses
|
||||||
|
var responsePartsInThisContent []gjson.Result
|
||||||
|
parts.ForEach(func(_, part gjson.Result) bool {
|
||||||
|
if part.Get("functionResponse").Exists() {
|
||||||
|
responsePartsInThisContent = append(responsePartsInThisContent, part)
|
||||||
|
}
|
||||||
|
return true
|
||||||
|
})
|
||||||
|
|
||||||
|
// If this content has function responses, collect them
|
||||||
|
if len(responsePartsInThisContent) > 0 {
|
||||||
|
collectedResponses = append(collectedResponses, responsePartsInThisContent...)
|
||||||
|
|
||||||
|
// Check if any pending groups can be satisfied
|
||||||
|
for i := len(pendingGroups) - 1; i >= 0; i-- {
|
||||||
|
group := pendingGroups[i]
|
||||||
|
if len(collectedResponses) >= group.ResponsesNeeded {
|
||||||
|
// Take the needed responses for this group
|
||||||
|
groupResponses := collectedResponses[:group.ResponsesNeeded]
|
||||||
|
collectedResponses = collectedResponses[group.ResponsesNeeded:]
|
||||||
|
|
||||||
|
// Create merged function response content
|
||||||
|
var responseParts []interface{}
|
||||||
|
for _, response := range groupResponses {
|
||||||
|
var responseMap map[string]interface{}
|
||||||
|
errUnmarshal := json.Unmarshal([]byte(response.Raw), &responseMap)
|
||||||
|
if errUnmarshal != nil {
|
||||||
|
log.Warnf("failed to unmarshal function response: %v\n", errUnmarshal)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
responseParts = append(responseParts, responseMap)
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(responseParts) > 0 {
|
||||||
|
functionResponseContent := map[string]interface{}{
|
||||||
|
"parts": responseParts,
|
||||||
|
"role": "function",
|
||||||
|
}
|
||||||
|
newContents = append(newContents, functionResponseContent)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Remove this group as it's been satisfied
|
||||||
|
pendingGroups = append(pendingGroups[:i], pendingGroups[i+1:]...)
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return true // Skip adding this content, responses are merged
|
||||||
|
}
|
||||||
|
|
||||||
|
// If this is a model with function calls, create a new group
|
||||||
|
if role == "model" {
|
||||||
|
var functionCallsInThisModel []gjson.Result
|
||||||
|
parts.ForEach(func(_, part gjson.Result) bool {
|
||||||
|
if part.Get("functionCall").Exists() {
|
||||||
|
functionCallsInThisModel = append(functionCallsInThisModel, part)
|
||||||
|
}
|
||||||
|
return true
|
||||||
|
})
|
||||||
|
|
||||||
|
if len(functionCallsInThisModel) > 0 {
|
||||||
|
// Add the model content
|
||||||
|
var contentMap map[string]interface{}
|
||||||
|
errUnmarshal := json.Unmarshal([]byte(value.Raw), &contentMap)
|
||||||
|
if errUnmarshal != nil {
|
||||||
|
log.Warnf("failed to unmarshal model content: %v\n", errUnmarshal)
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
newContents = append(newContents, contentMap)
|
||||||
|
|
||||||
|
// Create a new group for tracking responses
|
||||||
|
group := &FunctionCallGroup{
|
||||||
|
ModelContent: contentMap,
|
||||||
|
FunctionCalls: functionCallsInThisModel,
|
||||||
|
ResponsesNeeded: len(functionCallsInThisModel),
|
||||||
|
}
|
||||||
|
pendingGroups = append(pendingGroups, group)
|
||||||
|
} else {
|
||||||
|
// Regular model content without function calls
|
||||||
|
var contentMap map[string]interface{}
|
||||||
|
errUnmarshal := json.Unmarshal([]byte(value.Raw), &contentMap)
|
||||||
|
if errUnmarshal != nil {
|
||||||
|
log.Warnf("failed to unmarshal content: %v\n", errUnmarshal)
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
newContents = append(newContents, contentMap)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// Non-model content (user, etc.)
|
||||||
|
var contentMap map[string]interface{}
|
||||||
|
errUnmarshal := json.Unmarshal([]byte(value.Raw), &contentMap)
|
||||||
|
if errUnmarshal != nil {
|
||||||
|
log.Warnf("failed to unmarshal content: %v\n", errUnmarshal)
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
newContents = append(newContents, contentMap)
|
||||||
|
}
|
||||||
|
|
||||||
|
return true
|
||||||
|
})
|
||||||
|
|
||||||
|
// Handle any remaining pending groups with remaining responses
|
||||||
|
for _, group := range pendingGroups {
|
||||||
|
if len(collectedResponses) >= group.ResponsesNeeded {
|
||||||
|
groupResponses := collectedResponses[:group.ResponsesNeeded]
|
||||||
|
collectedResponses = collectedResponses[group.ResponsesNeeded:]
|
||||||
|
|
||||||
|
var responseParts []interface{}
|
||||||
|
for _, response := range groupResponses {
|
||||||
|
var responseMap map[string]interface{}
|
||||||
|
errUnmarshal := json.Unmarshal([]byte(response.Raw), &responseMap)
|
||||||
|
if errUnmarshal != nil {
|
||||||
|
log.Warnf("failed to unmarshal function response: %v\n", errUnmarshal)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
responseParts = append(responseParts, responseMap)
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(responseParts) > 0 {
|
||||||
|
functionResponseContent := map[string]interface{}{
|
||||||
|
"parts": responseParts,
|
||||||
|
"role": "function",
|
||||||
|
}
|
||||||
|
newContents = append(newContents, functionResponseContent)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Update the original JSON with the new contents
|
||||||
|
result := input
|
||||||
|
newContentsJSON, _ := json.Marshal(newContents)
|
||||||
|
result, _ = sjson.Set(result, "request.contents", json.RawMessage(newContentsJSON))
|
||||||
|
|
||||||
|
return result, nil
|
||||||
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user