diff --git a/config.example.yaml b/config.example.yaml index b397be07..09307c33 100644 --- a/config.example.yaml +++ b/config.example.yaml @@ -275,9 +275,21 @@ oauth-model-alias: # protocol: "gemini" # restricts the rule to a specific protocol, options: openai, gemini, claude, codex # params: # JSON path (gjson/sjson syntax) -> value # "generationConfig.thinkingConfig.thinkingBudget": 32768 +# default-raw: # Default raw rules set parameters using raw JSON when missing (must be valid JSON). +# - models: +# - name: "gemini-2.5-pro" # Supports wildcards (e.g., "gemini-*") +# protocol: "gemini" # restricts the rule to a specific protocol, options: openai, gemini, claude, codex +# params: # JSON path (gjson/sjson syntax) -> raw JSON value (strings are used as-is, must be valid JSON) +# "generationConfig.responseJsonSchema": "{\"type\":\"object\",\"properties\":{\"answer\":{\"type\":\"string\"}}}" # override: # Override rules always set parameters, overwriting any existing values. # - models: # - name: "gpt-*" # Supports wildcards (e.g., "gpt-*") # protocol: "codex" # restricts the rule to a specific protocol, options: openai, gemini, claude, codex # params: # JSON path (gjson/sjson syntax) -> value # "reasoning.effort": "high" +# override-raw: # Override raw rules always set parameters using raw JSON (must be valid JSON). +# - models: +# - name: "gpt-*" # Supports wildcards (e.g., "gpt-*") +# protocol: "codex" # restricts the rule to a specific protocol, options: openai, gemini, claude, codex +# params: # JSON path (gjson/sjson syntax) -> raw JSON value (strings are used as-is, must be valid JSON) +# "response_format": "{\"type\":\"json_schema\",\"json_schema\":{\"name\":\"answer\",\"schema\":{\"type\":\"object\"}}}" diff --git a/internal/config/config.go b/internal/config/config.go index c66229a8..0405cfa7 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -6,12 +6,14 @@ package config import ( "bytes" + "encoding/json" "errors" "fmt" "os" "strings" "syscall" + log "github.com/sirupsen/logrus" "golang.org/x/crypto/bcrypt" "gopkg.in/yaml.v3" ) @@ -216,8 +218,12 @@ type AmpUpstreamAPIKeyEntry struct { type PayloadConfig struct { // Default defines rules that only set parameters when they are missing in the payload. Default []PayloadRule `yaml:"default" json:"default"` + // DefaultRaw defines rules that set raw JSON values only when they are missing. + DefaultRaw []PayloadRule `yaml:"default-raw" json:"default-raw"` // Override defines rules that always set parameters, overwriting any existing values. Override []PayloadRule `yaml:"override" json:"override"` + // OverrideRaw defines rules that always set raw JSON values, overwriting any existing values. + OverrideRaw []PayloadRule `yaml:"override-raw" json:"override-raw"` } // PayloadRule describes a single rule targeting a list of models with parameter updates. @@ -225,6 +231,7 @@ type PayloadRule struct { // Models lists model entries with name pattern and protocol constraint. Models []PayloadModelRule `yaml:"models" json:"models"` // Params maps JSON paths (gjson/sjson syntax) to values written into the payload. + // For *-raw rules, values are treated as raw JSON fragments (strings are used as-is). Params map[string]any `yaml:"params" json:"params"` } @@ -540,6 +547,9 @@ func LoadConfigOptional(configFile string, optional bool) (*Config, error) { // Normalize global OAuth model name aliases. cfg.SanitizeOAuthModelAlias() + // Validate raw payload rules and drop invalid entries. + cfg.SanitizePayloadRules() + if cfg.legacyMigrationPending { fmt.Println("Detected legacy configuration keys, attempting to persist the normalized config...") if !optional && configFile != "" { @@ -556,6 +566,61 @@ func LoadConfigOptional(configFile string, optional bool) (*Config, error) { return &cfg, nil } +// SanitizePayloadRules validates raw JSON payload rule params and drops invalid rules. +func (cfg *Config) SanitizePayloadRules() { + if cfg == nil { + return + } + cfg.Payload.DefaultRaw = sanitizePayloadRawRules(cfg.Payload.DefaultRaw, "default-raw") + cfg.Payload.OverrideRaw = sanitizePayloadRawRules(cfg.Payload.OverrideRaw, "override-raw") +} + +func sanitizePayloadRawRules(rules []PayloadRule, section string) []PayloadRule { + if len(rules) == 0 { + return rules + } + out := make([]PayloadRule, 0, len(rules)) + for i := range rules { + rule := rules[i] + if len(rule.Params) == 0 { + continue + } + invalid := false + for path, value := range rule.Params { + raw, ok := payloadRawString(value) + if !ok { + continue + } + trimmed := bytes.TrimSpace(raw) + if len(trimmed) == 0 || !json.Valid(trimmed) { + log.WithFields(log.Fields{ + "section": section, + "rule_index": i + 1, + "param": path, + }).Warn("payload rule dropped: invalid raw JSON") + invalid = true + break + } + } + if invalid { + continue + } + out = append(out, rule) + } + return out +} + +func payloadRawString(value any) ([]byte, bool) { + switch typed := value.(type) { + case string: + return []byte(typed), true + case []byte: + return typed, true + default: + return nil, false + } +} + // SanitizeOAuthModelAlias normalizes and deduplicates global OAuth model name aliases. // It trims whitespace, normalizes channel keys to lower-case, drops empty entries, // allows multiple aliases per upstream name, and ensures aliases are unique within each channel. diff --git a/internal/runtime/executor/payload_helpers.go b/internal/runtime/executor/payload_helpers.go index 9014af87..b4e03c40 100644 --- a/internal/runtime/executor/payload_helpers.go +++ b/internal/runtime/executor/payload_helpers.go @@ -1,6 +1,7 @@ package executor import ( + "encoding/json" "strings" "github.com/router-for-me/CLIProxyAPI/v6/internal/config" @@ -17,7 +18,7 @@ func applyPayloadConfigWithRoot(cfg *config.Config, model, protocol, root string return payload } rules := cfg.Payload - if len(rules.Default) == 0 && len(rules.Override) == 0 { + if len(rules.Default) == 0 && len(rules.DefaultRaw) == 0 && len(rules.Override) == 0 && len(rules.OverrideRaw) == 0 { return payload } model = strings.TrimSpace(model) @@ -55,6 +56,35 @@ func applyPayloadConfigWithRoot(cfg *config.Config, model, protocol, root string appliedDefaults[fullPath] = struct{}{} } } + // Apply default raw rules: first write wins per field across all matching rules. + for i := range rules.DefaultRaw { + rule := &rules.DefaultRaw[i] + if !payloadRuleMatchesModel(rule, model, protocol) { + continue + } + for path, value := range rule.Params { + fullPath := buildPayloadPath(root, path) + if fullPath == "" { + continue + } + if gjson.GetBytes(source, fullPath).Exists() { + continue + } + if _, ok := appliedDefaults[fullPath]; ok { + continue + } + rawValue, ok := payloadRawValue(value) + if !ok { + continue + } + updated, errSet := sjson.SetRawBytes(out, fullPath, rawValue) + if errSet != nil { + continue + } + out = updated + appliedDefaults[fullPath] = struct{}{} + } + } // Apply override rules: last write wins per field across all matching rules. for i := range rules.Override { rule := &rules.Override[i] @@ -73,6 +103,28 @@ func applyPayloadConfigWithRoot(cfg *config.Config, model, protocol, root string out = updated } } + // Apply override raw rules: last write wins per field across all matching rules. + for i := range rules.OverrideRaw { + rule := &rules.OverrideRaw[i] + if !payloadRuleMatchesModel(rule, model, protocol) { + continue + } + for path, value := range rule.Params { + fullPath := buildPayloadPath(root, path) + if fullPath == "" { + continue + } + rawValue, ok := payloadRawValue(value) + if !ok { + continue + } + updated, errSet := sjson.SetRawBytes(out, fullPath, rawValue) + if errSet != nil { + continue + } + out = updated + } + } return out } @@ -116,6 +168,24 @@ func buildPayloadPath(root, path string) string { return r + "." + p } +func payloadRawValue(value any) ([]byte, bool) { + if value == nil { + return nil, false + } + switch typed := value.(type) { + case string: + return []byte(typed), true + case []byte: + return typed, true + default: + raw, errMarshal := json.Marshal(typed) + if errMarshal != nil { + return nil, false + } + return raw, true + } +} + // matchModelPattern performs simple wildcard matching where '*' matches zero or more characters. // Examples: //