mirror of
https://github.com/router-for-me/CLIProxyAPI.git
synced 2026-02-18 04:10:51 +08:00
refactor(util, executor): optimize payload handling and schema processing
- Replaced repetitive string operations with a centralized `escapeGJSONPathKey` function. - Streamlined handling of JSON schema cleaning for Gemini and Antigravity requests. - Improved payload management by transitioning from byte slices to strings for processing. - Removed unnecessary cloning of byte slices in several places.
This commit is contained in:
@@ -1280,51 +1280,40 @@ func (e *AntigravityExecutor) buildRequest(ctx context.Context, auth *cliproxyau
|
|||||||
payload = geminiToAntigravity(modelName, payload, projectID)
|
payload = geminiToAntigravity(modelName, payload, projectID)
|
||||||
payload, _ = sjson.SetBytes(payload, "model", modelName)
|
payload, _ = sjson.SetBytes(payload, "model", modelName)
|
||||||
|
|
||||||
if strings.Contains(modelName, "claude") || strings.Contains(modelName, "gemini-3-pro-high") {
|
useAntigravitySchema := strings.Contains(modelName, "claude") || strings.Contains(modelName, "gemini-3-pro-high")
|
||||||
strJSON := string(payload)
|
payloadStr := string(payload)
|
||||||
paths := make([]string, 0)
|
paths := make([]string, 0)
|
||||||
util.Walk(gjson.ParseBytes(payload), "", "parametersJsonSchema", &paths)
|
util.Walk(gjson.Parse(payloadStr), "", "parametersJsonSchema", &paths)
|
||||||
for _, p := range paths {
|
for _, p := range paths {
|
||||||
strJSON, _ = util.RenameKey(strJSON, p, p[:len(p)-len("parametersJsonSchema")]+"parameters")
|
payloadStr, _ = util.RenameKey(payloadStr, p, p[:len(p)-len("parametersJsonSchema")]+"parameters")
|
||||||
}
|
|
||||||
|
|
||||||
// Use the centralized schema cleaner to handle unsupported keywords,
|
|
||||||
// const->enum conversion, and flattening of types/anyOf.
|
|
||||||
strJSON = util.CleanJSONSchemaForAntigravity(strJSON)
|
|
||||||
payload = []byte(strJSON)
|
|
||||||
} else {
|
|
||||||
strJSON := string(payload)
|
|
||||||
paths := make([]string, 0)
|
|
||||||
util.Walk(gjson.Parse(strJSON), "", "parametersJsonSchema", &paths)
|
|
||||||
for _, p := range paths {
|
|
||||||
strJSON, _ = util.RenameKey(strJSON, p, p[:len(p)-len("parametersJsonSchema")]+"parameters")
|
|
||||||
}
|
|
||||||
// Clean tool schemas for Gemini to remove unsupported JSON Schema keywords
|
|
||||||
// without adding empty-schema placeholders.
|
|
||||||
strJSON = util.CleanJSONSchemaForGemini(strJSON)
|
|
||||||
payload = []byte(strJSON)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if strings.Contains(modelName, "claude") || strings.Contains(modelName, "gemini-3-pro-high") {
|
if useAntigravitySchema {
|
||||||
systemInstructionPartsResult := gjson.GetBytes(payload, "request.systemInstruction.parts")
|
payloadStr = util.CleanJSONSchemaForAntigravity(payloadStr)
|
||||||
payload, _ = sjson.SetBytes(payload, "request.systemInstruction.role", "user")
|
} else {
|
||||||
payload, _ = sjson.SetBytes(payload, "request.systemInstruction.parts.0.text", systemInstruction)
|
payloadStr = util.CleanJSONSchemaForGemini(payloadStr)
|
||||||
payload, _ = sjson.SetBytes(payload, "request.systemInstruction.parts.1.text", fmt.Sprintf("Please ignore following [ignore]%s[/ignore]", systemInstruction))
|
}
|
||||||
|
|
||||||
|
if useAntigravitySchema {
|
||||||
|
systemInstructionPartsResult := gjson.Get(payloadStr, "request.systemInstruction.parts")
|
||||||
|
payloadStr, _ = sjson.Set(payloadStr, "request.systemInstruction.role", "user")
|
||||||
|
payloadStr, _ = sjson.Set(payloadStr, "request.systemInstruction.parts.0.text", systemInstruction)
|
||||||
|
payloadStr, _ = sjson.Set(payloadStr, "request.systemInstruction.parts.1.text", fmt.Sprintf("Please ignore following [ignore]%s[/ignore]", systemInstruction))
|
||||||
|
|
||||||
if systemInstructionPartsResult.Exists() && systemInstructionPartsResult.IsArray() {
|
if systemInstructionPartsResult.Exists() && systemInstructionPartsResult.IsArray() {
|
||||||
for _, partResult := range systemInstructionPartsResult.Array() {
|
for _, partResult := range systemInstructionPartsResult.Array() {
|
||||||
payload, _ = sjson.SetRawBytes(payload, "request.systemInstruction.parts.-1", []byte(partResult.Raw))
|
payloadStr, _ = sjson.SetRaw(payloadStr, "request.systemInstruction.parts.-1", partResult.Raw)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if strings.Contains(modelName, "claude") {
|
if strings.Contains(modelName, "claude") {
|
||||||
payload, _ = sjson.SetBytes(payload, "request.toolConfig.functionCallingConfig.mode", "VALIDATED")
|
payloadStr, _ = sjson.Set(payloadStr, "request.toolConfig.functionCallingConfig.mode", "VALIDATED")
|
||||||
} else {
|
} else {
|
||||||
payload, _ = sjson.DeleteBytes(payload, "request.generationConfig.maxOutputTokens")
|
payloadStr, _ = sjson.Delete(payloadStr, "request.generationConfig.maxOutputTokens")
|
||||||
}
|
}
|
||||||
|
|
||||||
httpReq, errReq := http.NewRequestWithContext(ctx, http.MethodPost, requestURL.String(), bytes.NewReader(payload))
|
httpReq, errReq := http.NewRequestWithContext(ctx, http.MethodPost, requestURL.String(), strings.NewReader(payloadStr))
|
||||||
if errReq != nil {
|
if errReq != nil {
|
||||||
return nil, errReq
|
return nil, errReq
|
||||||
}
|
}
|
||||||
@@ -1346,11 +1335,15 @@ func (e *AntigravityExecutor) buildRequest(ctx context.Context, auth *cliproxyau
|
|||||||
authLabel = auth.Label
|
authLabel = auth.Label
|
||||||
authType, authValue = auth.AccountInfo()
|
authType, authValue = auth.AccountInfo()
|
||||||
}
|
}
|
||||||
|
var payloadLog []byte
|
||||||
|
if e.cfg != nil && e.cfg.RequestLog {
|
||||||
|
payloadLog = []byte(payloadStr)
|
||||||
|
}
|
||||||
recordAPIRequest(ctx, e.cfg, upstreamRequestLog{
|
recordAPIRequest(ctx, e.cfg, upstreamRequestLog{
|
||||||
URL: requestURL.String(),
|
URL: requestURL.String(),
|
||||||
Method: http.MethodPost,
|
Method: http.MethodPost,
|
||||||
Headers: httpReq.Header.Clone(),
|
Headers: httpReq.Header.Clone(),
|
||||||
Body: payload,
|
Body: payloadLog,
|
||||||
Provider: e.Identifier(),
|
Provider: e.Identifier(),
|
||||||
AuthID: authID,
|
AuthID: authID,
|
||||||
AuthLabel: authLabel,
|
AuthLabel: authLabel,
|
||||||
|
|||||||
@@ -3,7 +3,6 @@
|
|||||||
package chat_completions
|
package chat_completions
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
|
||||||
"fmt"
|
"fmt"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
@@ -28,7 +27,7 @@ const geminiCLIFunctionThoughtSignature = "skip_thought_signature_validator"
|
|||||||
// Returns:
|
// Returns:
|
||||||
// - []byte: The transformed request data in Gemini CLI API format
|
// - []byte: The transformed request data in Gemini CLI API format
|
||||||
func ConvertOpenAIRequestToAntigravity(modelName string, inputRawJSON []byte, _ bool) []byte {
|
func ConvertOpenAIRequestToAntigravity(modelName string, inputRawJSON []byte, _ bool) []byte {
|
||||||
rawJSON := bytes.Clone(inputRawJSON)
|
rawJSON := inputRawJSON
|
||||||
// Base envelope (no default thinkingConfig)
|
// Base envelope (no default thinkingConfig)
|
||||||
out := []byte(`{"project":"","request":{"contents":[]},"model":"gemini-2.5-pro"}`)
|
out := []byte(`{"project":"","request":{"contents":[]},"model":"gemini-2.5-pro"}`)
|
||||||
|
|
||||||
|
|||||||
@@ -667,6 +667,9 @@ func orDefault(val, def string) string {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func escapeGJSONPathKey(key string) string {
|
func escapeGJSONPathKey(key string) string {
|
||||||
|
if strings.IndexAny(key, ".*?") == -1 {
|
||||||
|
return key
|
||||||
|
}
|
||||||
return gjsonPathKeyReplacer.Replace(key)
|
return gjsonPathKeyReplacer.Replace(key)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -6,7 +6,6 @@ package util
|
|||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
"fmt"
|
"fmt"
|
||||||
"strings"
|
|
||||||
|
|
||||||
"github.com/tidwall/gjson"
|
"github.com/tidwall/gjson"
|
||||||
"github.com/tidwall/sjson"
|
"github.com/tidwall/sjson"
|
||||||
@@ -33,15 +32,15 @@ func Walk(value gjson.Result, path, field string, paths *[]string) {
|
|||||||
// . -> \.
|
// . -> \.
|
||||||
// * -> \*
|
// * -> \*
|
||||||
// ? -> \?
|
// ? -> \?
|
||||||
var keyReplacer = strings.NewReplacer(".", "\\.", "*", "\\*", "?", "\\?")
|
keyStr := key.String()
|
||||||
safeKey := keyReplacer.Replace(key.String())
|
safeKey := escapeGJSONPathKey(keyStr)
|
||||||
|
|
||||||
if path == "" {
|
if path == "" {
|
||||||
childPath = safeKey
|
childPath = safeKey
|
||||||
} else {
|
} else {
|
||||||
childPath = path + "." + safeKey
|
childPath = path + "." + safeKey
|
||||||
}
|
}
|
||||||
if key.String() == field {
|
if keyStr == field {
|
||||||
*paths = append(*paths, childPath)
|
*paths = append(*paths, childPath)
|
||||||
}
|
}
|
||||||
Walk(val, childPath, field, paths)
|
Walk(val, childPath, field, paths)
|
||||||
@@ -87,15 +86,6 @@ func RenameKey(jsonStr, oldKeyPath, newKeyPath string) (string, error) {
|
|||||||
return finalJson, nil
|
return finalJson, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func DeleteKey(jsonStr, keyName string) string {
|
|
||||||
paths := make([]string, 0)
|
|
||||||
Walk(gjson.Parse(jsonStr), "", keyName, &paths)
|
|
||||||
for _, p := range paths {
|
|
||||||
jsonStr, _ = sjson.Delete(jsonStr, p)
|
|
||||||
}
|
|
||||||
return jsonStr
|
|
||||||
}
|
|
||||||
|
|
||||||
// FixJSON converts non-standard JSON that uses single quotes for strings into
|
// FixJSON converts non-standard JSON that uses single quotes for strings into
|
||||||
// RFC 8259-compliant JSON by converting those single-quoted strings to
|
// RFC 8259-compliant JSON by converting those single-quoted strings to
|
||||||
// double-quoted strings with proper escaping.
|
// double-quoted strings with proper escaping.
|
||||||
|
|||||||
@@ -155,20 +155,6 @@ func requestExecutionMetadata(ctx context.Context) map[string]any {
|
|||||||
return map[string]any{idempotencyKeyMetadataKey: key}
|
return map[string]any{idempotencyKeyMetadataKey: key}
|
||||||
}
|
}
|
||||||
|
|
||||||
func mergeMetadata(base, overlay map[string]any) map[string]any {
|
|
||||||
if len(base) == 0 && len(overlay) == 0 {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
out := make(map[string]any, len(base)+len(overlay))
|
|
||||||
for k, v := range base {
|
|
||||||
out[k] = v
|
|
||||||
}
|
|
||||||
for k, v := range overlay {
|
|
||||||
out[k] = v
|
|
||||||
}
|
|
||||||
return out
|
|
||||||
}
|
|
||||||
|
|
||||||
// BaseAPIHandler contains the handlers for API endpoints.
|
// BaseAPIHandler contains the handlers for API endpoints.
|
||||||
// It holds a pool of clients to interact with the backend service and manages
|
// It holds a pool of clients to interact with the backend service and manages
|
||||||
// load balancing, client selection, and configuration.
|
// load balancing, client selection, and configuration.
|
||||||
@@ -398,7 +384,7 @@ func (h *BaseAPIHandler) ExecuteWithAuthManager(ctx context.Context, handlerType
|
|||||||
opts := coreexecutor.Options{
|
opts := coreexecutor.Options{
|
||||||
Stream: false,
|
Stream: false,
|
||||||
Alt: alt,
|
Alt: alt,
|
||||||
OriginalRequest: cloneBytes(rawJSON),
|
OriginalRequest: rawJSON,
|
||||||
SourceFormat: sdktranslator.FromString(handlerType),
|
SourceFormat: sdktranslator.FromString(handlerType),
|
||||||
}
|
}
|
||||||
opts.Metadata = reqMeta
|
opts.Metadata = reqMeta
|
||||||
@@ -437,7 +423,7 @@ func (h *BaseAPIHandler) ExecuteCountWithAuthManager(ctx context.Context, handle
|
|||||||
opts := coreexecutor.Options{
|
opts := coreexecutor.Options{
|
||||||
Stream: false,
|
Stream: false,
|
||||||
Alt: alt,
|
Alt: alt,
|
||||||
OriginalRequest: cloneBytes(rawJSON),
|
OriginalRequest: rawJSON,
|
||||||
SourceFormat: sdktranslator.FromString(handlerType),
|
SourceFormat: sdktranslator.FromString(handlerType),
|
||||||
}
|
}
|
||||||
opts.Metadata = reqMeta
|
opts.Metadata = reqMeta
|
||||||
@@ -479,7 +465,7 @@ func (h *BaseAPIHandler) ExecuteStreamWithAuthManager(ctx context.Context, handl
|
|||||||
opts := coreexecutor.Options{
|
opts := coreexecutor.Options{
|
||||||
Stream: true,
|
Stream: true,
|
||||||
Alt: alt,
|
Alt: alt,
|
||||||
OriginalRequest: cloneBytes(rawJSON),
|
OriginalRequest: rawJSON,
|
||||||
SourceFormat: sdktranslator.FromString(handlerType),
|
SourceFormat: sdktranslator.FromString(handlerType),
|
||||||
}
|
}
|
||||||
opts.Metadata = reqMeta
|
opts.Metadata = reqMeta
|
||||||
@@ -668,17 +654,6 @@ func cloneBytes(src []byte) []byte {
|
|||||||
return dst
|
return dst
|
||||||
}
|
}
|
||||||
|
|
||||||
func cloneMetadata(src map[string]any) map[string]any {
|
|
||||||
if len(src) == 0 {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
dst := make(map[string]any, len(src))
|
|
||||||
for k, v := range src {
|
|
||||||
dst[k] = v
|
|
||||||
}
|
|
||||||
return dst
|
|
||||||
}
|
|
||||||
|
|
||||||
// WriteErrorResponse writes an error message to the response writer using the HTTP status embedded in the message.
|
// WriteErrorResponse writes an error message to the response writer using the HTTP status embedded in the message.
|
||||||
func (h *BaseAPIHandler) WriteErrorResponse(c *gin.Context, msg *interfaces.ErrorMessage) {
|
func (h *BaseAPIHandler) WriteErrorResponse(c *gin.Context, msg *interfaces.ErrorMessage) {
|
||||||
status := http.StatusInternalServerError
|
status := http.StatusInternalServerError
|
||||||
|
|||||||
Reference in New Issue
Block a user