mirror of
https://github.com/router-for-me/CLIProxyAPI.git
synced 2026-02-04 05:20:52 +08:00
Compare commits
118 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
653439698e | ||
|
|
89254cfc97 | ||
|
|
6bd9a034f7 | ||
|
|
26fc65b051 | ||
|
|
ed5ec5b55c | ||
|
|
df777650ac | ||
|
|
10f8c795ac | ||
|
|
3e4858a624 | ||
|
|
1231dc9cda | ||
|
|
c84ff42bcd | ||
|
|
8a5db02165 | ||
|
|
d7afb6eb0c | ||
|
|
bbd1fe890a | ||
|
|
f607231efa | ||
|
|
2039062845 | ||
|
|
99478d13a8 | ||
|
|
69d3a80fc3 | ||
|
|
9e268ad103 | ||
|
|
9d9b9e7a0d | ||
|
|
13aa82f3f3 | ||
|
|
05e55d7dc5 | ||
|
|
1b358c931c | ||
|
|
ca09db21ff | ||
|
|
718ff7a73f | ||
|
|
fa70b220e9 | ||
|
|
1b8cb7b77b | ||
|
|
774f1fbc17 | ||
|
|
cfa8ddb59f | ||
|
|
39597267ae | ||
|
|
393e38f2c0 | ||
|
|
d1220de02d | ||
|
|
13eb5268de | ||
|
|
88798816f2 | ||
|
|
598f0af19b | ||
|
|
a33f5d31fc | ||
|
|
506699fba1 | ||
|
|
68a27772b3 | ||
|
|
de87fb622b | ||
|
|
f27672f6cf | ||
|
|
28420c14e4 | ||
|
|
0bd221ff41 | ||
|
|
5fda6f8ef3 | ||
|
|
9b956f6338 | ||
|
|
09923f654c | ||
|
|
ae7b972649 | ||
|
|
47885e3710 | ||
|
|
4b9a260b37 | ||
|
|
2c743c8f0b | ||
|
|
9f2c278ee6 | ||
|
|
aea337cfe2 | ||
|
|
811f8f8b4f | ||
|
|
27734a23b1 | ||
|
|
1b8e538a77 | ||
|
|
41c2385aca | ||
|
|
d605985f45 | ||
|
|
d52b28b147 | ||
|
|
4afe1f42ca | ||
|
|
7481c0eaa0 | ||
|
|
ffdfad8482 | ||
|
|
6586f08584 | ||
|
|
f49e887fe6 | ||
|
|
a5b3ff11fd | ||
|
|
084558f200 | ||
|
|
b602eae215 | ||
|
|
d02bf9c243 | ||
|
|
26a5f67df2 | ||
|
|
600fd42a83 | ||
|
|
670685139a | ||
|
|
52b6306388 | ||
|
|
521ec6f1b8 | ||
|
|
b0c5d9640a | ||
|
|
ef8e94e992 | ||
|
|
9df96a4bb4 | ||
|
|
28a428ae2f | ||
|
|
b326ec3641 | ||
|
|
fcecbc7d46 | ||
|
|
f4007f53ba | ||
|
|
5a812a1e93 | ||
|
|
5e624cc7b1 | ||
|
|
3af24597ee | ||
|
|
e0be6c5786 | ||
|
|
88b101ebf5 | ||
|
|
d9a65745df | ||
|
|
97ab623d42 | ||
|
|
14aa6cc7e8 | ||
|
|
3bc489254b | ||
|
|
4c07ea41c3 | ||
|
|
f6720f8dfa | ||
|
|
e19ab3a066 | ||
|
|
8f1dd69e72 | ||
|
|
f26da24a2f | ||
|
|
8e4fbcaa7d | ||
|
|
09c339953d | ||
|
|
367a05bdf6 | ||
|
|
d20b71deb9 | ||
|
|
712ce9f781 | ||
|
|
a4a3274a55 | ||
|
|
716aa71f6e | ||
|
|
e8976f9898 | ||
|
|
8496cc2444 | ||
|
|
5ef2d59e05 | ||
|
|
07bb89ae80 | ||
|
|
27a5ad8ec2 | ||
|
|
707b07c5f5 | ||
|
|
4a764afd76 | ||
|
|
ecf49d574b | ||
|
|
5a75ef8ffd | ||
|
|
07279f8746 | ||
|
|
71f788b13a | ||
|
|
59c62dc580 | ||
|
|
d5310a3300 | ||
|
|
f0a3eb574e | ||
|
|
bb15855443 | ||
|
|
14ce6aebd1 | ||
|
|
2fe83723f2 | ||
|
|
cd8c86c6fb | ||
|
|
52d5fd1a67 | ||
|
|
07d21463ca |
@@ -27,4 +27,8 @@ config.yaml
|
||||
bin/*
|
||||
.claude/*
|
||||
.vscode/*
|
||||
.gemini/*
|
||||
.serena/*
|
||||
.agent/*
|
||||
.bmad/*
|
||||
_bmad/*
|
||||
|
||||
23
.github/workflows/pr-test-build.yml
vendored
Normal file
23
.github/workflows/pr-test-build.yml
vendored
Normal file
@@ -0,0 +1,23 @@
|
||||
name: pr-test-build
|
||||
|
||||
on:
|
||||
pull_request:
|
||||
|
||||
permissions:
|
||||
contents: read
|
||||
|
||||
jobs:
|
||||
build:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@v4
|
||||
- name: Set up Go
|
||||
uses: actions/setup-go@v5
|
||||
with:
|
||||
go-version-file: go.mod
|
||||
cache: true
|
||||
- name: Build
|
||||
run: |
|
||||
go build -o test-output ./cmd/server
|
||||
rm -f test-output
|
||||
4
.gitignore
vendored
4
.gitignore
vendored
@@ -30,7 +30,11 @@ GEMINI.md
|
||||
# Tooling metadata
|
||||
.vscode/*
|
||||
.claude/*
|
||||
.gemini/*
|
||||
.serena/*
|
||||
.agent/*
|
||||
.bmad/*
|
||||
_bmad/*
|
||||
|
||||
# macOS
|
||||
.DS_Store
|
||||
|
||||
@@ -59,7 +59,7 @@ CLIProxyAPI includes integrated support for [Amp CLI](https://ampcode.com) and A
|
||||
- **Model mapping** to route unavailable models to alternatives (e.g., `claude-opus-4.5` → `claude-sonnet-4`)
|
||||
- Security-first design with localhost-only management endpoints
|
||||
|
||||
**→ [Complete Amp CLI Integration Guide](docs/amp-cli-integration.md)**
|
||||
**→ [Complete Amp CLI Integration Guide](https://help.router-for.me/agent-client/amp-cli.html)**
|
||||
|
||||
## SDK Docs
|
||||
|
||||
|
||||
@@ -57,7 +57,7 @@ CLIProxyAPI 已内置对 [Amp CLI](https://ampcode.com) 和 Amp IDE 扩展的支
|
||||
- 智能模型回退与自动路由
|
||||
- 以安全为先的设计,管理端点仅限 localhost
|
||||
|
||||
**→ [Amp CLI 完整集成指南](docs/amp-cli-integration_CN.md)**
|
||||
**→ [Amp CLI 完整集成指南](https://help.router-for.me/cn/agent-client/amp-cli.html)**
|
||||
|
||||
## SDK 文档
|
||||
|
||||
|
||||
@@ -405,7 +405,7 @@ func main() {
|
||||
usage.SetStatisticsEnabled(cfg.UsageStatisticsEnabled)
|
||||
coreauth.SetQuotaCooldownDisabled(cfg.DisableCooling)
|
||||
|
||||
if err = logging.ConfigureLogOutput(cfg.LoggingToFile); err != nil {
|
||||
if err = logging.ConfigureLogOutput(cfg.LoggingToFile, cfg.LogsMaxTotalSizeMB); err != nil {
|
||||
log.Errorf("failed to configure log output: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -25,6 +25,9 @@ remote-management:
|
||||
# Disable the bundled management control panel asset download and HTTP route when true.
|
||||
disable-control-panel: false
|
||||
|
||||
# GitHub repository for the management control panel. Accepts a repository URL or releases API URL.
|
||||
panel-github-repository: "https://github.com/router-for-me/Cli-Proxy-API-Management-Center"
|
||||
|
||||
# Authentication directory (supports ~ for home directory)
|
||||
auth-dir: "~/.cli-proxy-api"
|
||||
|
||||
@@ -39,12 +42,19 @@ debug: false
|
||||
# When true, write application logs to rotating files instead of stdout
|
||||
logging-to-file: false
|
||||
|
||||
# Maximum total size (MB) of log files under the logs directory. When exceeded, the oldest log
|
||||
# files are deleted until within the limit. Set to 0 to disable.
|
||||
logs-max-total-size-mb: 0
|
||||
|
||||
# When false, disable in-memory usage statistics aggregation
|
||||
usage-statistics-enabled: false
|
||||
|
||||
# Proxy URL. Supports socks5/http/https protocols. Example: socks5://user:pass@192.168.1.1:1080/
|
||||
proxy-url: ""
|
||||
|
||||
# When true, unprefixed model requests only use credentials without a prefix (except when prefix == model name).
|
||||
force-model-prefix: false
|
||||
|
||||
# Number of times to retry a request. Retries will occur if the HTTP response code is 403, 408, 500, 502, 503, or 504.
|
||||
request-retry: 3
|
||||
|
||||
@@ -62,6 +72,7 @@ ws-auth: false
|
||||
# Gemini API keys
|
||||
# gemini-api-key:
|
||||
# - api-key: "AIzaSy...01"
|
||||
# prefix: "test" # optional: require calls like "test/gemini-3-pro-preview" to target this credential
|
||||
# base-url: "https://generativelanguage.googleapis.com"
|
||||
# headers:
|
||||
# X-Custom-Header: "custom-value"
|
||||
@@ -76,6 +87,7 @@ ws-auth: false
|
||||
# Codex API keys
|
||||
# codex-api-key:
|
||||
# - api-key: "sk-atSM..."
|
||||
# prefix: "test" # optional: require calls like "test/gpt-5-codex" to target this credential
|
||||
# base-url: "https://www.example.com" # use the custom codex API endpoint
|
||||
# headers:
|
||||
# X-Custom-Header: "custom-value"
|
||||
@@ -90,6 +102,7 @@ ws-auth: false
|
||||
# claude-api-key:
|
||||
# - api-key: "sk-atSM..." # use the official claude API key, no need to set the base url
|
||||
# - api-key: "sk-atSM..."
|
||||
# prefix: "test" # optional: require calls like "test/claude-sonnet-latest" to target this credential
|
||||
# base-url: "https://www.example.com" # use the custom claude API endpoint
|
||||
# headers:
|
||||
# X-Custom-Header: "custom-value"
|
||||
@@ -106,6 +119,7 @@ ws-auth: false
|
||||
# OpenAI compatibility providers
|
||||
# openai-compatibility:
|
||||
# - name: "openrouter" # The name of the provider; it will be used in the user agent and other places.
|
||||
# prefix: "test" # optional: require calls like "test/kimi-k2" to target this provider's credentials
|
||||
# base-url: "https://openrouter.ai/api/v1" # The base URL of the provider.
|
||||
# headers:
|
||||
# X-Custom-Header: "custom-value"
|
||||
@@ -120,6 +134,7 @@ ws-auth: false
|
||||
# Vertex API keys (Vertex-compatible endpoints, use API key + base URL)
|
||||
# vertex-api-key:
|
||||
# - api-key: "vk-123..." # x-goog-api-key header
|
||||
# prefix: "test" # optional: require calls like "test/vertex-pro" to target this credential
|
||||
# base-url: "https://example.com/api" # e.g. https://zenmux.ai/api
|
||||
# proxy-url: "socks5://proxy.example.com:1080" # optional per-key proxy override
|
||||
# headers:
|
||||
@@ -136,8 +151,8 @@ ws-auth: false
|
||||
# upstream-url: "https://ampcode.com"
|
||||
# # Optional: Override API key for Amp upstream (otherwise uses env or file)
|
||||
# upstream-api-key: ""
|
||||
# # Restrict Amp management routes (/api/auth, /api/user, etc.) to localhost only (recommended)
|
||||
# restrict-management-to-localhost: true
|
||||
# # Restrict Amp management routes (/api/auth, /api/user, etc.) to localhost only (default: false)
|
||||
# restrict-management-to-localhost: false
|
||||
# # Force model mappings to run before checking local API keys (default: false)
|
||||
# force-model-mappings: false
|
||||
# # Amp Model Mappings
|
||||
|
||||
@@ -1,443 +0,0 @@
|
||||
# Amp CLI Integration Guide
|
||||
|
||||
This guide explains how to use CLIProxyAPI with Amp CLI and Amp IDE extensions, enabling you to use your existing Google/ChatGPT/Claude subscriptions (via OAuth) with Amp's CLI.
|
||||
|
||||
## Table of Contents
|
||||
|
||||
- [Overview](#overview)
|
||||
- [Which Providers Should You Authenticate?](#which-providers-should-you-authenticate)
|
||||
- [Architecture](#architecture)
|
||||
- [Configuration](#configuration)
|
||||
- [Model Mapping Configuration](#model-mapping-configuration)
|
||||
- [Setup](#setup)
|
||||
- [Usage](#usage)
|
||||
- [Troubleshooting](#troubleshooting)
|
||||
|
||||
## Overview
|
||||
|
||||
The Amp CLI integration adds specialized routing to support Amp's API patterns while maintaining full compatibility with all existing CLIProxyAPI features. This allows you to use both traditional CLIProxyAPI features and Amp CLI with the same proxy server.
|
||||
|
||||
### Key Features
|
||||
|
||||
- **Provider route aliases**: Maps Amp's `/api/provider/{provider}/v1...` patterns to CLIProxyAPI handlers
|
||||
- **Management proxy**: Forwards OAuth and account management requests to Amp's control plane
|
||||
- **Smart fallback**: Automatically routes unconfigured models to ampcode.com
|
||||
- **Model mapping**: Route unavailable models to alternatives you have access to (e.g., `claude-opus-4.5` → `claude-sonnet-4`)
|
||||
- **Secret management**: Configurable precedence (config > env > file) with 5-minute caching
|
||||
- **Security-first**: Management routes restricted to localhost by default
|
||||
- **Automatic gzip handling**: Decompresses responses from Amp upstream
|
||||
|
||||
### What You Can Do
|
||||
|
||||
- Use Amp CLI with your Google account (Gemini 3 Pro Preview, Gemini 2.5 Pro, Gemini 2.5 Flash)
|
||||
- Use Amp CLI with your ChatGPT Plus/Pro subscription (GPT-5, GPT-5 Codex models)
|
||||
- Use Amp CLI with your Claude Pro/Max subscription (Claude Sonnet 4.5, Opus 4.1)
|
||||
- Use Amp IDE extensions (VS Code, Cursor, Windsurf, etc.) with the same proxy
|
||||
- Run multiple CLI tools (Factory + Amp) through one proxy server
|
||||
- Route unconfigured models automatically through ampcode.com
|
||||
|
||||
### Which Providers Should You Authenticate?
|
||||
|
||||
**Important**: The providers you need to authenticate depend on which models and features your installed version of Amp currently uses. Amp employs different providers for various agent modes and specialized subagents:
|
||||
|
||||
- **Smart mode**: Uses Google/Gemini models (Gemini 3 Pro)
|
||||
- **Rush mode**: Uses Anthropic/Claude models (Claude Haiku 4.5)
|
||||
- **Oracle subagent**: Uses OpenAI/GPT models (GPT-5 medium reasoning)
|
||||
- **Librarian subagent**: Uses Anthropic/Claude models (Claude Sonnet 4.5)
|
||||
- **Search subagent**: Uses Anthropic/Claude models (Claude Haiku 4.5)
|
||||
- **Review feature**: Uses Google/Gemini models (Gemini 2.5 Flash-Lite)
|
||||
|
||||
For the most current information about which models Amp uses, see the **[Amp Models Documentation](https://ampcode.com/models)**.
|
||||
|
||||
#### Fallback Behavior
|
||||
|
||||
CLIProxyAPI uses a smart fallback system:
|
||||
|
||||
1. **Provider authenticated locally** (`--login`, `--codex-login`, `--claude-login`):
|
||||
- Requests use **your OAuth subscription** (ChatGPT Plus/Pro, Claude Pro/Max, Google account)
|
||||
- You benefit from your subscription's included usage quotas
|
||||
- No Amp credits consumed
|
||||
|
||||
2. **Provider NOT authenticated locally**:
|
||||
- Requests automatically forward to **ampcode.com**
|
||||
- Uses Amp's backend provider connections
|
||||
- **Requires Amp credits** if the provider is paid (OpenAI, Anthropic paid tiers)
|
||||
- May result in errors if Amp credit balance is insufficient
|
||||
|
||||
**Recommendation**: Authenticate all providers you have subscriptions for to maximize value and minimize Amp credit usage. If you don't have subscriptions to all providers Amp uses, ensure you have sufficient Amp credits available for fallback requests.
|
||||
|
||||
## Architecture
|
||||
|
||||
### Request Flow
|
||||
|
||||
```
|
||||
Amp CLI/IDE
|
||||
↓
|
||||
├─ Provider API requests (/api/provider/{provider}/v1/...)
|
||||
│ ↓
|
||||
│ ├─ Model configured locally?
|
||||
│ │ YES → Use local OAuth tokens (OpenAI/Claude/Gemini handlers)
|
||||
│ │ NO ↓
|
||||
│ │ ├─ Model mapping configured?
|
||||
│ │ │ YES → Rewrite model → Use local handler (free)
|
||||
│ │ │ NO → Forward to ampcode.com (uses Amp credits)
|
||||
│ ↓
|
||||
│ Response
|
||||
│
|
||||
└─ Management requests (/api/auth, /api/user, /api/threads, ...)
|
||||
↓
|
||||
├─ Localhost check (security)
|
||||
↓
|
||||
└─ Reverse proxy to ampcode.com
|
||||
↓
|
||||
Response (auto-decompressed if gzipped)
|
||||
```
|
||||
|
||||
### Components
|
||||
|
||||
The Amp integration is implemented as a modular routing module (`internal/api/modules/amp/`) with these components:
|
||||
|
||||
1. **Route Aliases** (`routes.go`): Maps Amp-style paths to standard handlers
|
||||
2. **Reverse Proxy** (`proxy.go`): Forwards management requests to ampcode.com
|
||||
3. **Fallback Handler** (`fallback_handlers.go`): Routes unconfigured models to ampcode.com
|
||||
4. **Secret Management** (`secret.go`): Multi-source API key resolution with caching
|
||||
5. **Main Module** (`amp.go`): Orchestrates registration and configuration
|
||||
|
||||
## Configuration
|
||||
|
||||
### Basic Configuration
|
||||
|
||||
Add these fields to your `config.yaml`:
|
||||
|
||||
```yaml
|
||||
# Amp upstream control plane (required for management routes)
|
||||
amp-upstream-url: "https://ampcode.com"
|
||||
|
||||
# Optional: Override API key (otherwise uses env or file)
|
||||
# amp-upstream-api-key: "your-amp-api-key"
|
||||
|
||||
# Security: restrict management routes to localhost (recommended)
|
||||
amp-restrict-management-to-localhost: true
|
||||
```
|
||||
|
||||
### Model Mapping Configuration
|
||||
|
||||
When Amp CLI requests a model that you don't have access to, you can configure mappings to route those requests to alternative models that you DO have available. This avoids consuming Amp credits for models you could handle locally.
|
||||
|
||||
```yaml
|
||||
# Route unavailable models to alternatives
|
||||
amp-model-mappings:
|
||||
# Example: Route Claude Opus 4.5 requests to Claude Sonnet 4
|
||||
- from: "claude-opus-4.5"
|
||||
to: "claude-sonnet-4"
|
||||
|
||||
# Example: Route GPT-5 requests to Gemini 2.5 Pro
|
||||
- from: "gpt-5"
|
||||
to: "gemini-2.5-pro"
|
||||
|
||||
# Example: Map older model names to newer versions
|
||||
- from: "claude-3-opus-20240229"
|
||||
to: "claude-3-5-sonnet-20241022"
|
||||
```
|
||||
|
||||
**How it works:**
|
||||
|
||||
1. Amp CLI requests a model (e.g., `claude-opus-4.5`)
|
||||
2. CLIProxyAPI checks if a local provider is available for that model
|
||||
3. If not available, it checks the model mappings
|
||||
4. If a mapping exists, the request is rewritten to use the target model
|
||||
5. The request is then handled locally (free, using your OAuth subscription)
|
||||
|
||||
**Benefits:**
|
||||
- **Save Amp credits**: Use your local subscriptions instead of forwarding to ampcode.com
|
||||
- **Hot-reload**: Mappings can be updated without restarting the proxy
|
||||
- **Structured logging**: Clear logs show when mappings are applied
|
||||
|
||||
**Routing Decision Logs:**
|
||||
|
||||
The proxy logs each routing decision with structured fields:
|
||||
|
||||
```
|
||||
[AMP] Using local provider for model: gemini-2.5-pro # Local provider (free)
|
||||
[AMP] Model mapped: claude-opus-4.5 -> claude-sonnet-4 # Mapping applied (free)
|
||||
[AMP] Forwarding to ampcode.com (uses Amp credits) - model_id: gpt-5 # Fallback (costs credits)
|
||||
```
|
||||
|
||||
### Secret Resolution Precedence
|
||||
|
||||
The Amp module resolves API keys using this precedence order:
|
||||
|
||||
| Source | Key | Priority | Cache |
|
||||
|--------|-----|----------|-------|
|
||||
| Config file | `amp-upstream-api-key` | High | No |
|
||||
| Environment | `AMP_API_KEY` | Medium | No |
|
||||
| Amp secrets file | `~/.local/share/amp/secrets.json` | Low | 5 min |
|
||||
|
||||
**Recommendation**: Use the Amp secrets file (lowest precedence) for normal usage. This file is automatically managed by `amp login`.
|
||||
|
||||
### Security Settings
|
||||
|
||||
**`amp-restrict-management-to-localhost`** (default: `true`)
|
||||
|
||||
When enabled, management routes (`/api/auth`, `/api/user`, `/api/threads`, etc.) only accept connections from localhost (127.0.0.1, ::1). This prevents:
|
||||
- Drive-by browser attacks
|
||||
- Remote access to management endpoints
|
||||
- CORS-based attacks
|
||||
- Header spoofing attacks (e.g., `X-Forwarded-For: 127.0.0.1`)
|
||||
|
||||
#### How It Works
|
||||
|
||||
This restriction uses the **actual TCP connection address** (`RemoteAddr`), not HTTP headers like `X-Forwarded-For`. This prevents header spoofing attacks but has important implications:
|
||||
|
||||
- ✅ **Works for direct connections**: Running CLIProxyAPI directly on your machine or server
|
||||
- ⚠️ **May not work behind reverse proxies**: If deploying behind nginx, Cloudflare, or other proxies, the connection will appear to come from the proxy's IP, not localhost
|
||||
|
||||
#### Reverse Proxy Deployments
|
||||
|
||||
If you need to run CLIProxyAPI behind a reverse proxy (nginx, Caddy, Cloudflare Tunnel, etc.):
|
||||
|
||||
1. **Disable the localhost restriction**:
|
||||
```yaml
|
||||
amp-restrict-management-to-localhost: false
|
||||
```
|
||||
|
||||
2. **Use alternative security measures**:
|
||||
- Firewall rules restricting access to management routes
|
||||
- Proxy-level authentication (HTTP Basic Auth, OAuth)
|
||||
- Network-level isolation (VPN, Tailscale, Cloudflare Access)
|
||||
- Bind CLIProxyAPI to `127.0.0.1` only and access via SSH tunnel
|
||||
|
||||
3. **Example nginx configuration** (blocks external access to management routes):
|
||||
```nginx
|
||||
location /api/auth { deny all; }
|
||||
location /api/user { deny all; }
|
||||
location /api/threads { deny all; }
|
||||
location /api/internal { deny all; }
|
||||
```
|
||||
|
||||
**Important**: Only disable `amp-restrict-management-to-localhost` if you understand the security implications and have other protections in place.
|
||||
|
||||
## Setup
|
||||
|
||||
### 1. Configure CLIProxyAPI
|
||||
|
||||
Create or edit `config.yaml`:
|
||||
|
||||
```yaml
|
||||
port: 8317
|
||||
auth-dir: "~/.cli-proxy-api"
|
||||
|
||||
# Amp integration
|
||||
amp-upstream-url: "https://ampcode.com"
|
||||
amp-restrict-management-to-localhost: true
|
||||
|
||||
# Other standard settings...
|
||||
debug: false
|
||||
logging-to-file: true
|
||||
```
|
||||
|
||||
### 2. Authenticate with Providers
|
||||
|
||||
Run OAuth login for the providers you want to use:
|
||||
|
||||
**Google Account (Gemini 2.5 Pro, Gemini 2.5 Flash, Gemini 3 Pro Preview):**
|
||||
```bash
|
||||
./cli-proxy-api --login
|
||||
```
|
||||
|
||||
**ChatGPT Plus/Pro (GPT-5, GPT-5 Codex):**
|
||||
```bash
|
||||
./cli-proxy-api --codex-login
|
||||
```
|
||||
|
||||
**Claude Pro/Max (Claude Sonnet 4.5, Opus 4.1):**
|
||||
```bash
|
||||
./cli-proxy-api --claude-login
|
||||
```
|
||||
|
||||
Tokens are saved to:
|
||||
- Gemini: `~/.cli-proxy-api/gemini-<email>.json`
|
||||
- OpenAI Codex: `~/.cli-proxy-api/codex-<email>.json`
|
||||
- Claude: `~/.cli-proxy-api/claude-<email>.json`
|
||||
|
||||
### 3. Start the Proxy
|
||||
|
||||
```bash
|
||||
./cli-proxy-api --config config.yaml
|
||||
```
|
||||
|
||||
Or run in background with tmux (recommended for remote servers):
|
||||
|
||||
```bash
|
||||
tmux new-session -d -s proxy "./cli-proxy-api --config config.yaml"
|
||||
```
|
||||
|
||||
### 4. Configure Amp CLI
|
||||
|
||||
#### Option A: Settings File
|
||||
|
||||
Edit `~/.config/amp/settings.json`:
|
||||
|
||||
```json
|
||||
{
|
||||
"amp.url": "http://localhost:8317"
|
||||
}
|
||||
```
|
||||
|
||||
#### Option B: Environment Variable
|
||||
|
||||
```bash
|
||||
export AMP_URL=http://localhost:8317
|
||||
```
|
||||
|
||||
### 5. Login and Use Amp
|
||||
|
||||
Login through the proxy (proxied to ampcode.com):
|
||||
|
||||
```bash
|
||||
amp login
|
||||
```
|
||||
|
||||
Use Amp as normal:
|
||||
|
||||
```bash
|
||||
amp "Write a hello world program in Python"
|
||||
```
|
||||
|
||||
### 6. (Optional) Configure Amp IDE Extension
|
||||
|
||||
The proxy also works with Amp IDE extensions for VS Code, Cursor, Windsurf, etc.
|
||||
|
||||
1. Open Amp extension settings in your IDE
|
||||
2. Set **Amp URL** to `http://localhost:8317`
|
||||
3. Login with your Amp account
|
||||
4. Start using Amp in your IDE
|
||||
|
||||
Both CLI and IDE can use the proxy simultaneously.
|
||||
|
||||
## Usage
|
||||
|
||||
### Supported Routes
|
||||
|
||||
#### Provider Aliases (Always Available)
|
||||
|
||||
These routes work even without `amp-upstream-url` configured:
|
||||
|
||||
- `/api/provider/openai/v1/chat/completions`
|
||||
- `/api/provider/openai/v1/responses`
|
||||
- `/api/provider/anthropic/v1/messages`
|
||||
- `/api/provider/google/v1beta/models/:action`
|
||||
|
||||
Amp CLI calls these routes with your OAuth-authenticated models configured in CLIProxyAPI.
|
||||
|
||||
#### Management Routes (Require `amp-upstream-url`)
|
||||
|
||||
These routes are proxied to ampcode.com:
|
||||
|
||||
- `/api/auth` - Authentication
|
||||
- `/api/user` - User profile
|
||||
- `/api/meta` - Metadata
|
||||
- `/api/threads` - Conversation threads
|
||||
- `/api/telemetry` - Usage telemetry
|
||||
- `/api/internal` - Internal APIs
|
||||
|
||||
**Security**: Restricted to localhost by default.
|
||||
|
||||
### Model Fallback Behavior
|
||||
|
||||
When Amp requests a model:
|
||||
|
||||
1. **Check local configuration**: Does CLIProxyAPI have OAuth tokens for this model's provider?
|
||||
2. **If YES**: Route to local handler (use your OAuth subscription)
|
||||
3. **If NO**: Check if a model mapping exists
|
||||
4. **If mapping exists**: Rewrite request to mapped model → Route to local handler (free)
|
||||
5. **If no mapping**: Forward to ampcode.com (uses Amp credits)
|
||||
|
||||
This enables seamless mixed usage:
|
||||
- Models you've configured (Gemini, ChatGPT, Claude) → Your OAuth subscriptions
|
||||
- Models with mappings configured → Routed to alternative local models (free)
|
||||
- Models you haven't configured and have no mapping → Amp's default providers (uses credits)
|
||||
|
||||
### Example API Calls
|
||||
|
||||
**Chat completion with local OAuth:**
|
||||
```bash
|
||||
curl http://localhost:8317/api/provider/openai/v1/chat/completions \
|
||||
-H "Content-Type: application/json" \
|
||||
-d '{
|
||||
"model": "gpt-5",
|
||||
"messages": [{"role": "user", "content": "Hello"}]
|
||||
}'
|
||||
```
|
||||
|
||||
**Management endpoint (localhost only):**
|
||||
```bash
|
||||
curl http://localhost:8317/api/user
|
||||
```
|
||||
|
||||
## Troubleshooting
|
||||
|
||||
### Common Issues
|
||||
|
||||
| Symptom | Likely Cause | Fix |
|
||||
|---------|--------------|-----|
|
||||
| 404 on `/api/provider/...` | Incorrect route path | Ensure exact path: `/api/provider/{provider}/v1...` |
|
||||
| 403 on `/api/user` | Non-localhost request | Run from same machine or disable `amp-restrict-management-to-localhost` (not recommended) |
|
||||
| 401/403 from provider | Missing/expired OAuth | Re-run `--codex-login` or `--claude-login` |
|
||||
| Amp gzip errors | Response decompression issue | Update to latest build; auto-decompression should handle this |
|
||||
| Models not using proxy | Wrong Amp URL | Verify `amp.url` setting or `AMP_URL` environment variable |
|
||||
| CORS errors | Protected management endpoint | Use CLI/terminal, not browser |
|
||||
|
||||
### Diagnostics
|
||||
|
||||
**Check proxy logs:**
|
||||
```bash
|
||||
# If logging-to-file: true
|
||||
tail -f logs/requests.log
|
||||
|
||||
# If running in tmux
|
||||
tmux attach-session -t proxy
|
||||
```
|
||||
|
||||
**Enable debug mode** (temporarily):
|
||||
```yaml
|
||||
debug: true
|
||||
```
|
||||
|
||||
**Test basic connectivity:**
|
||||
```bash
|
||||
# Check if proxy is running
|
||||
curl http://localhost:8317/v1/models
|
||||
|
||||
# Check Amp-specific route
|
||||
curl http://localhost:8317/api/provider/openai/v1/models
|
||||
```
|
||||
|
||||
**Verify Amp configuration:**
|
||||
```bash
|
||||
# Check if Amp is using proxy
|
||||
amp config get amp.url
|
||||
|
||||
# Or check environment
|
||||
echo $AMP_URL
|
||||
```
|
||||
|
||||
### Security Checklist
|
||||
|
||||
- ✅ Keep `amp-restrict-management-to-localhost: true` (default)
|
||||
- ✅ Don't expose proxy publicly (bind to localhost or use firewall/VPN)
|
||||
- ✅ Use the Amp secrets file (`~/.local/share/amp/secrets.json`) managed by `amp login`
|
||||
- ✅ Rotate OAuth tokens periodically by re-running login commands
|
||||
- ✅ Store config and auth-dir on encrypted disk if handling sensitive data
|
||||
- ✅ Keep proxy binary up to date for security fixes
|
||||
|
||||
## Additional Resources
|
||||
|
||||
- [CLIProxyAPI Main Documentation](https://help.router-for.me/)
|
||||
- [Amp CLI Official Manual](https://ampcode.com/manual)
|
||||
- [Management API Reference](https://help.router-for.me/management/api)
|
||||
- [SDK Documentation](sdk-usage.md)
|
||||
|
||||
## Disclaimer
|
||||
|
||||
This integration is for personal/educational use. Using reverse proxies or alternate API bases may violate provider Terms of Service. You are solely responsible for how you use this software. Accounts may be rate-limited, locked, or banned. No warranties. Use at your own risk.
|
||||
@@ -1,392 +0,0 @@
|
||||
# Amp CLI 集成指南
|
||||
|
||||
本指南说明如何在 Amp CLI 和 Amp IDE 扩展中使用 CLIProxyAPI,通过 OAuth 让你能够把已有的 Google/ChatGPT/Claude 订阅与 Amp 的 CLI 一起使用。
|
||||
|
||||
## 目录
|
||||
|
||||
- [概述](#概述)
|
||||
- [应该认证哪些服务提供商?](#应该认证哪些服务提供商)
|
||||
- [架构](#架构)
|
||||
- [配置](#配置)
|
||||
- [设置](#设置)
|
||||
- [用法](#用法)
|
||||
- [故障排查](#故障排查)
|
||||
|
||||
## 概述
|
||||
|
||||
Amp CLI 集成为 Amp 的 API 模式添加了专用路由,同时保持与现有 CLIProxyAPI 功能的完全兼容。这样你可以在同一个代理服务器上同时使用传统 CLIProxyAPI 功能和 Amp CLI。
|
||||
|
||||
### 主要特性
|
||||
|
||||
- **提供者路由别名**:将 Amp 的 `/api/provider/{provider}/v1...` 路径映射到 CLIProxyAPI 处理器
|
||||
- **管理代理**:将 OAuth 和账号管理请求转发到 Amp 控制平面
|
||||
- **智能回退**:自动将未配置的模型路由到 ampcode.com
|
||||
- **密钥管理**:可配置优先级(配置 > 环境变量 > 文件),缓存 5 分钟
|
||||
- **安全优先**:管理路由默认限制为 localhost
|
||||
- **自动 gzip 处理**:自动解压来自 Amp 上游的响应
|
||||
|
||||
### 你可以做什么
|
||||
|
||||
- 使用 Amp CLI 搭配你的 Google 账号(Gemini 3 Pro Preview、Gemini 2.5 Pro、Gemini 2.5 Flash)
|
||||
- 使用 Amp CLI 搭配你的 ChatGPT Plus/Pro 订阅(GPT-5、GPT-5 Codex 模型)
|
||||
- 使用 Amp CLI 搭配你的 Claude Pro/Max 订阅(Claude Sonnet 4.5、Opus 4.1)
|
||||
- 将 Amp IDE 扩展(VS Code、Cursor、Windsurf 等)与同一个代理一起使用
|
||||
- 通过一个代理同时运行多个 CLI 工具(Factory + Amp)
|
||||
- 将未配置的模型自动路由到 ampcode.com
|
||||
|
||||
### 应该认证哪些服务提供商?
|
||||
|
||||
**重要**:需要认证的提供商取决于你安装的 Amp 版本当前使用的模型和功能。Amp 的不同智能模式和子代理会使用不同的提供商:
|
||||
|
||||
- **Smart 模式**:使用 Google/Gemini 模型(Gemini 3 Pro)
|
||||
- **Rush 模式**:使用 Anthropic/Claude 模型(Claude Haiku 4.5)
|
||||
- **Oracle 子代理**:使用 OpenAI/GPT 模型(GPT-5 medium reasoning)
|
||||
- **Librarian 子代理**:使用 Anthropic/Claude 模型(Claude Sonnet 4.5)
|
||||
- **Search 子代理**:使用 Anthropic/Claude 模型(Claude Haiku 4.5)
|
||||
- **Review 功能**:使用 Google/Gemini 模型(Gemini 2.5 Flash-Lite)
|
||||
|
||||
有关 Amp 当前使用哪些模型的最新信息,请参阅 **[Amp 模型文档](https://ampcode.com/models)**。
|
||||
|
||||
#### 回退行为
|
||||
|
||||
CLIProxyAPI 采用智能回退机制:
|
||||
|
||||
1. **本地已认证提供商**(`--login`、`--codex-login`、`--claude-login`):
|
||||
- 请求使用**你的 OAuth 订阅**(ChatGPT Plus/Pro、Claude Pro/Max、Google 账号)
|
||||
- 享受订阅自带的额度
|
||||
- 不消耗 Amp 额度
|
||||
|
||||
2. **本地未认证提供商**:
|
||||
- 请求自动转发到 **ampcode.com**
|
||||
- 使用 Amp 的后端提供商连接
|
||||
- 如果提供商是付费的(OpenAI、Anthropic 付费档),**需要消耗 Amp 额度**
|
||||
- 若 Amp 额度不足,可能产生错误
|
||||
|
||||
**建议**:对你有订阅的所有提供商都进行认证,以最大化价值并尽量减少 Amp 额度消耗。如果没有覆盖 Amp 使用的全部提供商,请确保为回退请求准备足够的 Amp 额度。
|
||||
|
||||
## 架构
|
||||
|
||||
### 请求流
|
||||
|
||||
```
|
||||
Amp CLI/IDE
|
||||
↓
|
||||
├─ Provider API requests (/api/provider/{provider}/v1/...)
|
||||
│ ↓
|
||||
│ ├─ Model configured locally?
|
||||
│ │ YES → Use local OAuth tokens (OpenAI/Claude/Gemini handlers)
|
||||
│ │ NO → Forward to ampcode.com (reverse proxy)
|
||||
│ ↓
|
||||
│ Response
|
||||
│
|
||||
└─ Management requests (/api/auth, /api/user, /api/threads, ...)
|
||||
↓
|
||||
├─ Localhost check (security)
|
||||
↓
|
||||
└─ Reverse proxy to ampcode.com
|
||||
↓
|
||||
Response (auto-decompressed if gzipped)
|
||||
```
|
||||
|
||||
### 组件
|
||||
|
||||
Amp 集成以模块化路由模块(`internal/api/modules/amp/`)实现,包含以下组件:
|
||||
|
||||
1. **路由别名**(`routes.go`):将 Amp 风格的路径映射到标准处理器
|
||||
2. **反向代理**(`proxy.go`):将管理请求转发到 ampcode.com
|
||||
3. **回退处理器**(`fallback_handlers.go`):将未配置的模型路由到 ampcode.com
|
||||
4. **密钥管理**(`secret.go`):多来源 API 密钥解析并带缓存
|
||||
5. **主模块**(`amp.go`):负责注册和配置
|
||||
|
||||
## 配置
|
||||
|
||||
### 基础配置
|
||||
|
||||
在 `config.yaml` 中新增以下字段:
|
||||
|
||||
```yaml
|
||||
# Amp 上游控制平面(管理路由必需)
|
||||
amp-upstream-url: "https://ampcode.com"
|
||||
|
||||
# 可选:覆盖 API key(否则使用环境变量或文件)
|
||||
# amp-upstream-api-key: "your-amp-api-key"
|
||||
|
||||
# 安全性:将管理路由限制为 localhost(推荐)
|
||||
amp-restrict-management-to-localhost: true
|
||||
```
|
||||
|
||||
### 密钥解析优先级
|
||||
|
||||
Amp 模块以如下优先级解析 API key:
|
||||
|
||||
| 来源 | 键名 | 优先级 | 缓存 |
|
||||
|------|------|--------|------|
|
||||
| 配置文件 | `amp-upstream-api-key` | 高 | 无 |
|
||||
| 环境变量 | `AMP_API_KEY` | 中 | 无 |
|
||||
| Amp 密钥文件 | `~/.local/share/amp/secrets.json` | 低 | 5 分钟 |
|
||||
|
||||
**建议**:日常使用时采用 Amp 密钥文件(最低优先级)。该文件由 `amp login` 自动管理。
|
||||
|
||||
### 安全设置
|
||||
|
||||
**`amp-restrict-management-to-localhost`**(默认:`true`)
|
||||
|
||||
启用后,管理路由(`/api/auth`、`/api/user`、`/api/threads` 等)只接受来自 localhost(127.0.0.1、::1)的连接,可防止:
|
||||
- 浏览器探测式攻击
|
||||
- 对管理端点的远程访问
|
||||
- 基于 CORS 的攻击
|
||||
- 伪造头攻击(例如 `X-Forwarded-For: 127.0.0.1`)
|
||||
|
||||
#### 工作原理
|
||||
|
||||
此限制使用**实际的 TCP 连接地址**(`RemoteAddr`),而非 `X-Forwarded-For` 等 HTTP 头,能防止头部伪造,但有重要影响:
|
||||
|
||||
- ✅ **直接连接可用**:在本机或服务器直接运行 CLIProxyAPI 时适用
|
||||
- ⚠️ **可能不适用于反向代理场景**:部署在 nginx、Cloudflare 等代理后,请求源会显示为代理 IP 而非 localhost
|
||||
|
||||
#### 反向代理部署
|
||||
|
||||
若需要在反向代理(nginx、Caddy、Cloudflare Tunnel 等)后运行 CLIProxyAPI:
|
||||
|
||||
1. **关闭 localhost 限制**:
|
||||
```yaml
|
||||
amp-restrict-management-to-localhost: false
|
||||
```
|
||||
|
||||
2. **使用替代安全措施**:
|
||||
- 防火墙规则限制管理路由访问
|
||||
- 代理层认证(HTTP Basic Auth、OAuth)
|
||||
- 网络隔离(VPN、Tailscale、Cloudflare Access)
|
||||
- 将 CLIProxyAPI 仅绑定 `127.0.0.1`,并通过 SSH 隧道访问
|
||||
|
||||
3. **nginx 示例配置**(阻止外部访问管理路由):
|
||||
```nginx
|
||||
location /api/auth { deny all; }
|
||||
location /api/user { deny all; }
|
||||
location /api/threads { deny all; }
|
||||
location /api/internal { deny all; }
|
||||
```
|
||||
|
||||
**重要**:只有在理解安全影响并已采取其他防护措施时,才关闭 `amp-restrict-management-to-localhost`。
|
||||
|
||||
## 设置
|
||||
|
||||
### 1. 配置 CLIProxyAPI
|
||||
|
||||
创建或编辑 `config.yaml`:
|
||||
|
||||
```yaml
|
||||
port: 8317
|
||||
auth-dir: "~/.cli-proxy-api"
|
||||
|
||||
# Amp 集成
|
||||
amp-upstream-url: "https://ampcode.com"
|
||||
amp-restrict-management-to-localhost: true
|
||||
|
||||
# 其他常规设置...
|
||||
debug: false
|
||||
logging-to-file: true
|
||||
```
|
||||
|
||||
### 2. 认证提供商
|
||||
|
||||
为要使用的提供商执行 OAuth 登录:
|
||||
|
||||
**Google 账号(Gemini 2.5 Pro、Gemini 2.5 Flash、Gemini 3 Pro Preview):**
|
||||
```bash
|
||||
./cli-proxy-api --login
|
||||
```
|
||||
|
||||
**ChatGPT Plus/Pro(GPT-5、GPT-5 Codex):**
|
||||
```bash
|
||||
./cli-proxy-api --codex-login
|
||||
```
|
||||
|
||||
**Claude Pro/Max(Claude Sonnet 4.5、Opus 4.1):**
|
||||
```bash
|
||||
./cli-proxy-api --claude-login
|
||||
```
|
||||
|
||||
令牌会保存到:
|
||||
- Gemini: `~/.cli-proxy-api/gemini-<email>.json`
|
||||
- OpenAI Codex: `~/.cli-proxy-api/codex-<email>.json`
|
||||
- Claude: `~/.cli-proxy-api/claude-<email>.json`
|
||||
|
||||
### 3. 启动代理
|
||||
|
||||
```bash
|
||||
./cli-proxy-api --config config.yaml
|
||||
```
|
||||
|
||||
或使用 tmux 在后台运行(推荐用于远程服务器):
|
||||
|
||||
```bash
|
||||
tmux new-session -d -s proxy "./cli-proxy-api --config config.yaml"
|
||||
```
|
||||
|
||||
### 4. 配置 Amp CLI
|
||||
|
||||
#### 方案 A:配置文件
|
||||
|
||||
编辑 `~/.config/amp/settings.json`:
|
||||
|
||||
```json
|
||||
{
|
||||
"amp.url": "http://localhost:8317"
|
||||
}
|
||||
```
|
||||
|
||||
#### 方案 B:环境变量
|
||||
|
||||
```bash
|
||||
export AMP_URL=http://localhost:8317
|
||||
```
|
||||
|
||||
### 5. 登录并使用 Amp
|
||||
|
||||
通过代理登录(请求会被代理到 ampcode.com):
|
||||
|
||||
```bash
|
||||
amp login
|
||||
```
|
||||
|
||||
像平常一样使用 Amp:
|
||||
|
||||
```bash
|
||||
amp "Write a hello world program in Python"
|
||||
```
|
||||
|
||||
### 6. (可选)配置 Amp IDE 扩展
|
||||
|
||||
该代理同样适用于 VS Code、Cursor、Windsurf 等 Amp IDE 扩展。
|
||||
|
||||
1. 在 IDE 中打开 Amp 扩展设置
|
||||
2. 将 **Amp URL** 设置为 `http://localhost:8317`
|
||||
3. 用你的 Amp 账号登录
|
||||
4. 在 IDE 中开始使用 Amp
|
||||
|
||||
CLI 和 IDE 可同时使用该代理。
|
||||
|
||||
## 用法
|
||||
|
||||
### 支持的路由
|
||||
|
||||
#### 提供商别名(始终可用)
|
||||
|
||||
这些路由即使未配置 `amp-upstream-url` 也可使用:
|
||||
|
||||
- `/api/provider/openai/v1/chat/completions`
|
||||
- `/api/provider/openai/v1/responses`
|
||||
- `/api/provider/anthropic/v1/messages`
|
||||
- `/api/provider/google/v1beta/models/:action`
|
||||
|
||||
Amp CLI 会使用你在 CLIProxyAPI 中通过 OAuth 认证的模型来调用这些路由。
|
||||
|
||||
#### 管理路由(需要 `amp-upstream-url`)
|
||||
|
||||
这些路由会被代理到 ampcode.com:
|
||||
|
||||
- `/api/auth` - 认证
|
||||
- `/api/user` - 用户资料
|
||||
- `/api/meta` - 元数据
|
||||
- `/api/threads` - 会话线程
|
||||
- `/api/telemetry` - 使用遥测
|
||||
- `/api/internal` - 内部 API
|
||||
|
||||
**安全性**:默认限制为 localhost。
|
||||
|
||||
### 模型回退行为
|
||||
|
||||
当 Amp 请求模型时:
|
||||
|
||||
1. **检查本地配置**:CLIProxyAPI 是否有该模型提供商的 OAuth 令牌?
|
||||
2. **如果有**:路由到本地处理器(使用你的 OAuth 订阅)
|
||||
3. **如果没有**:转发到 ampcode.com(使用 Amp 的默认路由)
|
||||
|
||||
这实现了无缝混用:
|
||||
- 你已配置的模型(Gemini、ChatGPT、Claude)→ 你的 OAuth 订阅
|
||||
- 未配置的模型 → Amp 的默认提供商
|
||||
|
||||
### 示例 API 调用
|
||||
|
||||
**使用本地 OAuth 的聊天补全:**
|
||||
```bash
|
||||
curl http://localhost:8317/api/provider/openai/v1/chat/completions \
|
||||
-H "Content-Type: application/json" \
|
||||
-d '{
|
||||
"model": "gpt-5",
|
||||
"messages": [{"role": "user", "content": "Hello"}]
|
||||
}'
|
||||
```
|
||||
|
||||
**管理端点(仅限 localhost):**
|
||||
```bash
|
||||
curl http://localhost:8317/api/user
|
||||
```
|
||||
|
||||
## 故障排查
|
||||
|
||||
### 常见问题
|
||||
|
||||
| 症状 | 可能原因 | 解决方案 |
|
||||
|------|----------|----------|
|
||||
| `/api/provider/...` 返回 404 | 路径错误 | 确保路径准确:`/api/provider/{provider}/v1...` |
|
||||
| `/api/user` 返回 403 | 非 localhost 请求 | 在同一机器上访问,或关闭 `amp-restrict-management-to-localhost`(不推荐) |
|
||||
| 提供商返回 401/403 | OAuth 缺失或过期 | 重新运行 `--codex-login` 或 `--claude-login` |
|
||||
| Amp gzip 错误 | 响应解压问题 | 更新到最新构建;自动解压应能处理 |
|
||||
| 模型未走代理 | Amp URL 设置错误 | 检查 `amp.url` 设置或 `AMP_URL` 环境变量 |
|
||||
| CORS 错误 | 受保护的管理端点 | 使用 CLI/终端而非浏览器 |
|
||||
|
||||
### 诊断
|
||||
|
||||
**查看代理日志:**
|
||||
```bash
|
||||
# 若 logging-to-file: true
|
||||
tail -f logs/requests.log
|
||||
|
||||
# 若运行在 tmux 中
|
||||
tmux attach-session -t proxy
|
||||
```
|
||||
|
||||
**临时开启调试模式:**
|
||||
```yaml
|
||||
debug: true
|
||||
```
|
||||
|
||||
**测试基础连通性:**
|
||||
```bash
|
||||
# 检查代理是否运行
|
||||
curl http://localhost:8317/v1/models
|
||||
|
||||
# 检查 Amp 特定路由
|
||||
curl http://localhost:8317/api/provider/openai/v1/models
|
||||
```
|
||||
|
||||
**验证 Amp 配置:**
|
||||
```bash
|
||||
# 检查 Amp 是否使用代理
|
||||
amp config get amp.url
|
||||
|
||||
# 或检查环境变量
|
||||
echo $AMP_URL
|
||||
```
|
||||
|
||||
### 安全清单
|
||||
|
||||
- ✅ 保持 `amp-restrict-management-to-localhost: true`(默认)
|
||||
- ✅ 不要将代理暴露到公共网络(绑定到 localhost 或使用防火墙/VPN)
|
||||
- ✅ 使用 `amp login` 管理的 Amp 密钥文件(`~/.local/share/amp/secrets.json`)
|
||||
- ✅ 定期重新登录轮换 OAuth 令牌
|
||||
- ✅ 若处理敏感数据,使用加密磁盘存储配置和 auth-dir
|
||||
- ✅ 保持代理二进制为最新版本以获取安全修复
|
||||
|
||||
## 其他资源
|
||||
|
||||
- [CLIProxyAPI 主文档](https://help.router-for.me/)
|
||||
- [Amp CLI 官方手册](https://ampcode.com/manual)
|
||||
- [管理 API 参考](https://help.router-for.me/management/api)
|
||||
- [SDK 文档](sdk-usage.md)
|
||||
|
||||
## 免责声明
|
||||
|
||||
此集成仅用于个人或教育用途。使用反向代理或替代 API 基址可能违反提供商的服务条款。你需要对自己的使用方式负责。账号可能会被限速、锁定或封禁。软件不附带任何保证,使用风险自负。
|
||||
@@ -23,13 +23,13 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/api"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/logging"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/sdk/api"
|
||||
sdkAuth "github.com/router-for-me/CLIProxyAPI/v6/sdk/auth"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy"
|
||||
coreauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
|
||||
clipexec "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/sdk/config"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/sdk/logging"
|
||||
sdktr "github.com/router-for-me/CLIProxyAPI/v6/sdk/translator"
|
||||
)
|
||||
|
||||
|
||||
10
go.mod
10
go.mod
@@ -18,8 +18,8 @@ require (
|
||||
github.com/tidwall/gjson v1.18.0
|
||||
github.com/tidwall/sjson v1.2.5
|
||||
github.com/tiktoken-go/tokenizer v0.7.0
|
||||
golang.org/x/crypto v0.43.0
|
||||
golang.org/x/net v0.46.0
|
||||
golang.org/x/crypto v0.45.0
|
||||
golang.org/x/net v0.47.0
|
||||
golang.org/x/oauth2 v0.30.0
|
||||
gopkg.in/natefinch/lumberjack.v2 v2.2.1
|
||||
gopkg.in/yaml.v3 v3.0.1
|
||||
@@ -68,9 +68,9 @@ require (
|
||||
github.com/twitchyliquid64/golang-asm v0.15.1 // indirect
|
||||
github.com/ugorji/go/codec v1.2.12 // indirect
|
||||
golang.org/x/arch v0.8.0 // indirect
|
||||
golang.org/x/sync v0.17.0 // indirect
|
||||
golang.org/x/sys v0.37.0 // indirect
|
||||
golang.org/x/text v0.30.0 // indirect
|
||||
golang.org/x/sync v0.18.0 // indirect
|
||||
golang.org/x/sys v0.38.0 // indirect
|
||||
golang.org/x/text v0.31.0 // indirect
|
||||
google.golang.org/protobuf v1.34.1 // indirect
|
||||
gopkg.in/ini.v1 v1.67.0 // indirect
|
||||
)
|
||||
|
||||
24
go.sum
24
go.sum
@@ -160,22 +160,22 @@ github.com/ugorji/go/codec v1.2.12/go.mod h1:UNopzCgEMSXjBc6AOMqYvWC1ktqTAfzJZUZ
|
||||
golang.org/x/arch v0.0.0-20210923205945-b76863e36670/go.mod h1:5om86z9Hs0C8fWVUuoMHwpExlXzs5Tkyp9hOrfG7pp8=
|
||||
golang.org/x/arch v0.8.0 h1:3wRIsP3pM4yUptoR96otTUOXI367OS0+c9eeRi9doIc=
|
||||
golang.org/x/arch v0.8.0/go.mod h1:FEVrYAQjsQXMVJ1nsMoVVXPZg6p2JE2mx8psSWTDQys=
|
||||
golang.org/x/crypto v0.43.0 h1:dduJYIi3A3KOfdGOHX8AVZ/jGiyPa3IbBozJ5kNuE04=
|
||||
golang.org/x/crypto v0.43.0/go.mod h1:BFbav4mRNlXJL4wNeejLpWxB7wMbc79PdRGhWKncxR0=
|
||||
golang.org/x/net v0.46.0 h1:giFlY12I07fugqwPuWJi68oOnpfqFnJIJzaIIm2JVV4=
|
||||
golang.org/x/net v0.46.0/go.mod h1:Q9BGdFy1y4nkUwiLvT5qtyhAnEHgnQ/zd8PfU6nc210=
|
||||
golang.org/x/crypto v0.45.0 h1:jMBrvKuj23MTlT0bQEOBcAE0mjg8mK9RXFhRH6nyF3Q=
|
||||
golang.org/x/crypto v0.45.0/go.mod h1:XTGrrkGJve7CYK7J8PEww4aY7gM3qMCElcJQ8n8JdX4=
|
||||
golang.org/x/net v0.47.0 h1:Mx+4dIFzqraBXUugkia1OOvlD6LemFo1ALMHjrXDOhY=
|
||||
golang.org/x/net v0.47.0/go.mod h1:/jNxtkgq5yWUGYkaZGqo27cfGZ1c5Nen03aYrrKpVRU=
|
||||
golang.org/x/oauth2 v0.30.0 h1:dnDm7JmhM45NNpd8FDDeLhK6FwqbOf4MLCM9zb1BOHI=
|
||||
golang.org/x/oauth2 v0.30.0/go.mod h1:B++QgG3ZKulg6sRPGD/mqlHQs5rB3Ml9erfeDY7xKlU=
|
||||
golang.org/x/sync v0.17.0 h1:l60nONMj9l5drqw6jlhIELNv9I0A4OFgRsG9k2oT9Ug=
|
||||
golang.org/x/sync v0.17.0/go.mod h1:9KTHXmSnoGruLpwFjVSX0lNNA75CykiMECbovNTZqGI=
|
||||
golang.org/x/sync v0.18.0 h1:kr88TuHDroi+UVf+0hZnirlk8o8T+4MrK6mr60WkH/I=
|
||||
golang.org/x/sync v0.18.0/go.mod h1:9KTHXmSnoGruLpwFjVSX0lNNA75CykiMECbovNTZqGI=
|
||||
golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.37.0 h1:fdNQudmxPjkdUTPnLn5mdQv7Zwvbvpaxqs831goi9kQ=
|
||||
golang.org/x/sys v0.37.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks=
|
||||
golang.org/x/term v0.36.0 h1:zMPR+aF8gfksFprF/Nc/rd1wRS1EI6nDBGyWAvDzx2Q=
|
||||
golang.org/x/term v0.36.0/go.mod h1:Qu394IJq6V6dCBRgwqshf3mPF85AqzYEzofzRdZkWss=
|
||||
golang.org/x/text v0.30.0 h1:yznKA/E9zq54KzlzBEAWn1NXSQ8DIp/NYMy88xJjl4k=
|
||||
golang.org/x/text v0.30.0/go.mod h1:yDdHFIX9t+tORqspjENWgzaCVXgk0yYnYuSZ8UzzBVM=
|
||||
golang.org/x/sys v0.38.0 h1:3yZWxaJjBmCWXqhN1qh02AkOnCQ1poK6oF+a7xWL6Gc=
|
||||
golang.org/x/sys v0.38.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks=
|
||||
golang.org/x/term v0.37.0 h1:8EGAD0qCmHYZg6J17DvsMy9/wJ7/D/4pV/wfnld5lTU=
|
||||
golang.org/x/term v0.37.0/go.mod h1:5pB4lxRNYYVZuTLmy8oR2BH8dflOR+IbTYFD8fi3254=
|
||||
golang.org/x/text v0.31.0 h1:aC8ghyu4JhP8VojJ2lEHBnochRno1sgL6nEi9WGFGMM=
|
||||
golang.org/x/text v0.31.0/go.mod h1:tKRAlv61yKIjGGHX/4tP1LTbc13YSec1pxVEWXzfoeM=
|
||||
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543 h1:E7g+9GITq07hpfrRu66IVDexMakfv52eLZ2CXBWiKr4=
|
||||
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
|
||||
google.golang.org/protobuf v1.34.1 h1:9ddQBjfCyZPOHPUiPxpYESBLc+T8P3E+Vo4IbKZgFWg=
|
||||
|
||||
@@ -26,6 +26,7 @@ import (
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/auth/qwen"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/interfaces"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/misc"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/registry"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
|
||||
sdkAuth "github.com/router-for-me/CLIProxyAPI/v6/sdk/auth"
|
||||
coreauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
|
||||
@@ -35,10 +36,6 @@ import (
|
||||
"golang.org/x/oauth2/google"
|
||||
)
|
||||
|
||||
var (
|
||||
oauthStatus = make(map[string]string)
|
||||
)
|
||||
|
||||
var lastRefreshKeys = []string{"last_refresh", "lastRefresh", "last_refreshed_at", "lastRefreshedAt"}
|
||||
|
||||
const (
|
||||
@@ -266,6 +263,54 @@ func (h *Handler) ListAuthFiles(c *gin.Context) {
|
||||
c.JSON(200, gin.H{"files": files})
|
||||
}
|
||||
|
||||
// GetAuthFileModels returns the models supported by a specific auth file
|
||||
func (h *Handler) GetAuthFileModels(c *gin.Context) {
|
||||
name := c.Query("name")
|
||||
if name == "" {
|
||||
c.JSON(400, gin.H{"error": "name is required"})
|
||||
return
|
||||
}
|
||||
|
||||
// Try to find auth ID via authManager
|
||||
var authID string
|
||||
if h.authManager != nil {
|
||||
auths := h.authManager.List()
|
||||
for _, auth := range auths {
|
||||
if auth.FileName == name || auth.ID == name {
|
||||
authID = auth.ID
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if authID == "" {
|
||||
authID = name // fallback to filename as ID
|
||||
}
|
||||
|
||||
// Get models from registry
|
||||
reg := registry.GetGlobalRegistry()
|
||||
models := reg.GetModelsForClient(authID)
|
||||
|
||||
result := make([]gin.H, 0, len(models))
|
||||
for _, m := range models {
|
||||
entry := gin.H{
|
||||
"id": m.ID,
|
||||
}
|
||||
if m.DisplayName != "" {
|
||||
entry["display_name"] = m.DisplayName
|
||||
}
|
||||
if m.Type != "" {
|
||||
entry["type"] = m.Type
|
||||
}
|
||||
if m.OwnedBy != "" {
|
||||
entry["owned_by"] = m.OwnedBy
|
||||
}
|
||||
result = append(result, entry)
|
||||
}
|
||||
|
||||
c.JSON(200, gin.H{"models": result})
|
||||
}
|
||||
|
||||
// List auth files from disk when the auth manager is unavailable.
|
||||
func (h *Handler) listAuthFilesFromDisk(c *gin.Context) {
|
||||
entries, err := os.ReadDir(h.cfg.AuthDir)
|
||||
@@ -737,6 +782,8 @@ func (h *Handler) RequestAnthropicToken(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
RegisterOAuthSession(state, "anthropic")
|
||||
|
||||
isWebUI := isWebUIRequest(c)
|
||||
if isWebUI {
|
||||
targetURL, errTarget := h.managementCallbackURL("/anthropic/callback")
|
||||
@@ -763,7 +810,7 @@ func (h *Handler) RequestAnthropicToken(c *gin.Context) {
|
||||
deadline := time.Now().Add(timeout)
|
||||
for {
|
||||
if time.Now().After(deadline) {
|
||||
oauthStatus[state] = "Timeout waiting for OAuth callback"
|
||||
SetOAuthSessionError(state, "Timeout waiting for OAuth callback")
|
||||
return nil, fmt.Errorf("timeout waiting for OAuth callback")
|
||||
}
|
||||
data, errRead := os.ReadFile(path)
|
||||
@@ -788,13 +835,13 @@ func (h *Handler) RequestAnthropicToken(c *gin.Context) {
|
||||
if errStr := resultMap["error"]; errStr != "" {
|
||||
oauthErr := claude.NewOAuthError(errStr, "", http.StatusBadRequest)
|
||||
log.Error(claude.GetUserFriendlyMessage(oauthErr))
|
||||
oauthStatus[state] = "Bad request"
|
||||
SetOAuthSessionError(state, "Bad request")
|
||||
return
|
||||
}
|
||||
if resultMap["state"] != state {
|
||||
authErr := claude.NewAuthenticationError(claude.ErrInvalidState, fmt.Errorf("expected %s, got %s", state, resultMap["state"]))
|
||||
log.Error(claude.GetUserFriendlyMessage(authErr))
|
||||
oauthStatus[state] = "State code error"
|
||||
SetOAuthSessionError(state, "State code error")
|
||||
return
|
||||
}
|
||||
|
||||
@@ -827,7 +874,7 @@ func (h *Handler) RequestAnthropicToken(c *gin.Context) {
|
||||
if errDo != nil {
|
||||
authErr := claude.NewAuthenticationError(claude.ErrCodeExchangeFailed, errDo)
|
||||
log.Errorf("Failed to exchange authorization code for tokens: %v", authErr)
|
||||
oauthStatus[state] = "Failed to exchange authorization code for tokens"
|
||||
SetOAuthSessionError(state, "Failed to exchange authorization code for tokens")
|
||||
return
|
||||
}
|
||||
defer func() {
|
||||
@@ -838,7 +885,7 @@ func (h *Handler) RequestAnthropicToken(c *gin.Context) {
|
||||
respBody, _ := io.ReadAll(resp.Body)
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
log.Errorf("token exchange failed with status %d: %s", resp.StatusCode, string(respBody))
|
||||
oauthStatus[state] = fmt.Sprintf("token exchange failed with status %d", resp.StatusCode)
|
||||
SetOAuthSessionError(state, fmt.Sprintf("token exchange failed with status %d", resp.StatusCode))
|
||||
return
|
||||
}
|
||||
var tResp struct {
|
||||
@@ -851,7 +898,7 @@ func (h *Handler) RequestAnthropicToken(c *gin.Context) {
|
||||
}
|
||||
if errU := json.Unmarshal(respBody, &tResp); errU != nil {
|
||||
log.Errorf("failed to parse token response: %v", errU)
|
||||
oauthStatus[state] = "Failed to parse token response"
|
||||
SetOAuthSessionError(state, "Failed to parse token response")
|
||||
return
|
||||
}
|
||||
bundle := &claude.ClaudeAuthBundle{
|
||||
@@ -876,7 +923,7 @@ func (h *Handler) RequestAnthropicToken(c *gin.Context) {
|
||||
savedPath, errSave := h.saveTokenRecord(ctx, record)
|
||||
if errSave != nil {
|
||||
log.Errorf("Failed to save authentication tokens: %v", errSave)
|
||||
oauthStatus[state] = "Failed to save authentication tokens"
|
||||
SetOAuthSessionError(state, "Failed to save authentication tokens")
|
||||
return
|
||||
}
|
||||
|
||||
@@ -885,10 +932,9 @@ func (h *Handler) RequestAnthropicToken(c *gin.Context) {
|
||||
fmt.Println("API key obtained and saved")
|
||||
}
|
||||
fmt.Println("You can now use Claude services through this CLI")
|
||||
delete(oauthStatus, state)
|
||||
CompleteOAuthSession(state)
|
||||
}()
|
||||
|
||||
oauthStatus[state] = ""
|
||||
c.JSON(200, gin.H{"status": "ok", "url": authURL, "state": state})
|
||||
}
|
||||
|
||||
@@ -919,6 +965,8 @@ func (h *Handler) RequestGeminiCLIToken(c *gin.Context) {
|
||||
state := fmt.Sprintf("gem-%d", time.Now().UnixNano())
|
||||
authURL := conf.AuthCodeURL(state, oauth2.AccessTypeOffline, oauth2.SetAuthURLParam("prompt", "consent"))
|
||||
|
||||
RegisterOAuthSession(state, "gemini")
|
||||
|
||||
isWebUI := isWebUIRequest(c)
|
||||
if isWebUI {
|
||||
targetURL, errTarget := h.managementCallbackURL("/google/callback")
|
||||
@@ -947,7 +995,7 @@ func (h *Handler) RequestGeminiCLIToken(c *gin.Context) {
|
||||
for {
|
||||
if time.Now().After(deadline) {
|
||||
log.Error("oauth flow timed out")
|
||||
oauthStatus[state] = "OAuth flow timed out"
|
||||
SetOAuthSessionError(state, "OAuth flow timed out")
|
||||
return
|
||||
}
|
||||
if data, errR := os.ReadFile(waitFile); errR == nil {
|
||||
@@ -956,13 +1004,13 @@ func (h *Handler) RequestGeminiCLIToken(c *gin.Context) {
|
||||
_ = os.Remove(waitFile)
|
||||
if errStr := m["error"]; errStr != "" {
|
||||
log.Errorf("Authentication failed: %s", errStr)
|
||||
oauthStatus[state] = "Authentication failed"
|
||||
SetOAuthSessionError(state, "Authentication failed")
|
||||
return
|
||||
}
|
||||
authCode = m["code"]
|
||||
if authCode == "" {
|
||||
log.Errorf("Authentication failed: code not found")
|
||||
oauthStatus[state] = "Authentication failed: code not found"
|
||||
SetOAuthSessionError(state, "Authentication failed: code not found")
|
||||
return
|
||||
}
|
||||
break
|
||||
@@ -974,7 +1022,7 @@ func (h *Handler) RequestGeminiCLIToken(c *gin.Context) {
|
||||
token, err := conf.Exchange(ctx, authCode)
|
||||
if err != nil {
|
||||
log.Errorf("Failed to exchange token: %v", err)
|
||||
oauthStatus[state] = "Failed to exchange token"
|
||||
SetOAuthSessionError(state, "Failed to exchange token")
|
||||
return
|
||||
}
|
||||
|
||||
@@ -985,7 +1033,7 @@ func (h *Handler) RequestGeminiCLIToken(c *gin.Context) {
|
||||
req, errNewRequest := http.NewRequestWithContext(ctx, "GET", "https://www.googleapis.com/oauth2/v1/userinfo?alt=json", nil)
|
||||
if errNewRequest != nil {
|
||||
log.Errorf("Could not get user info: %v", errNewRequest)
|
||||
oauthStatus[state] = "Could not get user info"
|
||||
SetOAuthSessionError(state, "Could not get user info")
|
||||
return
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
@@ -994,7 +1042,7 @@ func (h *Handler) RequestGeminiCLIToken(c *gin.Context) {
|
||||
resp, errDo := authHTTPClient.Do(req)
|
||||
if errDo != nil {
|
||||
log.Errorf("Failed to execute request: %v", errDo)
|
||||
oauthStatus[state] = "Failed to execute request"
|
||||
SetOAuthSessionError(state, "Failed to execute request")
|
||||
return
|
||||
}
|
||||
defer func() {
|
||||
@@ -1006,7 +1054,7 @@ func (h *Handler) RequestGeminiCLIToken(c *gin.Context) {
|
||||
bodyBytes, _ := io.ReadAll(resp.Body)
|
||||
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
|
||||
log.Errorf("Get user info request failed with status %d: %s", resp.StatusCode, string(bodyBytes))
|
||||
oauthStatus[state] = fmt.Sprintf("Get user info request failed with status %d", resp.StatusCode)
|
||||
SetOAuthSessionError(state, fmt.Sprintf("Get user info request failed with status %d", resp.StatusCode))
|
||||
return
|
||||
}
|
||||
|
||||
@@ -1015,7 +1063,6 @@ func (h *Handler) RequestGeminiCLIToken(c *gin.Context) {
|
||||
fmt.Printf("Authenticated user email: %s\n", email)
|
||||
} else {
|
||||
fmt.Println("Failed to get user email from token")
|
||||
oauthStatus[state] = "Failed to get user email from token"
|
||||
}
|
||||
|
||||
// Marshal/unmarshal oauth2.Token to generic map and enrich fields
|
||||
@@ -1023,7 +1070,7 @@ func (h *Handler) RequestGeminiCLIToken(c *gin.Context) {
|
||||
jsonData, _ := json.Marshal(token)
|
||||
if errUnmarshal := json.Unmarshal(jsonData, &ifToken); errUnmarshal != nil {
|
||||
log.Errorf("Failed to unmarshal token: %v", errUnmarshal)
|
||||
oauthStatus[state] = "Failed to unmarshal token"
|
||||
SetOAuthSessionError(state, "Failed to unmarshal token")
|
||||
return
|
||||
}
|
||||
|
||||
@@ -1049,7 +1096,7 @@ func (h *Handler) RequestGeminiCLIToken(c *gin.Context) {
|
||||
gemClient, errGetClient := gemAuth.GetAuthenticatedClient(ctx, &ts, h.cfg, true)
|
||||
if errGetClient != nil {
|
||||
log.Errorf("failed to get authenticated client: %v", errGetClient)
|
||||
oauthStatus[state] = "Failed to get authenticated client"
|
||||
SetOAuthSessionError(state, "Failed to get authenticated client")
|
||||
return
|
||||
}
|
||||
fmt.Println("Authentication successful.")
|
||||
@@ -1059,12 +1106,12 @@ func (h *Handler) RequestGeminiCLIToken(c *gin.Context) {
|
||||
projects, errAll := onboardAllGeminiProjects(ctx, gemClient, &ts)
|
||||
if errAll != nil {
|
||||
log.Errorf("Failed to complete Gemini CLI onboarding: %v", errAll)
|
||||
oauthStatus[state] = "Failed to complete Gemini CLI onboarding"
|
||||
SetOAuthSessionError(state, "Failed to complete Gemini CLI onboarding")
|
||||
return
|
||||
}
|
||||
if errVerify := ensureGeminiProjectsEnabled(ctx, gemClient, projects); errVerify != nil {
|
||||
log.Errorf("Failed to verify Cloud AI API status: %v", errVerify)
|
||||
oauthStatus[state] = "Failed to verify Cloud AI API status"
|
||||
SetOAuthSessionError(state, "Failed to verify Cloud AI API status")
|
||||
return
|
||||
}
|
||||
ts.ProjectID = strings.Join(projects, ",")
|
||||
@@ -1072,26 +1119,26 @@ func (h *Handler) RequestGeminiCLIToken(c *gin.Context) {
|
||||
} else {
|
||||
if errEnsure := ensureGeminiProjectAndOnboard(ctx, gemClient, &ts, requestedProjectID); errEnsure != nil {
|
||||
log.Errorf("Failed to complete Gemini CLI onboarding: %v", errEnsure)
|
||||
oauthStatus[state] = "Failed to complete Gemini CLI onboarding"
|
||||
SetOAuthSessionError(state, "Failed to complete Gemini CLI onboarding")
|
||||
return
|
||||
}
|
||||
|
||||
if strings.TrimSpace(ts.ProjectID) == "" {
|
||||
log.Error("Onboarding did not return a project ID")
|
||||
oauthStatus[state] = "Failed to resolve project ID"
|
||||
SetOAuthSessionError(state, "Failed to resolve project ID")
|
||||
return
|
||||
}
|
||||
|
||||
isChecked, errCheck := checkCloudAPIIsEnabled(ctx, gemClient, ts.ProjectID)
|
||||
if errCheck != nil {
|
||||
log.Errorf("Failed to verify Cloud AI API status: %v", errCheck)
|
||||
oauthStatus[state] = "Failed to verify Cloud AI API status"
|
||||
SetOAuthSessionError(state, "Failed to verify Cloud AI API status")
|
||||
return
|
||||
}
|
||||
ts.Checked = isChecked
|
||||
if !isChecked {
|
||||
log.Error("Cloud AI API is not enabled for the selected project")
|
||||
oauthStatus[state] = "Cloud AI API not enabled"
|
||||
SetOAuthSessionError(state, "Cloud AI API not enabled")
|
||||
return
|
||||
}
|
||||
}
|
||||
@@ -1114,15 +1161,14 @@ func (h *Handler) RequestGeminiCLIToken(c *gin.Context) {
|
||||
savedPath, errSave := h.saveTokenRecord(ctx, record)
|
||||
if errSave != nil {
|
||||
log.Errorf("Failed to save token to file: %v", errSave)
|
||||
oauthStatus[state] = "Failed to save token to file"
|
||||
SetOAuthSessionError(state, "Failed to save token to file")
|
||||
return
|
||||
}
|
||||
|
||||
delete(oauthStatus, state)
|
||||
CompleteOAuthSession(state)
|
||||
fmt.Printf("You can now use Gemini CLI services through this CLI; token saved to %s\n", savedPath)
|
||||
}()
|
||||
|
||||
oauthStatus[state] = ""
|
||||
c.JSON(200, gin.H{"status": "ok", "url": authURL, "state": state})
|
||||
}
|
||||
|
||||
@@ -1158,6 +1204,8 @@ func (h *Handler) RequestCodexToken(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
RegisterOAuthSession(state, "codex")
|
||||
|
||||
isWebUI := isWebUIRequest(c)
|
||||
if isWebUI {
|
||||
targetURL, errTarget := h.managementCallbackURL("/codex/callback")
|
||||
@@ -1186,7 +1234,7 @@ func (h *Handler) RequestCodexToken(c *gin.Context) {
|
||||
if time.Now().After(deadline) {
|
||||
authErr := codex.NewAuthenticationError(codex.ErrCallbackTimeout, fmt.Errorf("timeout waiting for OAuth callback"))
|
||||
log.Error(codex.GetUserFriendlyMessage(authErr))
|
||||
oauthStatus[state] = "Timeout waiting for OAuth callback"
|
||||
SetOAuthSessionError(state, "Timeout waiting for OAuth callback")
|
||||
return
|
||||
}
|
||||
if data, errR := os.ReadFile(waitFile); errR == nil {
|
||||
@@ -1196,12 +1244,12 @@ func (h *Handler) RequestCodexToken(c *gin.Context) {
|
||||
if errStr := m["error"]; errStr != "" {
|
||||
oauthErr := codex.NewOAuthError(errStr, "", http.StatusBadRequest)
|
||||
log.Error(codex.GetUserFriendlyMessage(oauthErr))
|
||||
oauthStatus[state] = "Bad Request"
|
||||
SetOAuthSessionError(state, "Bad Request")
|
||||
return
|
||||
}
|
||||
if m["state"] != state {
|
||||
authErr := codex.NewAuthenticationError(codex.ErrInvalidState, fmt.Errorf("expected %s, got %s", state, m["state"]))
|
||||
oauthStatus[state] = "State code error"
|
||||
SetOAuthSessionError(state, "State code error")
|
||||
log.Error(codex.GetUserFriendlyMessage(authErr))
|
||||
return
|
||||
}
|
||||
@@ -1232,14 +1280,14 @@ func (h *Handler) RequestCodexToken(c *gin.Context) {
|
||||
resp, errDo := httpClient.Do(req)
|
||||
if errDo != nil {
|
||||
authErr := codex.NewAuthenticationError(codex.ErrCodeExchangeFailed, errDo)
|
||||
oauthStatus[state] = "Failed to exchange authorization code for tokens"
|
||||
SetOAuthSessionError(state, "Failed to exchange authorization code for tokens")
|
||||
log.Errorf("Failed to exchange authorization code for tokens: %v", authErr)
|
||||
return
|
||||
}
|
||||
defer func() { _ = resp.Body.Close() }()
|
||||
respBody, _ := io.ReadAll(resp.Body)
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
oauthStatus[state] = fmt.Sprintf("Token exchange failed with status %d", resp.StatusCode)
|
||||
SetOAuthSessionError(state, fmt.Sprintf("Token exchange failed with status %d", resp.StatusCode))
|
||||
log.Errorf("token exchange failed with status %d: %s", resp.StatusCode, string(respBody))
|
||||
return
|
||||
}
|
||||
@@ -1250,7 +1298,7 @@ func (h *Handler) RequestCodexToken(c *gin.Context) {
|
||||
ExpiresIn int `json:"expires_in"`
|
||||
}
|
||||
if errU := json.Unmarshal(respBody, &tokenResp); errU != nil {
|
||||
oauthStatus[state] = "Failed to parse token response"
|
||||
SetOAuthSessionError(state, "Failed to parse token response")
|
||||
log.Errorf("failed to parse token response: %v", errU)
|
||||
return
|
||||
}
|
||||
@@ -1288,7 +1336,7 @@ func (h *Handler) RequestCodexToken(c *gin.Context) {
|
||||
}
|
||||
savedPath, errSave := h.saveTokenRecord(ctx, record)
|
||||
if errSave != nil {
|
||||
oauthStatus[state] = "Failed to save authentication tokens"
|
||||
SetOAuthSessionError(state, "Failed to save authentication tokens")
|
||||
log.Errorf("Failed to save authentication tokens: %v", errSave)
|
||||
return
|
||||
}
|
||||
@@ -1297,10 +1345,9 @@ func (h *Handler) RequestCodexToken(c *gin.Context) {
|
||||
fmt.Println("API key obtained and saved")
|
||||
}
|
||||
fmt.Println("You can now use Codex services through this CLI")
|
||||
delete(oauthStatus, state)
|
||||
CompleteOAuthSession(state)
|
||||
}()
|
||||
|
||||
oauthStatus[state] = ""
|
||||
c.JSON(200, gin.H{"status": "ok", "url": authURL, "state": state})
|
||||
}
|
||||
|
||||
@@ -1341,6 +1388,8 @@ func (h *Handler) RequestAntigravityToken(c *gin.Context) {
|
||||
params.Set("state", state)
|
||||
authURL := "https://accounts.google.com/o/oauth2/v2/auth?" + params.Encode()
|
||||
|
||||
RegisterOAuthSession(state, "antigravity")
|
||||
|
||||
isWebUI := isWebUIRequest(c)
|
||||
if isWebUI {
|
||||
targetURL, errTarget := h.managementCallbackURL("/antigravity/callback")
|
||||
@@ -1367,7 +1416,7 @@ func (h *Handler) RequestAntigravityToken(c *gin.Context) {
|
||||
for {
|
||||
if time.Now().After(deadline) {
|
||||
log.Error("oauth flow timed out")
|
||||
oauthStatus[state] = "OAuth flow timed out"
|
||||
SetOAuthSessionError(state, "OAuth flow timed out")
|
||||
return
|
||||
}
|
||||
if data, errReadFile := os.ReadFile(waitFile); errReadFile == nil {
|
||||
@@ -1376,18 +1425,18 @@ func (h *Handler) RequestAntigravityToken(c *gin.Context) {
|
||||
_ = os.Remove(waitFile)
|
||||
if errStr := strings.TrimSpace(payload["error"]); errStr != "" {
|
||||
log.Errorf("Authentication failed: %s", errStr)
|
||||
oauthStatus[state] = "Authentication failed"
|
||||
SetOAuthSessionError(state, "Authentication failed")
|
||||
return
|
||||
}
|
||||
if payloadState := strings.TrimSpace(payload["state"]); payloadState != "" && payloadState != state {
|
||||
log.Errorf("Authentication failed: state mismatch")
|
||||
oauthStatus[state] = "Authentication failed: state mismatch"
|
||||
SetOAuthSessionError(state, "Authentication failed: state mismatch")
|
||||
return
|
||||
}
|
||||
authCode = strings.TrimSpace(payload["code"])
|
||||
if authCode == "" {
|
||||
log.Error("Authentication failed: code not found")
|
||||
oauthStatus[state] = "Authentication failed: code not found"
|
||||
SetOAuthSessionError(state, "Authentication failed: code not found")
|
||||
return
|
||||
}
|
||||
break
|
||||
@@ -1406,7 +1455,7 @@ func (h *Handler) RequestAntigravityToken(c *gin.Context) {
|
||||
req, errNewRequest := http.NewRequestWithContext(ctx, http.MethodPost, "https://oauth2.googleapis.com/token", strings.NewReader(form.Encode()))
|
||||
if errNewRequest != nil {
|
||||
log.Errorf("Failed to build token request: %v", errNewRequest)
|
||||
oauthStatus[state] = "Failed to build token request"
|
||||
SetOAuthSessionError(state, "Failed to build token request")
|
||||
return
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
||||
@@ -1414,7 +1463,7 @@ func (h *Handler) RequestAntigravityToken(c *gin.Context) {
|
||||
resp, errDo := httpClient.Do(req)
|
||||
if errDo != nil {
|
||||
log.Errorf("Failed to execute token request: %v", errDo)
|
||||
oauthStatus[state] = "Failed to exchange token"
|
||||
SetOAuthSessionError(state, "Failed to exchange token")
|
||||
return
|
||||
}
|
||||
defer func() {
|
||||
@@ -1426,7 +1475,7 @@ func (h *Handler) RequestAntigravityToken(c *gin.Context) {
|
||||
if resp.StatusCode < http.StatusOK || resp.StatusCode >= http.StatusMultipleChoices {
|
||||
bodyBytes, _ := io.ReadAll(resp.Body)
|
||||
log.Errorf("Antigravity token exchange failed with status %d: %s", resp.StatusCode, string(bodyBytes))
|
||||
oauthStatus[state] = fmt.Sprintf("Token exchange failed: %d", resp.StatusCode)
|
||||
SetOAuthSessionError(state, fmt.Sprintf("Token exchange failed: %d", resp.StatusCode))
|
||||
return
|
||||
}
|
||||
|
||||
@@ -1438,7 +1487,7 @@ func (h *Handler) RequestAntigravityToken(c *gin.Context) {
|
||||
}
|
||||
if errDecode := json.NewDecoder(resp.Body).Decode(&tokenResp); errDecode != nil {
|
||||
log.Errorf("Failed to parse token response: %v", errDecode)
|
||||
oauthStatus[state] = "Failed to parse token response"
|
||||
SetOAuthSessionError(state, "Failed to parse token response")
|
||||
return
|
||||
}
|
||||
|
||||
@@ -1447,7 +1496,7 @@ func (h *Handler) RequestAntigravityToken(c *gin.Context) {
|
||||
infoReq, errInfoReq := http.NewRequestWithContext(ctx, http.MethodGet, "https://www.googleapis.com/oauth2/v1/userinfo?alt=json", nil)
|
||||
if errInfoReq != nil {
|
||||
log.Errorf("Failed to build user info request: %v", errInfoReq)
|
||||
oauthStatus[state] = "Failed to build user info request"
|
||||
SetOAuthSessionError(state, "Failed to build user info request")
|
||||
return
|
||||
}
|
||||
infoReq.Header.Set("Authorization", "Bearer "+tokenResp.AccessToken)
|
||||
@@ -1455,7 +1504,7 @@ func (h *Handler) RequestAntigravityToken(c *gin.Context) {
|
||||
infoResp, errInfo := httpClient.Do(infoReq)
|
||||
if errInfo != nil {
|
||||
log.Errorf("Failed to execute user info request: %v", errInfo)
|
||||
oauthStatus[state] = "Failed to execute user info request"
|
||||
SetOAuthSessionError(state, "Failed to execute user info request")
|
||||
return
|
||||
}
|
||||
defer func() {
|
||||
@@ -1474,7 +1523,7 @@ func (h *Handler) RequestAntigravityToken(c *gin.Context) {
|
||||
} else {
|
||||
bodyBytes, _ := io.ReadAll(infoResp.Body)
|
||||
log.Errorf("User info request failed with status %d: %s", infoResp.StatusCode, string(bodyBytes))
|
||||
oauthStatus[state] = fmt.Sprintf("User info request failed: %d", infoResp.StatusCode)
|
||||
SetOAuthSessionError(state, fmt.Sprintf("User info request failed: %d", infoResp.StatusCode))
|
||||
return
|
||||
}
|
||||
}
|
||||
@@ -1522,11 +1571,11 @@ func (h *Handler) RequestAntigravityToken(c *gin.Context) {
|
||||
savedPath, errSave := h.saveTokenRecord(ctx, record)
|
||||
if errSave != nil {
|
||||
log.Errorf("Failed to save token to file: %v", errSave)
|
||||
oauthStatus[state] = "Failed to save token to file"
|
||||
SetOAuthSessionError(state, "Failed to save token to file")
|
||||
return
|
||||
}
|
||||
|
||||
delete(oauthStatus, state)
|
||||
CompleteOAuthSession(state)
|
||||
fmt.Printf("Authentication successful! Token saved to %s\n", savedPath)
|
||||
if projectID != "" {
|
||||
fmt.Printf("Using GCP project: %s\n", projectID)
|
||||
@@ -1534,7 +1583,6 @@ func (h *Handler) RequestAntigravityToken(c *gin.Context) {
|
||||
fmt.Println("You can now use Antigravity services through this CLI")
|
||||
}()
|
||||
|
||||
oauthStatus[state] = ""
|
||||
c.JSON(200, gin.H{"status": "ok", "url": authURL, "state": state})
|
||||
}
|
||||
|
||||
@@ -1556,11 +1604,13 @@ func (h *Handler) RequestQwenToken(c *gin.Context) {
|
||||
}
|
||||
authURL := deviceFlow.VerificationURIComplete
|
||||
|
||||
RegisterOAuthSession(state, "qwen")
|
||||
|
||||
go func() {
|
||||
fmt.Println("Waiting for authentication...")
|
||||
tokenData, errPollForToken := qwenAuth.PollForToken(deviceFlow.DeviceCode, deviceFlow.CodeVerifier)
|
||||
if errPollForToken != nil {
|
||||
oauthStatus[state] = "Authentication failed"
|
||||
SetOAuthSessionError(state, "Authentication failed")
|
||||
fmt.Printf("Authentication failed: %v\n", errPollForToken)
|
||||
return
|
||||
}
|
||||
@@ -1579,16 +1629,15 @@ func (h *Handler) RequestQwenToken(c *gin.Context) {
|
||||
savedPath, errSave := h.saveTokenRecord(ctx, record)
|
||||
if errSave != nil {
|
||||
log.Errorf("Failed to save authentication tokens: %v", errSave)
|
||||
oauthStatus[state] = "Failed to save authentication tokens"
|
||||
SetOAuthSessionError(state, "Failed to save authentication tokens")
|
||||
return
|
||||
}
|
||||
|
||||
fmt.Printf("Authentication successful! Token saved to %s\n", savedPath)
|
||||
fmt.Println("You can now use Qwen services through this CLI")
|
||||
delete(oauthStatus, state)
|
||||
CompleteOAuthSession(state)
|
||||
}()
|
||||
|
||||
oauthStatus[state] = ""
|
||||
c.JSON(200, gin.H{"status": "ok", "url": authURL, "state": state})
|
||||
}
|
||||
|
||||
@@ -1601,6 +1650,8 @@ func (h *Handler) RequestIFlowToken(c *gin.Context) {
|
||||
authSvc := iflowauth.NewIFlowAuth(h.cfg)
|
||||
authURL, redirectURI := authSvc.AuthorizationURL(state, iflowauth.CallbackPort)
|
||||
|
||||
RegisterOAuthSession(state, "iflow")
|
||||
|
||||
isWebUI := isWebUIRequest(c)
|
||||
if isWebUI {
|
||||
targetURL, errTarget := h.managementCallbackURL("/iflow/callback")
|
||||
@@ -1627,7 +1678,7 @@ func (h *Handler) RequestIFlowToken(c *gin.Context) {
|
||||
var resultMap map[string]string
|
||||
for {
|
||||
if time.Now().After(deadline) {
|
||||
oauthStatus[state] = "Authentication failed"
|
||||
SetOAuthSessionError(state, "Authentication failed")
|
||||
fmt.Println("Authentication failed: timeout waiting for callback")
|
||||
return
|
||||
}
|
||||
@@ -1640,26 +1691,26 @@ func (h *Handler) RequestIFlowToken(c *gin.Context) {
|
||||
}
|
||||
|
||||
if errStr := strings.TrimSpace(resultMap["error"]); errStr != "" {
|
||||
oauthStatus[state] = "Authentication failed"
|
||||
SetOAuthSessionError(state, "Authentication failed")
|
||||
fmt.Printf("Authentication failed: %s\n", errStr)
|
||||
return
|
||||
}
|
||||
if resultState := strings.TrimSpace(resultMap["state"]); resultState != state {
|
||||
oauthStatus[state] = "Authentication failed"
|
||||
SetOAuthSessionError(state, "Authentication failed")
|
||||
fmt.Println("Authentication failed: state mismatch")
|
||||
return
|
||||
}
|
||||
|
||||
code := strings.TrimSpace(resultMap["code"])
|
||||
if code == "" {
|
||||
oauthStatus[state] = "Authentication failed"
|
||||
SetOAuthSessionError(state, "Authentication failed")
|
||||
fmt.Println("Authentication failed: code missing")
|
||||
return
|
||||
}
|
||||
|
||||
tokenData, errExchange := authSvc.ExchangeCodeForTokens(ctx, code, redirectURI)
|
||||
if errExchange != nil {
|
||||
oauthStatus[state] = "Authentication failed"
|
||||
SetOAuthSessionError(state, "Authentication failed")
|
||||
fmt.Printf("Authentication failed: %v\n", errExchange)
|
||||
return
|
||||
}
|
||||
@@ -1681,7 +1732,7 @@ func (h *Handler) RequestIFlowToken(c *gin.Context) {
|
||||
|
||||
savedPath, errSave := h.saveTokenRecord(ctx, record)
|
||||
if errSave != nil {
|
||||
oauthStatus[state] = "Failed to save authentication tokens"
|
||||
SetOAuthSessionError(state, "Failed to save authentication tokens")
|
||||
log.Errorf("Failed to save authentication tokens: %v", errSave)
|
||||
return
|
||||
}
|
||||
@@ -1691,10 +1742,9 @@ func (h *Handler) RequestIFlowToken(c *gin.Context) {
|
||||
fmt.Println("API key obtained and saved")
|
||||
}
|
||||
fmt.Println("You can now use iFlow services through this CLI")
|
||||
delete(oauthStatus, state)
|
||||
CompleteOAuthSession(state)
|
||||
}()
|
||||
|
||||
oauthStatus[state] = ""
|
||||
c.JSON(http.StatusOK, gin.H{"status": "ok", "url": authURL, "state": state})
|
||||
}
|
||||
|
||||
@@ -2130,16 +2180,24 @@ func checkCloudAPIIsEnabled(ctx context.Context, httpClient *http.Client, projec
|
||||
}
|
||||
|
||||
func (h *Handler) GetAuthStatus(c *gin.Context) {
|
||||
state := c.Query("state")
|
||||
if err, ok := oauthStatus[state]; ok {
|
||||
if err != "" {
|
||||
c.JSON(200, gin.H{"status": "error", "error": err})
|
||||
} else {
|
||||
c.JSON(200, gin.H{"status": "wait"})
|
||||
return
|
||||
}
|
||||
} else {
|
||||
c.JSON(200, gin.H{"status": "ok"})
|
||||
state := strings.TrimSpace(c.Query("state"))
|
||||
if state == "" {
|
||||
c.JSON(http.StatusOK, gin.H{"status": "ok"})
|
||||
return
|
||||
}
|
||||
delete(oauthStatus, state)
|
||||
if err := ValidateOAuthState(state); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"status": "error", "error": "invalid state"})
|
||||
return
|
||||
}
|
||||
|
||||
_, status, ok := GetOAuthSession(state)
|
||||
if !ok {
|
||||
c.JSON(http.StatusOK, gin.H{"status": "ok"})
|
||||
return
|
||||
}
|
||||
if status != "" {
|
||||
c.JSON(http.StatusOK, gin.H{"status": "error", "error": status})
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{"status": "wait"})
|
||||
}
|
||||
|
||||
100
internal/api/handlers/management/oauth_callback.go
Normal file
100
internal/api/handlers/management/oauth_callback.go
Normal file
@@ -0,0 +1,100 @@
|
||||
package management
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
type oauthCallbackRequest struct {
|
||||
Provider string `json:"provider"`
|
||||
RedirectURL string `json:"redirect_url"`
|
||||
Code string `json:"code"`
|
||||
State string `json:"state"`
|
||||
Error string `json:"error"`
|
||||
}
|
||||
|
||||
func (h *Handler) PostOAuthCallback(c *gin.Context) {
|
||||
if h == nil || h.cfg == nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"status": "error", "error": "handler not initialized"})
|
||||
return
|
||||
}
|
||||
|
||||
var req oauthCallbackRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"status": "error", "error": "invalid body"})
|
||||
return
|
||||
}
|
||||
|
||||
canonicalProvider, err := NormalizeOAuthProvider(req.Provider)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"status": "error", "error": "unsupported provider"})
|
||||
return
|
||||
}
|
||||
|
||||
state := strings.TrimSpace(req.State)
|
||||
code := strings.TrimSpace(req.Code)
|
||||
errMsg := strings.TrimSpace(req.Error)
|
||||
|
||||
if rawRedirect := strings.TrimSpace(req.RedirectURL); rawRedirect != "" {
|
||||
u, errParse := url.Parse(rawRedirect)
|
||||
if errParse != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"status": "error", "error": "invalid redirect_url"})
|
||||
return
|
||||
}
|
||||
q := u.Query()
|
||||
if state == "" {
|
||||
state = strings.TrimSpace(q.Get("state"))
|
||||
}
|
||||
if code == "" {
|
||||
code = strings.TrimSpace(q.Get("code"))
|
||||
}
|
||||
if errMsg == "" {
|
||||
errMsg = strings.TrimSpace(q.Get("error"))
|
||||
if errMsg == "" {
|
||||
errMsg = strings.TrimSpace(q.Get("error_description"))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if state == "" {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"status": "error", "error": "state is required"})
|
||||
return
|
||||
}
|
||||
if err := ValidateOAuthState(state); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"status": "error", "error": "invalid state"})
|
||||
return
|
||||
}
|
||||
if code == "" && errMsg == "" {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"status": "error", "error": "code or error is required"})
|
||||
return
|
||||
}
|
||||
|
||||
sessionProvider, sessionStatus, ok := GetOAuthSession(state)
|
||||
if !ok {
|
||||
c.JSON(http.StatusNotFound, gin.H{"status": "error", "error": "unknown or expired state"})
|
||||
return
|
||||
}
|
||||
if sessionStatus != "" {
|
||||
c.JSON(http.StatusConflict, gin.H{"status": "error", "error": "oauth flow is not pending"})
|
||||
return
|
||||
}
|
||||
if !strings.EqualFold(sessionProvider, canonicalProvider) {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"status": "error", "error": "provider does not match state"})
|
||||
return
|
||||
}
|
||||
|
||||
if _, errWrite := WriteOAuthCallbackFileForPendingSession(h.cfg.AuthDir, canonicalProvider, state, code, errMsg); errWrite != nil {
|
||||
if errors.Is(errWrite, errOAuthSessionNotPending) {
|
||||
c.JSON(http.StatusConflict, gin.H{"status": "error", "error": "oauth flow is not pending"})
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"status": "error", "error": "failed to persist oauth callback"})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{"status": "ok"})
|
||||
}
|
||||
258
internal/api/handlers/management/oauth_sessions.go
Normal file
258
internal/api/handlers/management/oauth_sessions.go
Normal file
@@ -0,0 +1,258 @@
|
||||
package management
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
const (
|
||||
oauthSessionTTL = 10 * time.Minute
|
||||
maxOAuthStateLength = 128
|
||||
)
|
||||
|
||||
var (
|
||||
errInvalidOAuthState = errors.New("invalid oauth state")
|
||||
errUnsupportedOAuthFlow = errors.New("unsupported oauth provider")
|
||||
errOAuthSessionNotPending = errors.New("oauth session is not pending")
|
||||
)
|
||||
|
||||
type oauthSession struct {
|
||||
Provider string
|
||||
Status string
|
||||
CreatedAt time.Time
|
||||
ExpiresAt time.Time
|
||||
}
|
||||
|
||||
type oauthSessionStore struct {
|
||||
mu sync.RWMutex
|
||||
ttl time.Duration
|
||||
sessions map[string]oauthSession
|
||||
}
|
||||
|
||||
func newOAuthSessionStore(ttl time.Duration) *oauthSessionStore {
|
||||
if ttl <= 0 {
|
||||
ttl = oauthSessionTTL
|
||||
}
|
||||
return &oauthSessionStore{
|
||||
ttl: ttl,
|
||||
sessions: make(map[string]oauthSession),
|
||||
}
|
||||
}
|
||||
|
||||
func (s *oauthSessionStore) purgeExpiredLocked(now time.Time) {
|
||||
for state, session := range s.sessions {
|
||||
if !session.ExpiresAt.IsZero() && now.After(session.ExpiresAt) {
|
||||
delete(s.sessions, state)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *oauthSessionStore) Register(state, provider string) {
|
||||
state = strings.TrimSpace(state)
|
||||
provider = strings.ToLower(strings.TrimSpace(provider))
|
||||
if state == "" || provider == "" {
|
||||
return
|
||||
}
|
||||
now := time.Now()
|
||||
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
s.purgeExpiredLocked(now)
|
||||
s.sessions[state] = oauthSession{
|
||||
Provider: provider,
|
||||
Status: "",
|
||||
CreatedAt: now,
|
||||
ExpiresAt: now.Add(s.ttl),
|
||||
}
|
||||
}
|
||||
|
||||
func (s *oauthSessionStore) SetError(state, message string) {
|
||||
state = strings.TrimSpace(state)
|
||||
message = strings.TrimSpace(message)
|
||||
if state == "" {
|
||||
return
|
||||
}
|
||||
if message == "" {
|
||||
message = "Authentication failed"
|
||||
}
|
||||
now := time.Now()
|
||||
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
s.purgeExpiredLocked(now)
|
||||
session, ok := s.sessions[state]
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
session.Status = message
|
||||
session.ExpiresAt = now.Add(s.ttl)
|
||||
s.sessions[state] = session
|
||||
}
|
||||
|
||||
func (s *oauthSessionStore) Complete(state string) {
|
||||
state = strings.TrimSpace(state)
|
||||
if state == "" {
|
||||
return
|
||||
}
|
||||
now := time.Now()
|
||||
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
s.purgeExpiredLocked(now)
|
||||
delete(s.sessions, state)
|
||||
}
|
||||
|
||||
func (s *oauthSessionStore) Get(state string) (oauthSession, bool) {
|
||||
state = strings.TrimSpace(state)
|
||||
now := time.Now()
|
||||
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
s.purgeExpiredLocked(now)
|
||||
session, ok := s.sessions[state]
|
||||
return session, ok
|
||||
}
|
||||
|
||||
func (s *oauthSessionStore) IsPending(state, provider string) bool {
|
||||
state = strings.TrimSpace(state)
|
||||
provider = strings.ToLower(strings.TrimSpace(provider))
|
||||
now := time.Now()
|
||||
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
s.purgeExpiredLocked(now)
|
||||
session, ok := s.sessions[state]
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
if session.Status != "" {
|
||||
return false
|
||||
}
|
||||
if provider == "" {
|
||||
return true
|
||||
}
|
||||
return strings.EqualFold(session.Provider, provider)
|
||||
}
|
||||
|
||||
var oauthSessions = newOAuthSessionStore(oauthSessionTTL)
|
||||
|
||||
func RegisterOAuthSession(state, provider string) { oauthSessions.Register(state, provider) }
|
||||
|
||||
func SetOAuthSessionError(state, message string) { oauthSessions.SetError(state, message) }
|
||||
|
||||
func CompleteOAuthSession(state string) { oauthSessions.Complete(state) }
|
||||
|
||||
func GetOAuthSession(state string) (provider string, status string, ok bool) {
|
||||
session, ok := oauthSessions.Get(state)
|
||||
if !ok {
|
||||
return "", "", false
|
||||
}
|
||||
return session.Provider, session.Status, true
|
||||
}
|
||||
|
||||
func IsOAuthSessionPending(state, provider string) bool {
|
||||
return oauthSessions.IsPending(state, provider)
|
||||
}
|
||||
|
||||
func ValidateOAuthState(state string) error {
|
||||
trimmed := strings.TrimSpace(state)
|
||||
if trimmed == "" {
|
||||
return fmt.Errorf("%w: empty", errInvalidOAuthState)
|
||||
}
|
||||
if len(trimmed) > maxOAuthStateLength {
|
||||
return fmt.Errorf("%w: too long", errInvalidOAuthState)
|
||||
}
|
||||
if strings.Contains(trimmed, "/") || strings.Contains(trimmed, "\\") {
|
||||
return fmt.Errorf("%w: contains path separator", errInvalidOAuthState)
|
||||
}
|
||||
if strings.Contains(trimmed, "..") {
|
||||
return fmt.Errorf("%w: contains '..'", errInvalidOAuthState)
|
||||
}
|
||||
for _, r := range trimmed {
|
||||
switch {
|
||||
case r >= 'a' && r <= 'z':
|
||||
case r >= 'A' && r <= 'Z':
|
||||
case r >= '0' && r <= '9':
|
||||
case r == '-' || r == '_' || r == '.':
|
||||
default:
|
||||
return fmt.Errorf("%w: invalid character", errInvalidOAuthState)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func NormalizeOAuthProvider(provider string) (string, error) {
|
||||
switch strings.ToLower(strings.TrimSpace(provider)) {
|
||||
case "anthropic", "claude":
|
||||
return "anthropic", nil
|
||||
case "codex", "openai":
|
||||
return "codex", nil
|
||||
case "gemini", "google":
|
||||
return "gemini", nil
|
||||
case "iflow", "i-flow":
|
||||
return "iflow", nil
|
||||
case "antigravity", "anti-gravity":
|
||||
return "antigravity", nil
|
||||
case "qwen":
|
||||
return "qwen", nil
|
||||
default:
|
||||
return "", errUnsupportedOAuthFlow
|
||||
}
|
||||
}
|
||||
|
||||
type oauthCallbackFilePayload struct {
|
||||
Code string `json:"code"`
|
||||
State string `json:"state"`
|
||||
Error string `json:"error"`
|
||||
}
|
||||
|
||||
func WriteOAuthCallbackFile(authDir, provider, state, code, errorMessage string) (string, error) {
|
||||
if strings.TrimSpace(authDir) == "" {
|
||||
return "", fmt.Errorf("auth dir is empty")
|
||||
}
|
||||
canonicalProvider, err := NormalizeOAuthProvider(provider)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
if err := ValidateOAuthState(state); err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
fileName := fmt.Sprintf(".oauth-%s-%s.oauth", canonicalProvider, state)
|
||||
filePath := filepath.Join(authDir, fileName)
|
||||
payload := oauthCallbackFilePayload{
|
||||
Code: strings.TrimSpace(code),
|
||||
State: strings.TrimSpace(state),
|
||||
Error: strings.TrimSpace(errorMessage),
|
||||
}
|
||||
data, err := json.Marshal(payload)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("marshal oauth callback payload: %w", err)
|
||||
}
|
||||
if err := os.WriteFile(filePath, data, 0o600); err != nil {
|
||||
return "", fmt.Errorf("write oauth callback file: %w", err)
|
||||
}
|
||||
return filePath, nil
|
||||
}
|
||||
|
||||
func WriteOAuthCallbackFileForPendingSession(authDir, provider, state, code, errorMessage string) (string, error) {
|
||||
canonicalProvider, err := NormalizeOAuthProvider(provider)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
if !IsOAuthSessionPending(state, canonicalProvider) {
|
||||
return "", errOAuthSessionNotPending
|
||||
}
|
||||
return WriteOAuthCallbackFile(authDir, canonicalProvider, state, code, errorMessage)
|
||||
}
|
||||
@@ -71,22 +71,64 @@ func (w *ResponseWriterWrapper) Write(data []byte) (int, error) {
|
||||
n, err := w.ResponseWriter.Write(data)
|
||||
|
||||
// THEN: Handle logging based on response type
|
||||
if w.isStreaming {
|
||||
if w.isStreaming && w.chunkChannel != nil {
|
||||
// For streaming responses: Send to async logging channel (non-blocking)
|
||||
if w.chunkChannel != nil {
|
||||
select {
|
||||
case w.chunkChannel <- append([]byte(nil), data...): // Non-blocking send with copy
|
||||
default: // Channel full, skip logging to avoid blocking
|
||||
}
|
||||
select {
|
||||
case w.chunkChannel <- append([]byte(nil), data...): // Non-blocking send with copy
|
||||
default: // Channel full, skip logging to avoid blocking
|
||||
}
|
||||
} else {
|
||||
// For non-streaming responses: Buffer complete response
|
||||
return n, err
|
||||
}
|
||||
|
||||
if w.shouldBufferResponseBody() {
|
||||
w.body.Write(data)
|
||||
}
|
||||
|
||||
return n, err
|
||||
}
|
||||
|
||||
func (w *ResponseWriterWrapper) shouldBufferResponseBody() bool {
|
||||
if w.logger != nil && w.logger.IsEnabled() {
|
||||
return true
|
||||
}
|
||||
if !w.logOnErrorOnly {
|
||||
return false
|
||||
}
|
||||
status := w.statusCode
|
||||
if status == 0 {
|
||||
if statusWriter, ok := w.ResponseWriter.(interface{ Status() int }); ok && statusWriter != nil {
|
||||
status = statusWriter.Status()
|
||||
} else {
|
||||
status = http.StatusOK
|
||||
}
|
||||
}
|
||||
return status >= http.StatusBadRequest
|
||||
}
|
||||
|
||||
// WriteString wraps the underlying ResponseWriter's WriteString method to capture response data.
|
||||
// Some handlers (and fmt/io helpers) write via io.StringWriter; without this override, those writes
|
||||
// bypass Write() and would be missing from request logs.
|
||||
func (w *ResponseWriterWrapper) WriteString(data string) (int, error) {
|
||||
w.ensureHeadersCaptured()
|
||||
|
||||
// CRITICAL: Write to client first (zero latency)
|
||||
n, err := w.ResponseWriter.WriteString(data)
|
||||
|
||||
// THEN: Capture for logging
|
||||
if w.isStreaming && w.chunkChannel != nil {
|
||||
select {
|
||||
case w.chunkChannel <- []byte(data):
|
||||
default:
|
||||
}
|
||||
return n, err
|
||||
}
|
||||
|
||||
if w.shouldBufferResponseBody() {
|
||||
w.body.WriteString(data)
|
||||
}
|
||||
return n, err
|
||||
}
|
||||
|
||||
// WriteHeader wraps the underlying ResponseWriter's WriteHeader method.
|
||||
// It captures the status code, detects if the response is streaming based on the Content-Type header,
|
||||
// and initializes the appropriate logging mechanism (standard or streaming).
|
||||
@@ -160,12 +202,16 @@ func (w *ResponseWriterWrapper) detectStreaming(contentType string) bool {
|
||||
return true
|
||||
}
|
||||
|
||||
// Check request body for streaming indicators
|
||||
if w.requestInfo.Body != nil {
|
||||
// If a concrete Content-Type is already set (e.g., application/json for error responses),
|
||||
// treat it as non-streaming instead of inferring from the request payload.
|
||||
if strings.TrimSpace(contentType) != "" {
|
||||
return false
|
||||
}
|
||||
|
||||
// Only fall back to request payload hints when Content-Type is not set yet.
|
||||
if w.requestInfo != nil && len(w.requestInfo.Body) > 0 {
|
||||
bodyStr := string(w.requestInfo.Body)
|
||||
if strings.Contains(bodyStr, `"stream": true`) || strings.Contains(bodyStr, `"stream":true`) {
|
||||
return true
|
||||
}
|
||||
return strings.Contains(bodyStr, `"stream": true`) || strings.Contains(bodyStr, `"stream":true`)
|
||||
}
|
||||
|
||||
return false
|
||||
@@ -221,7 +267,7 @@ func (w *ResponseWriterWrapper) Finalize(c *gin.Context) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
if w.isStreaming {
|
||||
if w.isStreaming && w.streamWriter != nil {
|
||||
if w.chunkChannel != nil {
|
||||
close(w.chunkChannel)
|
||||
w.chunkChannel = nil
|
||||
@@ -233,24 +279,19 @@ func (w *ResponseWriterWrapper) Finalize(c *gin.Context) error {
|
||||
}
|
||||
|
||||
// Write API Request and Response to the streaming log before closing
|
||||
if w.streamWriter != nil {
|
||||
apiRequest := w.extractAPIRequest(c)
|
||||
if len(apiRequest) > 0 {
|
||||
_ = w.streamWriter.WriteAPIRequest(apiRequest)
|
||||
}
|
||||
apiResponse := w.extractAPIResponse(c)
|
||||
if len(apiResponse) > 0 {
|
||||
_ = w.streamWriter.WriteAPIResponse(apiResponse)
|
||||
}
|
||||
if err := w.streamWriter.Close(); err != nil {
|
||||
w.streamWriter = nil
|
||||
return err
|
||||
}
|
||||
apiRequest := w.extractAPIRequest(c)
|
||||
if len(apiRequest) > 0 {
|
||||
_ = w.streamWriter.WriteAPIRequest(apiRequest)
|
||||
}
|
||||
apiResponse := w.extractAPIResponse(c)
|
||||
if len(apiResponse) > 0 {
|
||||
_ = w.streamWriter.WriteAPIResponse(apiResponse)
|
||||
}
|
||||
if err := w.streamWriter.Close(); err != nil {
|
||||
w.streamWriter = nil
|
||||
return err
|
||||
}
|
||||
if forceLog {
|
||||
return w.logRequest(finalStatusCode, w.cloneHeaders(), w.body.Bytes(), w.extractAPIRequest(c), w.extractAPIResponse(c), slicesAPIResponseError, forceLog)
|
||||
}
|
||||
w.streamWriter = nil
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -335,26 +376,3 @@ func (w *ResponseWriterWrapper) logRequest(statusCode int, headers map[string][]
|
||||
apiResponseErrors,
|
||||
)
|
||||
}
|
||||
|
||||
// Status returns the HTTP response status code captured by the wrapper.
|
||||
// It defaults to 200 if WriteHeader has not been called.
|
||||
func (w *ResponseWriterWrapper) Status() int {
|
||||
if w.statusCode == 0 {
|
||||
return 200 // Default status code
|
||||
}
|
||||
return w.statusCode
|
||||
}
|
||||
|
||||
// Size returns the size of the response body in bytes for non-streaming responses.
|
||||
// For streaming responses, it returns -1, as the total size is unknown.
|
||||
func (w *ResponseWriterWrapper) Size() int {
|
||||
if w.isStreaming {
|
||||
return -1 // Unknown size for streaming responses
|
||||
}
|
||||
return w.body.Len()
|
||||
}
|
||||
|
||||
// Written returns true if the response header has been written (i.e., a status code has been set).
|
||||
func (w *ResponseWriterWrapper) Written() bool {
|
||||
return w.statusCode != 0
|
||||
}
|
||||
|
||||
@@ -137,7 +137,8 @@ func (m *AmpModule) Register(ctx modules.Context) error {
|
||||
m.registerProviderAliases(ctx.Engine, ctx.BaseHandler, auth)
|
||||
|
||||
// Register management proxy routes once; middleware will gate access when upstream is unavailable.
|
||||
m.registerManagementRoutes(ctx.Engine, ctx.BaseHandler)
|
||||
// Pass auth middleware to require valid API key for all management routes.
|
||||
m.registerManagementRoutes(ctx.Engine, ctx.BaseHandler, auth)
|
||||
|
||||
// If no upstream URL, skip proxy routes but provider aliases are still available
|
||||
if upstreamURL == "" {
|
||||
@@ -187,9 +188,6 @@ func (m *AmpModule) OnConfigUpdated(cfg *config.Config) error {
|
||||
|
||||
if oldSettings != nil && oldSettings.RestrictManagementToLocalhost != newSettings.RestrictManagementToLocalhost {
|
||||
m.setRestrictToLocalhost(newSettings.RestrictManagementToLocalhost)
|
||||
if !newSettings.RestrictManagementToLocalhost {
|
||||
log.Warnf("amp management routes now accessible from any IP - this is insecure!")
|
||||
}
|
||||
}
|
||||
|
||||
newUpstreamURL := strings.TrimSpace(newSettings.UpstreamURL)
|
||||
|
||||
@@ -146,6 +146,9 @@ func TestAmpModule_OnConfigUpdated_CacheInvalidation(t *testing.T) {
|
||||
m := &AmpModule{enabled: true}
|
||||
ms := NewMultiSourceSecretWithPath("", p, time.Minute)
|
||||
m.secretSource = ms
|
||||
m.lastConfig = &config.AmpCode{
|
||||
UpstreamAPIKey: "old-key",
|
||||
}
|
||||
|
||||
// Warm the cache
|
||||
if _, err := ms.Get(context.Background()); err != nil {
|
||||
@@ -157,7 +160,7 @@ func TestAmpModule_OnConfigUpdated_CacheInvalidation(t *testing.T) {
|
||||
}
|
||||
|
||||
// Update config - should invalidate cache
|
||||
if err := m.OnConfigUpdated(&config.Config{AmpCode: config.AmpCode{UpstreamURL: "http://x"}}); err != nil {
|
||||
if err := m.OnConfigUpdated(&config.Config{AmpCode: config.AmpCode{UpstreamURL: "http://x", UpstreamAPIKey: "new-key"}}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
|
||||
@@ -64,7 +64,7 @@ func logAmpRouting(routeType AmpRouteType, requestedModel, resolvedModel, provid
|
||||
fields["cost"] = "amp_credits"
|
||||
fields["source"] = "ampcode.com"
|
||||
fields["model_id"] = requestedModel // Explicit model_id for easy config reference
|
||||
log.WithFields(fields).Warnf("forwarding to ampcode.com (uses amp credits) - model_id: %s | To use local proxy, add to config: amp-model-mappings: [{from: \"%s\", to: \"<your-local-model>\"}]", requestedModel, requestedModel)
|
||||
log.WithFields(fields).Warnf("forwarding to ampcode.com (uses amp credits) - model_id: %s | To use local provider, add to config: ampcode.model-mappings: [{from: \"%s\", to: \"<your-local-model>\"}]", requestedModel, requestedModel)
|
||||
|
||||
case RouteTypeNoProvider:
|
||||
fields["cost"] = "none"
|
||||
@@ -134,7 +134,43 @@ func (fh *FallbackHandler) WrapHandler(handler gin.HandlerFunc) gin.HandlerFunc
|
||||
}
|
||||
|
||||
// Normalize model (handles dynamic thinking suffixes)
|
||||
normalizedModel, _ := util.NormalizeThinkingModel(modelName)
|
||||
normalizedModel, thinkingMetadata := util.NormalizeThinkingModel(modelName)
|
||||
thinkingSuffix := ""
|
||||
if thinkingMetadata != nil && strings.HasPrefix(modelName, normalizedModel) {
|
||||
thinkingSuffix = modelName[len(normalizedModel):]
|
||||
}
|
||||
|
||||
resolveMappedModel := func() (string, []string) {
|
||||
if fh.modelMapper == nil {
|
||||
return "", nil
|
||||
}
|
||||
|
||||
mappedModel := fh.modelMapper.MapModel(modelName)
|
||||
if mappedModel == "" {
|
||||
mappedModel = fh.modelMapper.MapModel(normalizedModel)
|
||||
}
|
||||
mappedModel = strings.TrimSpace(mappedModel)
|
||||
if mappedModel == "" {
|
||||
return "", nil
|
||||
}
|
||||
|
||||
// Preserve dynamic thinking suffix (e.g. "(xhigh)") when mapping applies, unless the target
|
||||
// already specifies its own thinking suffix.
|
||||
if thinkingSuffix != "" {
|
||||
_, mappedThinkingMetadata := util.NormalizeThinkingModel(mappedModel)
|
||||
if mappedThinkingMetadata == nil {
|
||||
mappedModel += thinkingSuffix
|
||||
}
|
||||
}
|
||||
|
||||
mappedBaseModel, _ := util.NormalizeThinkingModel(mappedModel)
|
||||
mappedProviders := util.GetProviderName(mappedBaseModel)
|
||||
if len(mappedProviders) == 0 {
|
||||
return "", nil
|
||||
}
|
||||
|
||||
return mappedModel, mappedProviders
|
||||
}
|
||||
|
||||
// Track resolved model for logging (may change if mapping is applied)
|
||||
resolvedModel := normalizedModel
|
||||
@@ -147,21 +183,15 @@ func (fh *FallbackHandler) WrapHandler(handler gin.HandlerFunc) gin.HandlerFunc
|
||||
if forceMappings {
|
||||
// FORCE MODE: Check model mappings FIRST (takes precedence over local API keys)
|
||||
// This allows users to route Amp requests to their preferred OAuth providers
|
||||
if fh.modelMapper != nil {
|
||||
if mappedModel := fh.modelMapper.MapModel(normalizedModel); mappedModel != "" {
|
||||
// Mapping found - check if we have a provider for the mapped model
|
||||
mappedProviders := util.GetProviderName(mappedModel)
|
||||
if len(mappedProviders) > 0 {
|
||||
// Mapping found and provider available - rewrite the model in request body
|
||||
bodyBytes = rewriteModelInRequest(bodyBytes, mappedModel)
|
||||
c.Request.Body = io.NopCloser(bytes.NewReader(bodyBytes))
|
||||
// Store mapped model in context for handlers that check it (like gemini bridge)
|
||||
c.Set(MappedModelContextKey, mappedModel)
|
||||
resolvedModel = mappedModel
|
||||
usedMapping = true
|
||||
providers = mappedProviders
|
||||
}
|
||||
}
|
||||
if mappedModel, mappedProviders := resolveMappedModel(); mappedModel != "" {
|
||||
// Mapping found and provider available - rewrite the model in request body
|
||||
bodyBytes = rewriteModelInRequest(bodyBytes, mappedModel)
|
||||
c.Request.Body = io.NopCloser(bytes.NewReader(bodyBytes))
|
||||
// Store mapped model in context for handlers that check it (like gemini bridge)
|
||||
c.Set(MappedModelContextKey, mappedModel)
|
||||
resolvedModel = mappedModel
|
||||
usedMapping = true
|
||||
providers = mappedProviders
|
||||
}
|
||||
|
||||
// If no mapping applied, check for local providers
|
||||
@@ -174,21 +204,15 @@ func (fh *FallbackHandler) WrapHandler(handler gin.HandlerFunc) gin.HandlerFunc
|
||||
|
||||
if len(providers) == 0 {
|
||||
// No providers configured - check if we have a model mapping
|
||||
if fh.modelMapper != nil {
|
||||
if mappedModel := fh.modelMapper.MapModel(normalizedModel); mappedModel != "" {
|
||||
// Mapping found - check if we have a provider for the mapped model
|
||||
mappedProviders := util.GetProviderName(mappedModel)
|
||||
if len(mappedProviders) > 0 {
|
||||
// Mapping found and provider available - rewrite the model in request body
|
||||
bodyBytes = rewriteModelInRequest(bodyBytes, mappedModel)
|
||||
c.Request.Body = io.NopCloser(bytes.NewReader(bodyBytes))
|
||||
// Store mapped model in context for handlers that check it (like gemini bridge)
|
||||
c.Set(MappedModelContextKey, mappedModel)
|
||||
resolvedModel = mappedModel
|
||||
usedMapping = true
|
||||
providers = mappedProviders
|
||||
}
|
||||
}
|
||||
if mappedModel, mappedProviders := resolveMappedModel(); mappedModel != "" {
|
||||
// Mapping found and provider available - rewrite the model in request body
|
||||
bodyBytes = rewriteModelInRequest(bodyBytes, mappedModel)
|
||||
c.Request.Body = io.NopCloser(bytes.NewReader(bodyBytes))
|
||||
// Store mapped model in context for handlers that check it (like gemini bridge)
|
||||
c.Set(MappedModelContextKey, mappedModel)
|
||||
resolvedModel = mappedModel
|
||||
usedMapping = true
|
||||
providers = mappedProviders
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -222,14 +246,14 @@ func (fh *FallbackHandler) WrapHandler(handler gin.HandlerFunc) gin.HandlerFunc
|
||||
// Log: Model was mapped to another model
|
||||
log.Debugf("amp model mapping: request %s -> %s", normalizedModel, resolvedModel)
|
||||
logAmpRouting(RouteTypeModelMapping, modelName, resolvedModel, providerName, requestPath)
|
||||
rewriter := NewResponseRewriter(c.Writer, normalizedModel)
|
||||
rewriter := NewResponseRewriter(c.Writer, modelName)
|
||||
c.Writer = rewriter
|
||||
// Filter Anthropic-Beta header only for local handling paths
|
||||
filterAntropicBetaHeader(c)
|
||||
c.Request.Body = io.NopCloser(bytes.NewReader(bodyBytes))
|
||||
handler(c)
|
||||
rewriter.Flush()
|
||||
log.Debugf("amp model mapping: response %s -> %s", resolvedModel, normalizedModel)
|
||||
log.Debugf("amp model mapping: response %s -> %s", resolvedModel, modelName)
|
||||
} else if len(providers) > 0 {
|
||||
// Log: Using local provider (free)
|
||||
logAmpRouting(RouteTypeLocalProvider, modelName, resolvedModel, providerName, requestPath)
|
||||
|
||||
73
internal/api/modules/amp/fallback_handlers_test.go
Normal file
73
internal/api/modules/amp/fallback_handlers_test.go
Normal file
@@ -0,0 +1,73 @@
|
||||
package amp
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"net/http/httputil"
|
||||
"testing"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/registry"
|
||||
)
|
||||
|
||||
func TestFallbackHandler_ModelMapping_PreservesThinkingSuffixAndRewritesResponse(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
reg := registry.GetGlobalRegistry()
|
||||
reg.RegisterClient("test-client-amp-fallback", "codex", []*registry.ModelInfo{
|
||||
{ID: "test/gpt-5.2", OwnedBy: "openai", Type: "codex"},
|
||||
})
|
||||
defer reg.UnregisterClient("test-client-amp-fallback")
|
||||
|
||||
mapper := NewModelMapper([]config.AmpModelMapping{
|
||||
{From: "gpt-5.2", To: "test/gpt-5.2"},
|
||||
})
|
||||
|
||||
fallback := NewFallbackHandlerWithMapper(func() *httputil.ReverseProxy { return nil }, mapper, nil)
|
||||
|
||||
handler := func(c *gin.Context) {
|
||||
var req struct {
|
||||
Model string `json:"model"`
|
||||
}
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"model": req.Model,
|
||||
"seen_model": req.Model,
|
||||
})
|
||||
}
|
||||
|
||||
r := gin.New()
|
||||
r.POST("/chat/completions", fallback.WrapHandler(handler))
|
||||
|
||||
reqBody := []byte(`{"model":"gpt-5.2(xhigh)"}`)
|
||||
req := httptest.NewRequest(http.MethodPost, "/chat/completions", bytes.NewReader(reqBody))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
w := httptest.NewRecorder()
|
||||
r.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("Expected status 200, got %d", w.Code)
|
||||
}
|
||||
|
||||
var resp struct {
|
||||
Model string `json:"model"`
|
||||
SeenModel string `json:"seen_model"`
|
||||
}
|
||||
if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil {
|
||||
t.Fatalf("Failed to parse response JSON: %v", err)
|
||||
}
|
||||
|
||||
if resp.Model != "gpt-5.2(xhigh)" {
|
||||
t.Errorf("Expected response model gpt-5.2(xhigh), got %s", resp.Model)
|
||||
}
|
||||
if resp.SeenModel != "test/gpt-5.2(xhigh)" {
|
||||
t.Errorf("Expected handler to see test/gpt-5.2(xhigh), got %s", resp.SeenModel)
|
||||
}
|
||||
}
|
||||
@@ -59,7 +59,8 @@ func (m *DefaultModelMapper) MapModel(requestedModel string) string {
|
||||
}
|
||||
|
||||
// Verify target model has available providers
|
||||
providers := util.GetProviderName(targetModel)
|
||||
normalizedTarget, _ := util.NormalizeThinkingModel(targetModel)
|
||||
providers := util.GetProviderName(normalizedTarget)
|
||||
if len(providers) == 0 {
|
||||
log.Debugf("amp model mapping: target model %s has no available providers, skipping mapping", targetModel)
|
||||
return ""
|
||||
|
||||
@@ -71,6 +71,25 @@ func TestModelMapper_MapModel_WithProvider(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestModelMapper_MapModel_TargetWithThinkingSuffix(t *testing.T) {
|
||||
reg := registry.GetGlobalRegistry()
|
||||
reg.RegisterClient("test-client-thinking", "codex", []*registry.ModelInfo{
|
||||
{ID: "gpt-5.2", OwnedBy: "openai", Type: "codex"},
|
||||
})
|
||||
defer reg.UnregisterClient("test-client-thinking")
|
||||
|
||||
mappings := []config.AmpModelMapping{
|
||||
{From: "gpt-5.2-alias", To: "gpt-5.2(xhigh)"},
|
||||
}
|
||||
|
||||
mapper := NewModelMapper(mappings)
|
||||
|
||||
result := mapper.MapModel("gpt-5.2-alias")
|
||||
if result != "gpt-5.2(xhigh)" {
|
||||
t.Errorf("Expected gpt-5.2(xhigh), got %s", result)
|
||||
}
|
||||
}
|
||||
|
||||
func TestModelMapper_MapModel_CaseInsensitive(t *testing.T) {
|
||||
reg := registry.GetGlobalRegistry()
|
||||
reg.RegisterClient("test-client2", "claude", []*registry.ModelInfo{
|
||||
|
||||
@@ -41,6 +41,11 @@ func createReverseProxy(upstreamURL string, secretSource SecretSource) (*httputi
|
||||
originalDirector(req)
|
||||
req.Host = parsed.Host
|
||||
|
||||
// Remove client's Authorization header - it was only used for CLI Proxy API authentication
|
||||
// We will set our own Authorization using the configured upstream-api-key
|
||||
req.Header.Del("Authorization")
|
||||
req.Header.Del("X-Api-Key")
|
||||
|
||||
// Preserve correlation headers for debugging
|
||||
if req.Header.Get("X-Request-ID") == "" {
|
||||
// Could generate one here if needed
|
||||
@@ -50,7 +55,7 @@ func createReverseProxy(upstreamURL string, secretSource SecretSource) (*httputi
|
||||
// Users going through ampcode.com proxy are paying for the service and should get all features
|
||||
// including 1M context window (context-1m-2025-08-07)
|
||||
|
||||
// Inject API key from secret source (precedence: config > env > file)
|
||||
// Inject API key from secret source (only uses upstream-api-key from config)
|
||||
if key, err := secretSource.Get(req.Context()); err == nil && key != "" {
|
||||
req.Header.Set("X-Api-Key", key)
|
||||
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", key))
|
||||
|
||||
@@ -39,7 +39,13 @@ func (rw *ResponseRewriter) Write(data []byte) (int, error) {
|
||||
}
|
||||
|
||||
if rw.isStreaming {
|
||||
return rw.ResponseWriter.Write(rw.rewriteStreamChunk(data))
|
||||
n, err := rw.ResponseWriter.Write(rw.rewriteStreamChunk(data))
|
||||
if err == nil {
|
||||
if flusher, ok := rw.ResponseWriter.(http.Flusher); ok {
|
||||
flusher.Flush()
|
||||
}
|
||||
}
|
||||
return n, err
|
||||
}
|
||||
return rw.body.Write(data)
|
||||
}
|
||||
|
||||
@@ -95,10 +95,25 @@ func (m *AmpModule) managementAvailabilityMiddleware() gin.HandlerFunc {
|
||||
}
|
||||
}
|
||||
|
||||
// wrapManagementAuth skips auth for selected management paths while keeping authentication elsewhere.
|
||||
func wrapManagementAuth(auth gin.HandlerFunc, prefixes ...string) gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
path := c.Request.URL.Path
|
||||
for _, prefix := range prefixes {
|
||||
if strings.HasPrefix(path, prefix) && (len(path) == len(prefix) || path[len(prefix)] == '/') {
|
||||
c.Next()
|
||||
return
|
||||
}
|
||||
}
|
||||
auth(c)
|
||||
}
|
||||
}
|
||||
|
||||
// registerManagementRoutes registers Amp management proxy routes
|
||||
// These routes proxy through to the Amp control plane for OAuth, user management, etc.
|
||||
// Uses dynamic middleware and proxy getter for hot-reload support.
|
||||
func (m *AmpModule) registerManagementRoutes(engine *gin.Engine, baseHandler *handlers.BaseAPIHandler) {
|
||||
// The auth middleware validates Authorization header against configured API keys.
|
||||
func (m *AmpModule) registerManagementRoutes(engine *gin.Engine, baseHandler *handlers.BaseAPIHandler, auth gin.HandlerFunc) {
|
||||
ampAPI := engine.Group("/api")
|
||||
|
||||
// Always disable CORS for management routes to prevent browser-based attacks
|
||||
@@ -107,8 +122,11 @@ func (m *AmpModule) registerManagementRoutes(engine *gin.Engine, baseHandler *ha
|
||||
// Apply dynamic localhost-only restriction (hot-reloadable via m.IsRestrictedToLocalhost())
|
||||
ampAPI.Use(m.localhostOnlyMiddleware())
|
||||
|
||||
if !m.IsRestrictedToLocalhost() {
|
||||
log.Warn("amp management routes are NOT restricted to localhost - this is insecure!")
|
||||
// Apply authentication middleware - requires valid API key in Authorization header
|
||||
var authWithBypass gin.HandlerFunc
|
||||
if auth != nil {
|
||||
ampAPI.Use(auth)
|
||||
authWithBypass = wrapManagementAuth(auth, "/threads", "/auth", "/docs")
|
||||
}
|
||||
|
||||
// Dynamic proxy handler that uses m.getProxy() for hot-reload support
|
||||
@@ -154,7 +172,14 @@ func (m *AmpModule) registerManagementRoutes(engine *gin.Engine, baseHandler *ha
|
||||
// Root-level routes that AMP CLI expects without /api prefix
|
||||
// These need the same security middleware as the /api/* routes (dynamic for hot-reload)
|
||||
rootMiddleware := []gin.HandlerFunc{m.managementAvailabilityMiddleware(), noCORSMiddleware(), m.localhostOnlyMiddleware()}
|
||||
if authWithBypass != nil {
|
||||
rootMiddleware = append(rootMiddleware, authWithBypass)
|
||||
}
|
||||
engine.GET("/threads", append(rootMiddleware, proxyHandler)...)
|
||||
engine.GET("/threads/*path", append(rootMiddleware, proxyHandler)...)
|
||||
engine.GET("/docs", append(rootMiddleware, proxyHandler)...)
|
||||
engine.GET("/docs/*path", append(rootMiddleware, proxyHandler)...)
|
||||
|
||||
engine.GET("/threads.rss", append(rootMiddleware, proxyHandler)...)
|
||||
engine.GET("/news.rss", append(rootMiddleware, proxyHandler)...)
|
||||
|
||||
@@ -262,7 +287,7 @@ func (m *AmpModule) registerProviderAliases(engine *gin.Engine, baseHandler *han
|
||||
v1betaAmp := provider.Group("/v1beta")
|
||||
{
|
||||
v1betaAmp.GET("/models", geminiHandlers.GeminiModels)
|
||||
v1betaAmp.POST("/models/:action", fallbackHandler.WrapHandler(geminiHandlers.GeminiHandler))
|
||||
v1betaAmp.GET("/models/:action", geminiHandlers.GeminiGetHandler)
|
||||
v1betaAmp.POST("/models/*action", fallbackHandler.WrapHandler(geminiHandlers.GeminiHandler))
|
||||
v1betaAmp.GET("/models/*action", geminiHandlers.GeminiGetHandler)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -32,7 +32,9 @@ func TestRegisterManagementRoutes(t *testing.T) {
|
||||
m.setProxy(proxy)
|
||||
|
||||
base := &handlers.BaseAPIHandler{}
|
||||
m.registerManagementRoutes(r, base)
|
||||
m.registerManagementRoutes(r, base, nil)
|
||||
srv := httptest.NewServer(r)
|
||||
defer srv.Close()
|
||||
|
||||
managementPaths := []struct {
|
||||
path string
|
||||
@@ -63,11 +65,17 @@ func TestRegisterManagementRoutes(t *testing.T) {
|
||||
for _, path := range managementPaths {
|
||||
t.Run(path.path, func(t *testing.T) {
|
||||
proxyCalled = false
|
||||
req := httptest.NewRequest(path.method, path.path, nil)
|
||||
w := httptest.NewRecorder()
|
||||
r.ServeHTTP(w, req)
|
||||
req, err := http.NewRequest(path.method, srv.URL+path.path, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to build request: %v", err)
|
||||
}
|
||||
resp, err := http.DefaultClient.Do(req)
|
||||
if err != nil {
|
||||
t.Fatalf("request failed: %v", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if w.Code == http.StatusNotFound {
|
||||
if resp.StatusCode == http.StatusNotFound {
|
||||
t.Fatalf("route %s not registered", path.path)
|
||||
}
|
||||
if !proxyCalled {
|
||||
|
||||
@@ -230,13 +230,9 @@ func NewServer(cfg *config.Config, authManager *auth.Manager, accessManager *sdk
|
||||
envManagementSecret := envAdminPasswordSet && envAdminPassword != ""
|
||||
|
||||
// Create server instance
|
||||
providerNames := make([]string, 0, len(cfg.OpenAICompatibility))
|
||||
for _, p := range cfg.OpenAICompatibility {
|
||||
providerNames = append(providerNames, p.Name)
|
||||
}
|
||||
s := &Server{
|
||||
engine: engine,
|
||||
handlers: handlers.NewBaseAPIHandlers(&cfg.SDKConfig, authManager, providerNames),
|
||||
handlers: handlers.NewBaseAPIHandlers(&cfg.SDKConfig, authManager),
|
||||
cfg: cfg,
|
||||
accessManager: accessManager,
|
||||
requestLogger: requestLogger,
|
||||
@@ -334,8 +330,8 @@ func (s *Server) setupRoutes() {
|
||||
v1beta.Use(AuthMiddleware(s.accessManager))
|
||||
{
|
||||
v1beta.GET("/models", geminiHandlers.GeminiModels)
|
||||
v1beta.POST("/models/:action", geminiHandlers.GeminiHandler)
|
||||
v1beta.GET("/models/:action", geminiHandlers.GeminiGetHandler)
|
||||
v1beta.POST("/models/*action", geminiHandlers.GeminiHandler)
|
||||
v1beta.GET("/models/*action", geminiHandlers.GeminiGetHandler)
|
||||
}
|
||||
|
||||
// Root endpoint
|
||||
@@ -358,10 +354,11 @@ func (s *Server) setupRoutes() {
|
||||
code := c.Query("code")
|
||||
state := c.Query("state")
|
||||
errStr := c.Query("error")
|
||||
// Persist to a temporary file keyed by state
|
||||
if errStr == "" {
|
||||
errStr = c.Query("error_description")
|
||||
}
|
||||
if state != "" {
|
||||
file := fmt.Sprintf("%s/.oauth-anthropic-%s.oauth", s.cfg.AuthDir, state)
|
||||
_ = os.WriteFile(file, []byte(fmt.Sprintf(`{"code":"%s","state":"%s","error":"%s"}`, code, state, errStr)), 0o600)
|
||||
_, _ = managementHandlers.WriteOAuthCallbackFileForPendingSession(s.cfg.AuthDir, "anthropic", state, code, errStr)
|
||||
}
|
||||
c.Header("Content-Type", "text/html; charset=utf-8")
|
||||
c.String(http.StatusOK, oauthCallbackSuccessHTML)
|
||||
@@ -371,9 +368,11 @@ func (s *Server) setupRoutes() {
|
||||
code := c.Query("code")
|
||||
state := c.Query("state")
|
||||
errStr := c.Query("error")
|
||||
if errStr == "" {
|
||||
errStr = c.Query("error_description")
|
||||
}
|
||||
if state != "" {
|
||||
file := fmt.Sprintf("%s/.oauth-codex-%s.oauth", s.cfg.AuthDir, state)
|
||||
_ = os.WriteFile(file, []byte(fmt.Sprintf(`{"code":"%s","state":"%s","error":"%s"}`, code, state, errStr)), 0o600)
|
||||
_, _ = managementHandlers.WriteOAuthCallbackFileForPendingSession(s.cfg.AuthDir, "codex", state, code, errStr)
|
||||
}
|
||||
c.Header("Content-Type", "text/html; charset=utf-8")
|
||||
c.String(http.StatusOK, oauthCallbackSuccessHTML)
|
||||
@@ -383,9 +382,11 @@ func (s *Server) setupRoutes() {
|
||||
code := c.Query("code")
|
||||
state := c.Query("state")
|
||||
errStr := c.Query("error")
|
||||
if errStr == "" {
|
||||
errStr = c.Query("error_description")
|
||||
}
|
||||
if state != "" {
|
||||
file := fmt.Sprintf("%s/.oauth-gemini-%s.oauth", s.cfg.AuthDir, state)
|
||||
_ = os.WriteFile(file, []byte(fmt.Sprintf(`{"code":"%s","state":"%s","error":"%s"}`, code, state, errStr)), 0o600)
|
||||
_, _ = managementHandlers.WriteOAuthCallbackFileForPendingSession(s.cfg.AuthDir, "gemini", state, code, errStr)
|
||||
}
|
||||
c.Header("Content-Type", "text/html; charset=utf-8")
|
||||
c.String(http.StatusOK, oauthCallbackSuccessHTML)
|
||||
@@ -395,9 +396,11 @@ func (s *Server) setupRoutes() {
|
||||
code := c.Query("code")
|
||||
state := c.Query("state")
|
||||
errStr := c.Query("error")
|
||||
if errStr == "" {
|
||||
errStr = c.Query("error_description")
|
||||
}
|
||||
if state != "" {
|
||||
file := fmt.Sprintf("%s/.oauth-iflow-%s.oauth", s.cfg.AuthDir, state)
|
||||
_ = os.WriteFile(file, []byte(fmt.Sprintf(`{"code":"%s","state":"%s","error":"%s"}`, code, state, errStr)), 0o600)
|
||||
_, _ = managementHandlers.WriteOAuthCallbackFileForPendingSession(s.cfg.AuthDir, "iflow", state, code, errStr)
|
||||
}
|
||||
c.Header("Content-Type", "text/html; charset=utf-8")
|
||||
c.String(http.StatusOK, oauthCallbackSuccessHTML)
|
||||
@@ -407,9 +410,11 @@ func (s *Server) setupRoutes() {
|
||||
code := c.Query("code")
|
||||
state := c.Query("state")
|
||||
errStr := c.Query("error")
|
||||
if errStr == "" {
|
||||
errStr = c.Query("error_description")
|
||||
}
|
||||
if state != "" {
|
||||
file := fmt.Sprintf("%s/.oauth-antigravity-%s.oauth", s.cfg.AuthDir, state)
|
||||
_ = os.WriteFile(file, []byte(fmt.Sprintf(`{"code":"%s","state":"%s","error":"%s"}`, code, state, errStr)), 0o600)
|
||||
_, _ = managementHandlers.WriteOAuthCallbackFileForPendingSession(s.cfg.AuthDir, "antigravity", state, code, errStr)
|
||||
}
|
||||
c.Header("Content-Type", "text/html; charset=utf-8")
|
||||
c.String(http.StatusOK, oauthCallbackSuccessHTML)
|
||||
@@ -568,6 +573,7 @@ func (s *Server) registerManagementRoutes() {
|
||||
mgmt.DELETE("/oauth-excluded-models", s.mgmt.DeleteOAuthExcludedModels)
|
||||
|
||||
mgmt.GET("/auth-files", s.mgmt.ListAuthFiles)
|
||||
mgmt.GET("/auth-files/models", s.mgmt.GetAuthFileModels)
|
||||
mgmt.GET("/auth-files/download", s.mgmt.DownloadAuthFile)
|
||||
mgmt.POST("/auth-files", s.mgmt.UploadAuthFile)
|
||||
mgmt.DELETE("/auth-files", s.mgmt.DeleteAuthFile)
|
||||
@@ -580,6 +586,7 @@ func (s *Server) registerManagementRoutes() {
|
||||
mgmt.GET("/qwen-auth-url", s.mgmt.RequestQwenToken)
|
||||
mgmt.GET("/iflow-auth-url", s.mgmt.RequestIFlowToken)
|
||||
mgmt.POST("/iflow-auth-url", s.mgmt.RequestIFlowCookieToken)
|
||||
mgmt.POST("/oauth-callback", s.mgmt.PostOAuthCallback)
|
||||
mgmt.GET("/get-auth-status", s.mgmt.GetAuthStatus)
|
||||
}
|
||||
}
|
||||
@@ -608,7 +615,7 @@ func (s *Server) serveManagementControlPanel(c *gin.Context) {
|
||||
|
||||
if _, err := os.Stat(filePath); err != nil {
|
||||
if os.IsNotExist(err) {
|
||||
go managementasset.EnsureLatestManagementHTML(context.Background(), managementasset.StaticDir(s.configFilePath), cfg.ProxyURL)
|
||||
go managementasset.EnsureLatestManagementHTML(context.Background(), managementasset.StaticDir(s.configFilePath), cfg.ProxyURL, cfg.RemoteManagement.PanelGitHubRepository)
|
||||
c.AbortWithStatus(http.StatusNotFound)
|
||||
return
|
||||
}
|
||||
@@ -837,11 +844,20 @@ func (s *Server) UpdateClients(cfg *config.Config) {
|
||||
}
|
||||
}
|
||||
|
||||
if oldCfg != nil && oldCfg.LoggingToFile != cfg.LoggingToFile {
|
||||
if err := logging.ConfigureLogOutput(cfg.LoggingToFile); err != nil {
|
||||
if oldCfg == nil || oldCfg.LoggingToFile != cfg.LoggingToFile || oldCfg.LogsMaxTotalSizeMB != cfg.LogsMaxTotalSizeMB {
|
||||
if err := logging.ConfigureLogOutput(cfg.LoggingToFile, cfg.LogsMaxTotalSizeMB); err != nil {
|
||||
log.Errorf("failed to reconfigure log output: %v", err)
|
||||
} else {
|
||||
log.Debugf("logging_to_file updated from %t to %t", oldCfg.LoggingToFile, cfg.LoggingToFile)
|
||||
if oldCfg == nil {
|
||||
log.Debug("log output configuration refreshed")
|
||||
} else {
|
||||
if oldCfg.LoggingToFile != cfg.LoggingToFile {
|
||||
log.Debugf("logging_to_file updated from %t to %t", oldCfg.LoggingToFile, cfg.LoggingToFile)
|
||||
}
|
||||
if oldCfg.LogsMaxTotalSizeMB != cfg.LogsMaxTotalSizeMB {
|
||||
log.Debugf("logs_max_total_size_mb updated from %d to %d", oldCfg.LogsMaxTotalSizeMB, cfg.LogsMaxTotalSizeMB)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -918,17 +934,11 @@ func (s *Server) UpdateClients(cfg *config.Config) {
|
||||
// Save YAML snapshot for next comparison
|
||||
s.oldConfigYaml, _ = yaml.Marshal(cfg)
|
||||
|
||||
providerNames := make([]string, 0, len(cfg.OpenAICompatibility))
|
||||
for _, p := range cfg.OpenAICompatibility {
|
||||
providerNames = append(providerNames, p.Name)
|
||||
}
|
||||
s.handlers.OpenAICompatProviders = providerNames
|
||||
|
||||
s.handlers.UpdateClients(&cfg.SDKConfig)
|
||||
|
||||
if !cfg.RemoteManagement.DisableControlPanel {
|
||||
staticDir := managementasset.StaticDir(s.configFilePath)
|
||||
go managementasset.EnsureLatestManagementHTML(context.Background(), staticDir, cfg.ProxyURL)
|
||||
go managementasset.EnsureLatestManagementHTML(context.Background(), staticDir, cfg.ProxyURL, cfg.RemoteManagement.PanelGitHubRepository)
|
||||
}
|
||||
if s.mgmt != nil {
|
||||
s.mgmt.SetConfig(cfg)
|
||||
|
||||
@@ -12,14 +12,15 @@ import (
|
||||
"strings"
|
||||
"syscall"
|
||||
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/sdk/config"
|
||||
"golang.org/x/crypto/bcrypt"
|
||||
"gopkg.in/yaml.v3"
|
||||
)
|
||||
|
||||
const DefaultPanelGitHubRepository = "https://github.com/router-for-me/Cli-Proxy-API-Management-Center"
|
||||
|
||||
// Config represents the application's configuration, loaded from a YAML file.
|
||||
type Config struct {
|
||||
config.SDKConfig `yaml:",inline"`
|
||||
SDKConfig `yaml:",inline"`
|
||||
// Host is the network host/interface on which the API server will bind.
|
||||
// Default is empty ("") to bind all interfaces (IPv4 + IPv6). Use "127.0.0.1" or "localhost" for local-only access.
|
||||
Host string `yaml:"host" json:"-"`
|
||||
@@ -41,6 +42,10 @@ type Config struct {
|
||||
// LoggingToFile controls whether application logs are written to rotating files or stdout.
|
||||
LoggingToFile bool `yaml:"logging-to-file" json:"logging-to-file"`
|
||||
|
||||
// LogsMaxTotalSizeMB limits the total size (in MB) of log files under the logs directory.
|
||||
// When exceeded, the oldest log files are deleted until within the limit. Set to 0 to disable.
|
||||
LogsMaxTotalSizeMB int `yaml:"logs-max-total-size-mb" json:"logs-max-total-size-mb"`
|
||||
|
||||
// UsageStatisticsEnabled toggles in-memory usage aggregation; when false, usage data is discarded.
|
||||
UsageStatisticsEnabled bool `yaml:"usage-statistics-enabled" json:"usage-statistics-enabled"`
|
||||
|
||||
@@ -104,6 +109,9 @@ type RemoteManagement struct {
|
||||
SecretKey string `yaml:"secret-key"`
|
||||
// DisableControlPanel skips serving and syncing the bundled management UI when true.
|
||||
DisableControlPanel bool `yaml:"disable-control-panel"`
|
||||
// PanelGitHubRepository overrides the GitHub repository used to fetch the management panel asset.
|
||||
// Accepts either a repository URL (https://github.com/org/repo) or an API releases endpoint.
|
||||
PanelGitHubRepository string `yaml:"panel-github-repository"`
|
||||
}
|
||||
|
||||
// QuotaExceeded defines the behavior when API quota limits are exceeded.
|
||||
@@ -139,7 +147,7 @@ type AmpCode struct {
|
||||
|
||||
// RestrictManagementToLocalhost restricts Amp management routes (/api/user, /api/threads, etc.)
|
||||
// to only accept connections from localhost (127.0.0.1, ::1). When true, prevents drive-by
|
||||
// browser attacks and remote access to management endpoints. Default: true (recommended).
|
||||
// browser attacks and remote access to management endpoints. Default: false (API key auth is sufficient).
|
||||
RestrictManagementToLocalhost bool `yaml:"restrict-management-to-localhost" json:"restrict-management-to-localhost"`
|
||||
|
||||
// ModelMappings defines model name mappings for Amp CLI requests.
|
||||
@@ -182,6 +190,9 @@ type ClaudeKey struct {
|
||||
// APIKey is the authentication key for accessing Claude API services.
|
||||
APIKey string `yaml:"api-key" json:"api-key"`
|
||||
|
||||
// Prefix optionally namespaces models for this credential (e.g., "teamA/claude-sonnet-4").
|
||||
Prefix string `yaml:"prefix,omitempty" json:"prefix,omitempty"`
|
||||
|
||||
// BaseURL is the base URL for the Claude API endpoint.
|
||||
// If empty, the default Claude API URL will be used.
|
||||
BaseURL string `yaml:"base-url" json:"base-url"`
|
||||
@@ -214,6 +225,9 @@ type CodexKey struct {
|
||||
// APIKey is the authentication key for accessing Codex API services.
|
||||
APIKey string `yaml:"api-key" json:"api-key"`
|
||||
|
||||
// Prefix optionally namespaces models for this credential (e.g., "teamA/gpt-5-codex").
|
||||
Prefix string `yaml:"prefix,omitempty" json:"prefix,omitempty"`
|
||||
|
||||
// BaseURL is the base URL for the Codex API endpoint.
|
||||
// If empty, the default Codex API URL will be used.
|
||||
BaseURL string `yaml:"base-url" json:"base-url"`
|
||||
@@ -234,6 +248,9 @@ type GeminiKey struct {
|
||||
// APIKey is the authentication key for accessing Gemini API services.
|
||||
APIKey string `yaml:"api-key" json:"api-key"`
|
||||
|
||||
// Prefix optionally namespaces models for this credential (e.g., "teamA/gemini-3-pro-preview").
|
||||
Prefix string `yaml:"prefix,omitempty" json:"prefix,omitempty"`
|
||||
|
||||
// BaseURL optionally overrides the Gemini API endpoint.
|
||||
BaseURL string `yaml:"base-url,omitempty" json:"base-url,omitempty"`
|
||||
|
||||
@@ -253,6 +270,9 @@ type OpenAICompatibility struct {
|
||||
// Name is the identifier for this OpenAI compatibility configuration.
|
||||
Name string `yaml:"name" json:"name"`
|
||||
|
||||
// Prefix optionally namespaces model aliases for this provider (e.g., "teamA/kimi-k2").
|
||||
Prefix string `yaml:"prefix,omitempty" json:"prefix,omitempty"`
|
||||
|
||||
// BaseURL is the base URL for the external OpenAI-compatible API endpoint.
|
||||
BaseURL string `yaml:"base-url" json:"base-url"`
|
||||
|
||||
@@ -325,9 +345,11 @@ func LoadConfigOptional(configFile string, optional bool) (*Config, error) {
|
||||
// Set defaults before unmarshal so that absent keys keep defaults.
|
||||
cfg.Host = "" // Default empty: binds to all interfaces (IPv4 + IPv6)
|
||||
cfg.LoggingToFile = false
|
||||
cfg.LogsMaxTotalSizeMB = 0
|
||||
cfg.UsageStatisticsEnabled = false
|
||||
cfg.DisableCooling = false
|
||||
cfg.AmpCode.RestrictManagementToLocalhost = true // Default to secure: only localhost access
|
||||
cfg.AmpCode.RestrictManagementToLocalhost = false // Default to false: API key auth is sufficient
|
||||
cfg.RemoteManagement.PanelGitHubRepository = DefaultPanelGitHubRepository
|
||||
if err = yaml.Unmarshal(data, &cfg); err != nil {
|
||||
if optional {
|
||||
// In cloud deploy mode, if YAML parsing fails, return empty config instead of error.
|
||||
@@ -363,6 +385,15 @@ func LoadConfigOptional(configFile string, optional bool) (*Config, error) {
|
||||
_ = SaveConfigPreserveCommentsUpdateNestedScalar(configFile, []string{"remote-management", "secret-key"}, hashed)
|
||||
}
|
||||
|
||||
cfg.RemoteManagement.PanelGitHubRepository = strings.TrimSpace(cfg.RemoteManagement.PanelGitHubRepository)
|
||||
if cfg.RemoteManagement.PanelGitHubRepository == "" {
|
||||
cfg.RemoteManagement.PanelGitHubRepository = DefaultPanelGitHubRepository
|
||||
}
|
||||
|
||||
if cfg.LogsMaxTotalSizeMB < 0 {
|
||||
cfg.LogsMaxTotalSizeMB = 0
|
||||
}
|
||||
|
||||
// Sync request authentication providers with inline API keys for backwards compatibility.
|
||||
syncInlineAccessProvider(&cfg)
|
||||
|
||||
@@ -411,6 +442,7 @@ func (cfg *Config) SanitizeOpenAICompatibility() {
|
||||
for i := range cfg.OpenAICompatibility {
|
||||
e := cfg.OpenAICompatibility[i]
|
||||
e.Name = strings.TrimSpace(e.Name)
|
||||
e.Prefix = normalizeModelPrefix(e.Prefix)
|
||||
e.BaseURL = strings.TrimSpace(e.BaseURL)
|
||||
e.Headers = NormalizeHeaders(e.Headers)
|
||||
if e.BaseURL == "" {
|
||||
@@ -431,6 +463,7 @@ func (cfg *Config) SanitizeCodexKeys() {
|
||||
out := make([]CodexKey, 0, len(cfg.CodexKey))
|
||||
for i := range cfg.CodexKey {
|
||||
e := cfg.CodexKey[i]
|
||||
e.Prefix = normalizeModelPrefix(e.Prefix)
|
||||
e.BaseURL = strings.TrimSpace(e.BaseURL)
|
||||
e.Headers = NormalizeHeaders(e.Headers)
|
||||
e.ExcludedModels = NormalizeExcludedModels(e.ExcludedModels)
|
||||
@@ -449,6 +482,7 @@ func (cfg *Config) SanitizeClaudeKeys() {
|
||||
}
|
||||
for i := range cfg.ClaudeKey {
|
||||
entry := &cfg.ClaudeKey[i]
|
||||
entry.Prefix = normalizeModelPrefix(entry.Prefix)
|
||||
entry.Headers = NormalizeHeaders(entry.Headers)
|
||||
entry.ExcludedModels = NormalizeExcludedModels(entry.ExcludedModels)
|
||||
}
|
||||
@@ -468,6 +502,7 @@ func (cfg *Config) SanitizeGeminiKeys() {
|
||||
if entry.APIKey == "" {
|
||||
continue
|
||||
}
|
||||
entry.Prefix = normalizeModelPrefix(entry.Prefix)
|
||||
entry.BaseURL = strings.TrimSpace(entry.BaseURL)
|
||||
entry.ProxyURL = strings.TrimSpace(entry.ProxyURL)
|
||||
entry.Headers = NormalizeHeaders(entry.Headers)
|
||||
@@ -481,6 +516,18 @@ func (cfg *Config) SanitizeGeminiKeys() {
|
||||
cfg.GeminiKey = out
|
||||
}
|
||||
|
||||
func normalizeModelPrefix(prefix string) string {
|
||||
trimmed := strings.TrimSpace(prefix)
|
||||
trimmed = strings.Trim(trimmed, "/")
|
||||
if trimmed == "" {
|
||||
return ""
|
||||
}
|
||||
if strings.Contains(trimmed, "/") {
|
||||
return ""
|
||||
}
|
||||
return trimmed
|
||||
}
|
||||
|
||||
func syncInlineAccessProvider(cfg *Config) {
|
||||
if cfg == nil {
|
||||
return
|
||||
@@ -653,7 +700,7 @@ func sanitizeConfigForPersist(cfg *Config) *Config {
|
||||
}
|
||||
clone := *cfg
|
||||
clone.SDKConfig = cfg.SDKConfig
|
||||
clone.SDKConfig.Access = config.AccessConfig{}
|
||||
clone.SDKConfig.Access = AccessConfig{}
|
||||
return &clone
|
||||
}
|
||||
|
||||
|
||||
87
internal/config/sdk_config.go
Normal file
87
internal/config/sdk_config.go
Normal file
@@ -0,0 +1,87 @@
|
||||
// Package config provides configuration management for the CLI Proxy API server.
|
||||
// It handles loading and parsing YAML configuration files, and provides structured
|
||||
// access to application settings including server port, authentication directory,
|
||||
// debug settings, proxy configuration, and API keys.
|
||||
package config
|
||||
|
||||
// SDKConfig represents the application's configuration, loaded from a YAML file.
|
||||
type SDKConfig struct {
|
||||
// ProxyURL is the URL of an optional proxy server to use for outbound requests.
|
||||
ProxyURL string `yaml:"proxy-url" json:"proxy-url"`
|
||||
|
||||
// ForceModelPrefix requires explicit model prefixes (e.g., "teamA/gemini-3-pro-preview")
|
||||
// to target prefixed credentials. When false, unprefixed model requests may use prefixed
|
||||
// credentials as well.
|
||||
ForceModelPrefix bool `yaml:"force-model-prefix" json:"force-model-prefix"`
|
||||
|
||||
// RequestLog enables or disables detailed request logging functionality.
|
||||
RequestLog bool `yaml:"request-log" json:"request-log"`
|
||||
|
||||
// APIKeys is a list of keys for authenticating clients to this proxy server.
|
||||
APIKeys []string `yaml:"api-keys" json:"api-keys"`
|
||||
|
||||
// Access holds request authentication provider configuration.
|
||||
Access AccessConfig `yaml:"auth,omitempty" json:"auth,omitempty"`
|
||||
}
|
||||
|
||||
// AccessConfig groups request authentication providers.
|
||||
type AccessConfig struct {
|
||||
// Providers lists configured authentication providers.
|
||||
Providers []AccessProvider `yaml:"providers,omitempty" json:"providers,omitempty"`
|
||||
}
|
||||
|
||||
// AccessProvider describes a request authentication provider entry.
|
||||
type AccessProvider struct {
|
||||
// Name is the instance identifier for the provider.
|
||||
Name string `yaml:"name" json:"name"`
|
||||
|
||||
// Type selects the provider implementation registered via the SDK.
|
||||
Type string `yaml:"type" json:"type"`
|
||||
|
||||
// SDK optionally names a third-party SDK module providing this provider.
|
||||
SDK string `yaml:"sdk,omitempty" json:"sdk,omitempty"`
|
||||
|
||||
// APIKeys lists inline keys for providers that require them.
|
||||
APIKeys []string `yaml:"api-keys,omitempty" json:"api-keys,omitempty"`
|
||||
|
||||
// Config passes provider-specific options to the implementation.
|
||||
Config map[string]any `yaml:"config,omitempty" json:"config,omitempty"`
|
||||
}
|
||||
|
||||
const (
|
||||
// AccessProviderTypeConfigAPIKey is the built-in provider validating inline API keys.
|
||||
AccessProviderTypeConfigAPIKey = "config-api-key"
|
||||
|
||||
// DefaultAccessProviderName is applied when no provider name is supplied.
|
||||
DefaultAccessProviderName = "config-inline"
|
||||
)
|
||||
|
||||
// ConfigAPIKeyProvider returns the first inline API key provider if present.
|
||||
func (c *SDKConfig) ConfigAPIKeyProvider() *AccessProvider {
|
||||
if c == nil {
|
||||
return nil
|
||||
}
|
||||
for i := range c.Access.Providers {
|
||||
if c.Access.Providers[i].Type == AccessProviderTypeConfigAPIKey {
|
||||
if c.Access.Providers[i].Name == "" {
|
||||
c.Access.Providers[i].Name = DefaultAccessProviderName
|
||||
}
|
||||
return &c.Access.Providers[i]
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// MakeInlineAPIKeyProvider constructs an inline API key provider configuration.
|
||||
// It returns nil when no keys are supplied.
|
||||
func MakeInlineAPIKeyProvider(keys []string) *AccessProvider {
|
||||
if len(keys) == 0 {
|
||||
return nil
|
||||
}
|
||||
provider := &AccessProvider{
|
||||
Name: DefaultAccessProviderName,
|
||||
Type: AccessProviderTypeConfigAPIKey,
|
||||
APIKeys: append([]string(nil), keys...),
|
||||
}
|
||||
return provider
|
||||
}
|
||||
@@ -13,6 +13,9 @@ type VertexCompatKey struct {
|
||||
// Maps to the x-goog-api-key header.
|
||||
APIKey string `yaml:"api-key" json:"api-key"`
|
||||
|
||||
// Prefix optionally namespaces model aliases for this credential (e.g., "teamA/vertex-pro").
|
||||
Prefix string `yaml:"prefix,omitempty" json:"prefix,omitempty"`
|
||||
|
||||
// BaseURL is the base URL for the Vertex-compatible API endpoint.
|
||||
// The executor will append "/v1/publishers/google/models/{model}:action" to this.
|
||||
// Example: "https://zenmux.ai/api" becomes "https://zenmux.ai/api/v1/publishers/google/models/..."
|
||||
@@ -53,6 +56,7 @@ func (cfg *Config) SanitizeVertexCompatKeys() {
|
||||
if entry.APIKey == "" {
|
||||
continue
|
||||
}
|
||||
entry.Prefix = normalizeModelPrefix(entry.Prefix)
|
||||
entry.BaseURL = strings.TrimSpace(entry.BaseURL)
|
||||
if entry.BaseURL == "" {
|
||||
// BaseURL is required for Vertex API key entries
|
||||
|
||||
@@ -72,39 +72,45 @@ func SetupBaseLogger() {
|
||||
}
|
||||
|
||||
// ConfigureLogOutput switches the global log destination between rotating files and stdout.
|
||||
func ConfigureLogOutput(loggingToFile bool) error {
|
||||
// When logsMaxTotalSizeMB > 0, a background cleaner removes the oldest log files in the logs directory
|
||||
// until the total size is within the limit.
|
||||
func ConfigureLogOutput(loggingToFile bool, logsMaxTotalSizeMB int) error {
|
||||
SetupBaseLogger()
|
||||
|
||||
writerMu.Lock()
|
||||
defer writerMu.Unlock()
|
||||
|
||||
logDir := "logs"
|
||||
if base := util.WritablePath(); base != "" {
|
||||
logDir = filepath.Join(base, "logs")
|
||||
}
|
||||
|
||||
protectedPath := ""
|
||||
if loggingToFile {
|
||||
logDir := "logs"
|
||||
if base := util.WritablePath(); base != "" {
|
||||
logDir = filepath.Join(base, "logs")
|
||||
}
|
||||
if err := os.MkdirAll(logDir, 0o755); err != nil {
|
||||
return fmt.Errorf("logging: failed to create log directory: %w", err)
|
||||
}
|
||||
if logWriter != nil {
|
||||
_ = logWriter.Close()
|
||||
}
|
||||
protectedPath = filepath.Join(logDir, "main.log")
|
||||
logWriter = &lumberjack.Logger{
|
||||
Filename: filepath.Join(logDir, "main.log"),
|
||||
Filename: protectedPath,
|
||||
MaxSize: 10,
|
||||
MaxBackups: 0,
|
||||
MaxAge: 0,
|
||||
Compress: false,
|
||||
}
|
||||
log.SetOutput(logWriter)
|
||||
return nil
|
||||
} else {
|
||||
if logWriter != nil {
|
||||
_ = logWriter.Close()
|
||||
logWriter = nil
|
||||
}
|
||||
log.SetOutput(os.Stdout)
|
||||
}
|
||||
|
||||
if logWriter != nil {
|
||||
_ = logWriter.Close()
|
||||
logWriter = nil
|
||||
}
|
||||
log.SetOutput(os.Stdout)
|
||||
configureLogDirCleanerLocked(logDir, logsMaxTotalSizeMB, protectedPath)
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -112,6 +118,8 @@ func closeLogOutputs() {
|
||||
writerMu.Lock()
|
||||
defer writerMu.Unlock()
|
||||
|
||||
stopLogDirCleanerLocked()
|
||||
|
||||
if logWriter != nil {
|
||||
_ = logWriter.Close()
|
||||
logWriter = nil
|
||||
|
||||
166
internal/logging/log_dir_cleaner.go
Normal file
166
internal/logging/log_dir_cleaner.go
Normal file
@@ -0,0 +1,166 @@
|
||||
package logging
|
||||
|
||||
import (
|
||||
"context"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"sort"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
const logDirCleanerInterval = time.Minute
|
||||
|
||||
var logDirCleanerCancel context.CancelFunc
|
||||
|
||||
func configureLogDirCleanerLocked(logDir string, maxTotalSizeMB int, protectedPath string) {
|
||||
stopLogDirCleanerLocked()
|
||||
|
||||
if maxTotalSizeMB <= 0 {
|
||||
return
|
||||
}
|
||||
|
||||
maxBytes := int64(maxTotalSizeMB) * 1024 * 1024
|
||||
if maxBytes <= 0 {
|
||||
return
|
||||
}
|
||||
|
||||
dir := strings.TrimSpace(logDir)
|
||||
if dir == "" {
|
||||
return
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
logDirCleanerCancel = cancel
|
||||
go runLogDirCleaner(ctx, filepath.Clean(dir), maxBytes, strings.TrimSpace(protectedPath))
|
||||
}
|
||||
|
||||
func stopLogDirCleanerLocked() {
|
||||
if logDirCleanerCancel == nil {
|
||||
return
|
||||
}
|
||||
logDirCleanerCancel()
|
||||
logDirCleanerCancel = nil
|
||||
}
|
||||
|
||||
func runLogDirCleaner(ctx context.Context, logDir string, maxBytes int64, protectedPath string) {
|
||||
ticker := time.NewTicker(logDirCleanerInterval)
|
||||
defer ticker.Stop()
|
||||
|
||||
cleanOnce := func() {
|
||||
deleted, errClean := enforceLogDirSizeLimit(logDir, maxBytes, protectedPath)
|
||||
if errClean != nil {
|
||||
log.WithError(errClean).Warn("logging: failed to enforce log directory size limit")
|
||||
return
|
||||
}
|
||||
if deleted > 0 {
|
||||
log.Debugf("logging: removed %d old log file(s) to enforce log directory size limit", deleted)
|
||||
}
|
||||
}
|
||||
|
||||
cleanOnce()
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case <-ticker.C:
|
||||
cleanOnce()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func enforceLogDirSizeLimit(logDir string, maxBytes int64, protectedPath string) (int, error) {
|
||||
if maxBytes <= 0 {
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
dir := strings.TrimSpace(logDir)
|
||||
if dir == "" {
|
||||
return 0, nil
|
||||
}
|
||||
dir = filepath.Clean(dir)
|
||||
|
||||
entries, errRead := os.ReadDir(dir)
|
||||
if errRead != nil {
|
||||
if os.IsNotExist(errRead) {
|
||||
return 0, nil
|
||||
}
|
||||
return 0, errRead
|
||||
}
|
||||
|
||||
protected := strings.TrimSpace(protectedPath)
|
||||
if protected != "" {
|
||||
protected = filepath.Clean(protected)
|
||||
}
|
||||
|
||||
type logFile struct {
|
||||
path string
|
||||
size int64
|
||||
modTime time.Time
|
||||
}
|
||||
|
||||
var (
|
||||
files []logFile
|
||||
total int64
|
||||
)
|
||||
for _, entry := range entries {
|
||||
if entry.IsDir() {
|
||||
continue
|
||||
}
|
||||
name := entry.Name()
|
||||
if !isLogFileName(name) {
|
||||
continue
|
||||
}
|
||||
info, errInfo := entry.Info()
|
||||
if errInfo != nil {
|
||||
continue
|
||||
}
|
||||
if !info.Mode().IsRegular() {
|
||||
continue
|
||||
}
|
||||
path := filepath.Join(dir, name)
|
||||
files = append(files, logFile{
|
||||
path: path,
|
||||
size: info.Size(),
|
||||
modTime: info.ModTime(),
|
||||
})
|
||||
total += info.Size()
|
||||
}
|
||||
|
||||
if total <= maxBytes {
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
sort.Slice(files, func(i, j int) bool {
|
||||
return files[i].modTime.Before(files[j].modTime)
|
||||
})
|
||||
|
||||
deleted := 0
|
||||
for _, file := range files {
|
||||
if total <= maxBytes {
|
||||
break
|
||||
}
|
||||
if protected != "" && filepath.Clean(file.path) == protected {
|
||||
continue
|
||||
}
|
||||
if errRemove := os.Remove(file.path); errRemove != nil {
|
||||
log.WithError(errRemove).Warnf("logging: failed to remove old log file: %s", filepath.Base(file.path))
|
||||
continue
|
||||
}
|
||||
total -= file.size
|
||||
deleted++
|
||||
}
|
||||
|
||||
return deleted, nil
|
||||
}
|
||||
|
||||
func isLogFileName(name string) bool {
|
||||
trimmed := strings.TrimSpace(name)
|
||||
if trimmed == "" {
|
||||
return false
|
||||
}
|
||||
lower := strings.ToLower(trimmed)
|
||||
return strings.HasSuffix(lower, ".log") || strings.HasSuffix(lower, ".log.gz")
|
||||
}
|
||||
70
internal/logging/log_dir_cleaner_test.go
Normal file
70
internal/logging/log_dir_cleaner_test.go
Normal file
@@ -0,0 +1,70 @@
|
||||
package logging
|
||||
|
||||
import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestEnforceLogDirSizeLimitDeletesOldest(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
|
||||
writeLogFile(t, filepath.Join(dir, "old.log"), 60, time.Unix(1, 0))
|
||||
writeLogFile(t, filepath.Join(dir, "mid.log"), 60, time.Unix(2, 0))
|
||||
protected := filepath.Join(dir, "main.log")
|
||||
writeLogFile(t, protected, 60, time.Unix(3, 0))
|
||||
|
||||
deleted, err := enforceLogDirSizeLimit(dir, 120, protected)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if deleted != 1 {
|
||||
t.Fatalf("expected 1 deleted file, got %d", deleted)
|
||||
}
|
||||
|
||||
if _, err := os.Stat(filepath.Join(dir, "old.log")); !os.IsNotExist(err) {
|
||||
t.Fatalf("expected old.log to be removed, stat error: %v", err)
|
||||
}
|
||||
if _, err := os.Stat(filepath.Join(dir, "mid.log")); err != nil {
|
||||
t.Fatalf("expected mid.log to remain, stat error: %v", err)
|
||||
}
|
||||
if _, err := os.Stat(protected); err != nil {
|
||||
t.Fatalf("expected protected main.log to remain, stat error: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestEnforceLogDirSizeLimitSkipsProtected(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
|
||||
protected := filepath.Join(dir, "main.log")
|
||||
writeLogFile(t, protected, 200, time.Unix(1, 0))
|
||||
writeLogFile(t, filepath.Join(dir, "other.log"), 50, time.Unix(2, 0))
|
||||
|
||||
deleted, err := enforceLogDirSizeLimit(dir, 100, protected)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if deleted != 1 {
|
||||
t.Fatalf("expected 1 deleted file, got %d", deleted)
|
||||
}
|
||||
|
||||
if _, err := os.Stat(protected); err != nil {
|
||||
t.Fatalf("expected protected main.log to remain, stat error: %v", err)
|
||||
}
|
||||
if _, err := os.Stat(filepath.Join(dir, "other.log")); !os.IsNotExist(err) {
|
||||
t.Fatalf("expected other.log to be removed, stat error: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func writeLogFile(t *testing.T, path string, size int, modTime time.Time) {
|
||||
t.Helper()
|
||||
|
||||
data := make([]byte, size)
|
||||
if err := os.WriteFile(path, data, 0o644); err != nil {
|
||||
t.Fatalf("write file: %v", err)
|
||||
}
|
||||
if err := os.Chtimes(path, modTime, modTime); err != nil {
|
||||
t.Fatalf("set times: %v", err)
|
||||
}
|
||||
}
|
||||
@@ -9,6 +9,7 @@ import (
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
@@ -23,10 +24,10 @@ import (
|
||||
)
|
||||
|
||||
const (
|
||||
managementReleaseURL = "https://api.github.com/repos/router-for-me/Cli-Proxy-API-Management-Center/releases/latest"
|
||||
managementAssetName = "management.html"
|
||||
httpUserAgent = "CLIProxyAPI-management-updater"
|
||||
updateCheckInterval = 3 * time.Hour
|
||||
defaultManagementReleaseURL = "https://api.github.com/repos/router-for-me/Cli-Proxy-API-Management-Center/releases/latest"
|
||||
managementAssetName = "management.html"
|
||||
httpUserAgent = "CLIProxyAPI-management-updater"
|
||||
updateCheckInterval = 3 * time.Hour
|
||||
)
|
||||
|
||||
// ManagementFileName exposes the control panel asset filename.
|
||||
@@ -97,7 +98,7 @@ func runAutoUpdater(ctx context.Context) {
|
||||
|
||||
configPath, _ := schedulerConfigPath.Load().(string)
|
||||
staticDir := StaticDir(configPath)
|
||||
EnsureLatestManagementHTML(ctx, staticDir, cfg.ProxyURL)
|
||||
EnsureLatestManagementHTML(ctx, staticDir, cfg.ProxyURL, cfg.RemoteManagement.PanelGitHubRepository)
|
||||
}
|
||||
|
||||
runOnce()
|
||||
@@ -181,7 +182,7 @@ func FilePath(configFilePath string) string {
|
||||
// EnsureLatestManagementHTML checks the latest management.html asset and updates the local copy when needed.
|
||||
// The function is designed to run in a background goroutine and will never panic.
|
||||
// It enforces a 3-hour rate limit to avoid frequent checks on config/auth file changes.
|
||||
func EnsureLatestManagementHTML(ctx context.Context, staticDir string, proxyURL string) {
|
||||
func EnsureLatestManagementHTML(ctx context.Context, staticDir string, proxyURL string, panelRepository string) {
|
||||
if ctx == nil {
|
||||
ctx = context.Background()
|
||||
}
|
||||
@@ -214,6 +215,7 @@ func EnsureLatestManagementHTML(ctx context.Context, staticDir string, proxyURL
|
||||
return
|
||||
}
|
||||
|
||||
releaseURL := resolveReleaseURL(panelRepository)
|
||||
client := newHTTPClient(proxyURL)
|
||||
|
||||
localPath := filepath.Join(staticDir, managementAssetName)
|
||||
@@ -225,7 +227,7 @@ func EnsureLatestManagementHTML(ctx context.Context, staticDir string, proxyURL
|
||||
localHash = ""
|
||||
}
|
||||
|
||||
asset, remoteHash, err := fetchLatestAsset(ctx, client)
|
||||
asset, remoteHash, err := fetchLatestAsset(ctx, client, releaseURL)
|
||||
if err != nil {
|
||||
log.WithError(err).Warn("failed to fetch latest management release information")
|
||||
return
|
||||
@@ -254,8 +256,44 @@ func EnsureLatestManagementHTML(ctx context.Context, staticDir string, proxyURL
|
||||
log.Infof("management asset updated successfully (hash=%s)", downloadedHash)
|
||||
}
|
||||
|
||||
func fetchLatestAsset(ctx context.Context, client *http.Client) (*releaseAsset, string, error) {
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, managementReleaseURL, nil)
|
||||
func resolveReleaseURL(repo string) string {
|
||||
repo = strings.TrimSpace(repo)
|
||||
if repo == "" {
|
||||
return defaultManagementReleaseURL
|
||||
}
|
||||
|
||||
parsed, err := url.Parse(repo)
|
||||
if err != nil || parsed.Host == "" {
|
||||
return defaultManagementReleaseURL
|
||||
}
|
||||
|
||||
host := strings.ToLower(parsed.Host)
|
||||
parsed.Path = strings.TrimSuffix(parsed.Path, "/")
|
||||
|
||||
if host == "api.github.com" {
|
||||
if !strings.HasSuffix(strings.ToLower(parsed.Path), "/releases/latest") {
|
||||
parsed.Path = parsed.Path + "/releases/latest"
|
||||
}
|
||||
return parsed.String()
|
||||
}
|
||||
|
||||
if host == "github.com" {
|
||||
parts := strings.Split(strings.Trim(parsed.Path, "/"), "/")
|
||||
if len(parts) >= 2 && parts[0] != "" && parts[1] != "" {
|
||||
repoName := strings.TrimSuffix(parts[1], ".git")
|
||||
return fmt.Sprintf("https://api.github.com/repos/%s/%s/releases/latest", parts[0], repoName)
|
||||
}
|
||||
}
|
||||
|
||||
return defaultManagementReleaseURL
|
||||
}
|
||||
|
||||
func fetchLatestAsset(ctx context.Context, client *http.Client, releaseURL string) (*releaseAsset, string, error) {
|
||||
if strings.TrimSpace(releaseURL) == "" {
|
||||
releaseURL = defaultManagementReleaseURL
|
||||
}
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, releaseURL, nil)
|
||||
if err != nil {
|
||||
return nil, "", fmt.Errorf("create release request: %w", err)
|
||||
}
|
||||
|
||||
@@ -20,6 +20,7 @@ func CodexInstructionsForModel(modelName, systemInstructions string) (bool, stri
|
||||
lastCodexMaxPrompt := ""
|
||||
last51Prompt := ""
|
||||
last52Prompt := ""
|
||||
last52CodexPrompt := ""
|
||||
// lastReviewPrompt := ""
|
||||
for _, entry := range entries {
|
||||
content, _ := codexInstructionsDir.ReadFile("codex_instructions/" + entry.Name())
|
||||
@@ -36,12 +37,16 @@ func CodexInstructionsForModel(modelName, systemInstructions string) (bool, stri
|
||||
last51Prompt = string(content)
|
||||
} else if strings.HasPrefix(entry.Name(), "gpt_5_2_prompt.md") {
|
||||
last52Prompt = string(content)
|
||||
} else if strings.HasPrefix(entry.Name(), "gpt-5.2-codex_prompt.md") {
|
||||
last52CodexPrompt = string(content)
|
||||
} else if strings.HasPrefix(entry.Name(), "review_prompt.md") {
|
||||
// lastReviewPrompt = string(content)
|
||||
}
|
||||
}
|
||||
if strings.Contains(modelName, "codex-max") {
|
||||
return false, lastCodexMaxPrompt
|
||||
} else if strings.Contains(modelName, "5.2-codex") {
|
||||
return false, last52CodexPrompt
|
||||
} else if strings.Contains(modelName, "codex") {
|
||||
return false, lastCodexPrompt
|
||||
} else if strings.Contains(modelName, "5.1") {
|
||||
|
||||
@@ -0,0 +1,117 @@
|
||||
You are Codex, based on GPT-5. You are running as a coding agent in the Codex CLI on a user's computer.
|
||||
|
||||
## General
|
||||
|
||||
- When searching for text or files, prefer using `rg` or `rg --files` respectively because `rg` is much faster than alternatives like `grep`. (If the `rg` command is not found, then use alternatives.)
|
||||
|
||||
## Editing constraints
|
||||
|
||||
- Default to ASCII when editing or creating files. Only introduce non-ASCII or other Unicode characters when there is a clear justification and the file already uses them.
|
||||
- Add succinct code comments that explain what is going on if code is not self-explanatory. You should not add comments like "Assigns the value to the variable", but a brief comment might be useful ahead of a complex code block that the user would otherwise have to spend time parsing out. Usage of these comments should be rare.
|
||||
- Try to use apply_patch for single file edits, but it is fine to explore other options to make the edit if it does not work well. Do not use apply_patch for changes that are auto-generated (i.e. generating package.json or running a lint or format command like gofmt) or when scripting is more efficient (such as search and replacing a string across a codebase).
|
||||
- You may be in a dirty git worktree.
|
||||
* NEVER revert existing changes you did not make unless explicitly requested, since these changes were made by the user.
|
||||
* If asked to make a commit or code edits and there are unrelated changes to your work or changes that you didn't make in those files, don't revert those changes.
|
||||
* If the changes are in files you've touched recently, you should read carefully and understand how you can work with the changes rather than reverting them.
|
||||
* If the changes are in unrelated files, just ignore them and don't revert them.
|
||||
- Do not amend a commit unless explicitly requested to do so.
|
||||
- While you are working, you might notice unexpected changes that you didn't make. If this happens, STOP IMMEDIATELY and ask the user how they would like to proceed.
|
||||
- **NEVER** use destructive commands like `git reset --hard` or `git checkout --` unless specifically requested or approved by the user.
|
||||
|
||||
## Plan tool
|
||||
|
||||
When using the planning tool:
|
||||
- Skip using the planning tool for straightforward tasks (roughly the easiest 25%).
|
||||
- Do not make single-step plans.
|
||||
- When you made a plan, update it after having performed one of the sub-tasks that you shared on the plan.
|
||||
|
||||
## Codex CLI harness, sandboxing, and approvals
|
||||
|
||||
The Codex CLI harness supports several different configurations for sandboxing and escalation approvals that the user can choose from.
|
||||
|
||||
Filesystem sandboxing defines which files can be read or written. The options for `sandbox_mode` are:
|
||||
- **read-only**: The sandbox only permits reading files.
|
||||
- **workspace-write**: The sandbox permits reading files, and editing files in `cwd` and `writable_roots`. Editing files in other directories requires approval.
|
||||
- **danger-full-access**: No filesystem sandboxing - all commands are permitted.
|
||||
|
||||
Network sandboxing defines whether network can be accessed without approval. Options for `network_access` are:
|
||||
- **restricted**: Requires approval
|
||||
- **enabled**: No approval needed
|
||||
|
||||
Approvals are your mechanism to get user consent to run shell commands without the sandbox. Possible configuration options for `approval_policy` are
|
||||
- **untrusted**: The harness will escalate most commands for user approval, apart from a limited allowlist of safe "read" commands.
|
||||
- **on-failure**: The harness will allow all commands to run in the sandbox (if enabled), and failures will be escalated to the user for approval to run again without the sandbox.
|
||||
- **on-request**: Commands will be run in the sandbox by default, and you can specify in your tool call if you want to escalate a command to run without sandboxing. (Note that this mode is not always available. If it is, you'll see parameters for it in the `shell` command description.)
|
||||
- **never**: This is a non-interactive mode where you may NEVER ask the user for approval to run commands. Instead, you must always persist and work around constraints to solve the task for the user. You MUST do your utmost best to finish the task and validate your work before yielding. If this mode is paired with `danger-full-access`, take advantage of it to deliver the best outcome for the user. Further, in this mode, your default testing philosophy is overridden: Even if you don't see local patterns for testing, you may add tests and scripts to validate your work. Just remove them before yielding.
|
||||
|
||||
When you are running with `approval_policy == on-request`, and sandboxing enabled, here are scenarios where you'll need to request approval:
|
||||
- You need to run a command that writes to a directory that requires it (e.g. running tests that write to /var)
|
||||
- You need to run a GUI app (e.g., open/xdg-open/osascript) to open browsers or files.
|
||||
- You are running sandboxed and need to run a command that requires network access (e.g. installing packages)
|
||||
- If you run a command that is important to solving the user's query, but it fails because of sandboxing, rerun the command with approval. ALWAYS proceed to use the `sandbox_permissions` and `justification` parameters - do not message the user before requesting approval for the command.
|
||||
- You are about to take a potentially destructive action such as an `rm` or `git reset` that the user did not explicitly ask for
|
||||
- (for all of these, you should weigh alternative paths that do not require approval)
|
||||
|
||||
When `sandbox_mode` is set to read-only, you'll need to request approval for any command that isn't a read.
|
||||
|
||||
You will be told what filesystem sandboxing, network sandboxing, and approval mode are active in a developer or user message. If you are not told about this, assume that you are running with workspace-write, network sandboxing enabled, and approval on-failure.
|
||||
|
||||
Although they introduce friction to the user because your work is paused until the user responds, you should leverage them when necessary to accomplish important work. If the completing the task requires escalated permissions, Do not let these settings or the sandbox deter you from attempting to accomplish the user's task unless it is set to "never", in which case never ask for approvals.
|
||||
|
||||
When requesting approval to execute a command that will require escalated privileges:
|
||||
- Provide the `sandbox_permissions` parameter with the value `"require_escalated"`
|
||||
- Include a short, 1 sentence explanation for why you need escalated permissions in the justification parameter
|
||||
|
||||
## Special user requests
|
||||
|
||||
- If the user makes a simple request (such as asking for the time) which you can fulfill by running a terminal command (such as `date`), you should do so.
|
||||
- If the user asks for a "review", default to a code review mindset: prioritise identifying bugs, risks, behavioural regressions, and missing tests. Findings must be the primary focus of the response - keep summaries or overviews brief and only after enumerating the issues. Present findings first (ordered by severity with file/line references), follow with open questions or assumptions, and offer a change-summary only as a secondary detail. If no findings are discovered, state that explicitly and mention any residual risks or testing gaps.
|
||||
|
||||
## Frontend tasks
|
||||
When doing frontend design tasks, avoid collapsing into "AI slop" or safe, average-looking layouts.
|
||||
Aim for interfaces that feel intentional, bold, and a bit surprising.
|
||||
- Typography: Use expressive, purposeful fonts and avoid default stacks (Inter, Roboto, Arial, system).
|
||||
- Color & Look: Choose a clear visual direction; define CSS variables; avoid purple-on-white defaults. No purple bias or dark mode bias.
|
||||
- Motion: Use a few meaningful animations (page-load, staggered reveals) instead of generic micro-motions.
|
||||
- Background: Don't rely on flat, single-color backgrounds; use gradients, shapes, or subtle patterns to build atmosphere.
|
||||
- Overall: Avoid boilerplate layouts and interchangeable UI patterns. Vary themes, type families, and visual languages across outputs.
|
||||
- Ensure the page loads properly on both desktop and mobile
|
||||
|
||||
Exception: If working within an existing website or design system, preserve the established patterns, structure, and visual language.
|
||||
|
||||
## Presenting your work and final message
|
||||
|
||||
You are producing plain text that will later be styled by the CLI. Follow these rules exactly. Formatting should make results easy to scan, but not feel mechanical. Use judgment to decide how much structure adds value.
|
||||
|
||||
- Default: be very concise; friendly coding teammate tone.
|
||||
- Ask only when needed; suggest ideas; mirror the user's style.
|
||||
- For substantial work, summarize clearly; follow final‑answer formatting.
|
||||
- Skip heavy formatting for simple confirmations.
|
||||
- Don't dump large files you've written; reference paths only.
|
||||
- No "save/copy this file" - User is on the same machine.
|
||||
- Offer logical next steps (tests, commits, build) briefly; add verify steps if you couldn't do something.
|
||||
- For code changes:
|
||||
* Lead with a quick explanation of the change, and then give more details on the context covering where and why a change was made. Do not start this explanation with "summary", just jump right in.
|
||||
* If there are natural next steps the user may want to take, suggest them at the end of your response. Do not make suggestions if there are no natural next steps.
|
||||
* When suggesting multiple options, use numeric lists for the suggestions so the user can quickly respond with a single number.
|
||||
- The user does not command execution outputs. When asked to show the output of a command (e.g. `git show`), relay the important details in your answer or summarize the key lines so the user understands the result.
|
||||
|
||||
### Final answer structure and style guidelines
|
||||
|
||||
- Plain text; CLI handles styling. Use structure only when it helps scanability.
|
||||
- Headers: optional; short Title Case (1-3 words) wrapped in **…**; no blank line before the first bullet; add only if they truly help.
|
||||
- Bullets: use - ; merge related points; keep to one line when possible; 4–6 per list ordered by importance; keep phrasing consistent.
|
||||
- Monospace: backticks for commands/paths/env vars/code ids and inline examples; use for literal keyword bullets; never combine with **.
|
||||
- Code samples or multi-line snippets should be wrapped in fenced code blocks; include an info string as often as possible.
|
||||
- Structure: group related bullets; order sections general → specific → supporting; for subsections, start with a bolded keyword bullet, then items; match complexity to the task.
|
||||
- Tone: collaborative, concise, factual; present tense, active voice; self‑contained; no "above/below"; parallel wording.
|
||||
- Don'ts: no nested bullets/hierarchies; no ANSI codes; don't cram unrelated keywords; keep keyword lists short—wrap/reformat if long; avoid naming formatting styles in answers.
|
||||
- Adaptation: code explanations → precise, structured with code refs; simple tasks → lead with outcome; big changes → logical walkthrough + rationale + next actions; casual one-offs → plain sentences, no headers/bullets.
|
||||
- File References: When referencing files in your response follow the below rules:
|
||||
* Use inline code to make file paths clickable.
|
||||
* Each reference should have a stand alone path. Even if it's the same file.
|
||||
* Accepted: absolute, workspace‑relative, a/ or b/ diff prefixes, or bare filename/suffix.
|
||||
* Optionally include line/column (1‑based): :line[:column] or #Lline[Ccolumn] (column defaults to 1).
|
||||
* Do not use URIs like file://, vscode://, or https://.
|
||||
* Do not provide range of lines
|
||||
* Examples: src/app.ts, src/app.ts:42, b/server/index.js#L10, C:\repo\project\main.rs:12:5
|
||||
@@ -160,7 +160,22 @@ func GetGeminiModels() []*ModelInfo {
|
||||
InputTokenLimit: 1048576,
|
||||
OutputTokenLimit: 65536,
|
||||
SupportedGenerationMethods: []string{"generateContent", "countTokens", "createCachedContent", "batchGenerateContent"},
|
||||
Thinking: &ThinkingSupport{Min: 128, Max: 32768, ZeroAllowed: false, DynamicAllowed: true},
|
||||
Thinking: &ThinkingSupport{Min: 128, Max: 32768, ZeroAllowed: false, DynamicAllowed: true, Levels: []string{"low", "high"}},
|
||||
},
|
||||
{
|
||||
ID: "gemini-3-flash-preview",
|
||||
Object: "model",
|
||||
Created: 1765929600,
|
||||
OwnedBy: "google",
|
||||
Type: "gemini",
|
||||
Name: "models/gemini-3-flash-preview",
|
||||
Version: "3.0",
|
||||
DisplayName: "Gemini 3 Flash Preview",
|
||||
Description: "Gemini 3 Flash Preview",
|
||||
InputTokenLimit: 1048576,
|
||||
OutputTokenLimit: 65536,
|
||||
SupportedGenerationMethods: []string{"generateContent", "countTokens", "createCachedContent", "batchGenerateContent"},
|
||||
Thinking: &ThinkingSupport{Min: 128, Max: 32768, ZeroAllowed: false, DynamicAllowed: true, Levels: []string{"minimal", "low", "medium", "high"}},
|
||||
},
|
||||
{
|
||||
ID: "gemini-3-pro-image-preview",
|
||||
@@ -175,7 +190,7 @@ func GetGeminiModels() []*ModelInfo {
|
||||
InputTokenLimit: 1048576,
|
||||
OutputTokenLimit: 65536,
|
||||
SupportedGenerationMethods: []string{"generateContent", "countTokens", "createCachedContent", "batchGenerateContent"},
|
||||
Thinking: &ThinkingSupport{Min: 128, Max: 32768, ZeroAllowed: false, DynamicAllowed: true},
|
||||
Thinking: &ThinkingSupport{Min: 128, Max: 32768, ZeroAllowed: false, DynamicAllowed: true, Levels: []string{"low", "high"}},
|
||||
},
|
||||
}
|
||||
}
|
||||
@@ -240,7 +255,22 @@ func GetGeminiVertexModels() []*ModelInfo {
|
||||
InputTokenLimit: 1048576,
|
||||
OutputTokenLimit: 65536,
|
||||
SupportedGenerationMethods: []string{"generateContent", "countTokens", "createCachedContent", "batchGenerateContent"},
|
||||
Thinking: &ThinkingSupport{Min: 128, Max: 32768, ZeroAllowed: false, DynamicAllowed: true},
|
||||
Thinking: &ThinkingSupport{Min: 128, Max: 32768, ZeroAllowed: false, DynamicAllowed: true, Levels: []string{"low", "high"}},
|
||||
},
|
||||
{
|
||||
ID: "gemini-3-flash-preview",
|
||||
Object: "model",
|
||||
Created: 1765929600,
|
||||
OwnedBy: "google",
|
||||
Type: "gemini",
|
||||
Name: "models/gemini-3-flash-preview",
|
||||
Version: "3.0",
|
||||
DisplayName: "Gemini 3 Flash Preview",
|
||||
Description: "Our most intelligent model built for speed, combining frontier intelligence with superior search and grounding.",
|
||||
InputTokenLimit: 1048576,
|
||||
OutputTokenLimit: 65536,
|
||||
SupportedGenerationMethods: []string{"generateContent", "countTokens", "createCachedContent", "batchGenerateContent"},
|
||||
Thinking: &ThinkingSupport{Min: 128, Max: 32768, ZeroAllowed: false, DynamicAllowed: true, Levels: []string{"minimal", "low", "medium", "high"}},
|
||||
},
|
||||
{
|
||||
ID: "gemini-3-pro-image-preview",
|
||||
@@ -255,7 +285,7 @@ func GetGeminiVertexModels() []*ModelInfo {
|
||||
InputTokenLimit: 1048576,
|
||||
OutputTokenLimit: 65536,
|
||||
SupportedGenerationMethods: []string{"generateContent", "countTokens", "createCachedContent", "batchGenerateContent"},
|
||||
Thinking: &ThinkingSupport{Min: 128, Max: 32768, ZeroAllowed: false, DynamicAllowed: true},
|
||||
Thinking: &ThinkingSupport{Min: 128, Max: 32768, ZeroAllowed: false, DynamicAllowed: true, Levels: []string{"low", "high"}},
|
||||
},
|
||||
}
|
||||
}
|
||||
@@ -317,11 +347,26 @@ func GetGeminiCLIModels() []*ModelInfo {
|
||||
Name: "models/gemini-3-pro-preview",
|
||||
Version: "3.0",
|
||||
DisplayName: "Gemini 3 Pro Preview",
|
||||
Description: "Gemini 3 Pro Preview",
|
||||
Description: "Our most intelligent model with SOTA reasoning and multimodal understanding, and powerful agentic and vibe coding capabilities",
|
||||
InputTokenLimit: 1048576,
|
||||
OutputTokenLimit: 65536,
|
||||
SupportedGenerationMethods: []string{"generateContent", "countTokens", "createCachedContent", "batchGenerateContent"},
|
||||
Thinking: &ThinkingSupport{Min: 128, Max: 32768, ZeroAllowed: false, DynamicAllowed: true},
|
||||
Thinking: &ThinkingSupport{Min: 128, Max: 32768, ZeroAllowed: false, DynamicAllowed: true, Levels: []string{"low", "high"}},
|
||||
},
|
||||
{
|
||||
ID: "gemini-3-flash-preview",
|
||||
Object: "model",
|
||||
Created: 1765929600,
|
||||
OwnedBy: "google",
|
||||
Type: "gemini",
|
||||
Name: "models/gemini-3-flash-preview",
|
||||
Version: "3.0",
|
||||
DisplayName: "Gemini 3 Flash Preview",
|
||||
Description: "Our most intelligent model built for speed, combining frontier intelligence with superior search and grounding.",
|
||||
InputTokenLimit: 1048576,
|
||||
OutputTokenLimit: 65536,
|
||||
SupportedGenerationMethods: []string{"generateContent", "countTokens", "createCachedContent", "batchGenerateContent"},
|
||||
Thinking: &ThinkingSupport{Min: 128, Max: 32768, ZeroAllowed: false, DynamicAllowed: true, Levels: []string{"minimal", "low", "medium", "high"}},
|
||||
},
|
||||
}
|
||||
}
|
||||
@@ -387,7 +432,22 @@ func GetAIStudioModels() []*ModelInfo {
|
||||
InputTokenLimit: 1048576,
|
||||
OutputTokenLimit: 65536,
|
||||
SupportedGenerationMethods: []string{"generateContent", "countTokens", "createCachedContent", "batchGenerateContent"},
|
||||
Thinking: &ThinkingSupport{Min: 128, Max: 32768, ZeroAllowed: false, DynamicAllowed: true},
|
||||
Thinking: &ThinkingSupport{Min: 128, Max: 32768, ZeroAllowed: false, DynamicAllowed: true, Levels: []string{"low", "high"}},
|
||||
},
|
||||
{
|
||||
ID: "gemini-3-flash-preview",
|
||||
Object: "model",
|
||||
Created: 1765929600,
|
||||
OwnedBy: "google",
|
||||
Type: "gemini",
|
||||
Name: "models/gemini-3-flash-preview",
|
||||
Version: "3.0",
|
||||
DisplayName: "Gemini 3 Flash Preview",
|
||||
Description: "Our most intelligent model built for speed, combining frontier intelligence with superior search and grounding.",
|
||||
InputTokenLimit: 1048576,
|
||||
OutputTokenLimit: 65536,
|
||||
SupportedGenerationMethods: []string{"generateContent", "countTokens", "createCachedContent", "batchGenerateContent"},
|
||||
Thinking: &ThinkingSupport{Min: 128, Max: 32768, ZeroAllowed: false, DynamicAllowed: true, Levels: []string{"minimal", "low", "medium", "high"}},
|
||||
},
|
||||
{
|
||||
ID: "gemini-pro-latest",
|
||||
@@ -582,6 +642,20 @@ func GetOpenAIModels() []*ModelInfo {
|
||||
SupportedParameters: []string{"tools"},
|
||||
Thinking: &ThinkingSupport{Levels: []string{"none", "low", "medium", "high", "xhigh"}},
|
||||
},
|
||||
{
|
||||
ID: "gpt-5.2-codex",
|
||||
Object: "model",
|
||||
Created: 1765440000,
|
||||
OwnedBy: "openai",
|
||||
Type: "openai",
|
||||
Version: "gpt-5.2",
|
||||
DisplayName: "GPT 5.2 Codex",
|
||||
Description: "Stable version of GPT 5.2 Codex, The best model for coding and agentic tasks across domains.",
|
||||
ContextLength: 400000,
|
||||
MaxCompletionTokens: 128000,
|
||||
SupportedParameters: []string{"tools"},
|
||||
Thinking: &ThinkingSupport{Levels: []string{"low", "medium", "high", "xhigh"}},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
@@ -630,6 +704,13 @@ func GetQwenModels() []*ModelInfo {
|
||||
}
|
||||
}
|
||||
|
||||
// iFlowThinkingSupport is a shared ThinkingSupport configuration for iFlow models
|
||||
// that support thinking mode via chat_template_kwargs.enable_thinking (boolean toggle).
|
||||
// Uses level-based configuration so standard normalization flows apply before conversion.
|
||||
var iFlowThinkingSupport = &ThinkingSupport{
|
||||
Levels: []string{"none", "auto", "minimal", "low", "medium", "high", "xhigh"},
|
||||
}
|
||||
|
||||
// GetIFlowModels returns supported models for iFlow OAuth accounts.
|
||||
func GetIFlowModels() []*ModelInfo {
|
||||
entries := []struct {
|
||||
@@ -645,19 +726,20 @@ func GetIFlowModels() []*ModelInfo {
|
||||
{ID: "qwen3-vl-plus", DisplayName: "Qwen3-VL-Plus", Description: "Qwen3 multimodal vision-language", Created: 1758672000},
|
||||
{ID: "qwen3-max-preview", DisplayName: "Qwen3-Max-Preview", Description: "Qwen3 Max preview build", Created: 1757030400},
|
||||
{ID: "kimi-k2-0905", DisplayName: "Kimi-K2-Instruct-0905", Description: "Moonshot Kimi K2 instruct 0905", Created: 1757030400},
|
||||
{ID: "glm-4.6", DisplayName: "GLM-4.6", Description: "Zhipu GLM 4.6 general model", Created: 1759190400},
|
||||
{ID: "glm-4.6", DisplayName: "GLM-4.6", Description: "Zhipu GLM 4.6 general model", Created: 1759190400, Thinking: iFlowThinkingSupport},
|
||||
{ID: "kimi-k2", DisplayName: "Kimi-K2", Description: "Moonshot Kimi K2 general model", Created: 1752192000},
|
||||
{ID: "kimi-k2-thinking", DisplayName: "Kimi-K2-Thinking", Description: "Moonshot Kimi K2 thinking model", Created: 1762387200, Thinking: &ThinkingSupport{Levels: []string{"low", "medium", "high"}}},
|
||||
{ID: "deepseek-v3.2-chat", DisplayName: "DeepSeek-V3.2", Description: "DeepSeek V3.2", Created: 1764576000},
|
||||
{ID: "kimi-k2-thinking", DisplayName: "Kimi-K2-Thinking", Description: "Moonshot Kimi K2 thinking model", Created: 1762387200},
|
||||
{ID: "deepseek-v3.2-chat", DisplayName: "DeepSeek-V3.2", Description: "DeepSeek V3.2 Chat", Created: 1764576000},
|
||||
{ID: "deepseek-v3.2-reasoner", DisplayName: "DeepSeek-V3.2", Description: "DeepSeek V3.2 Reasoner", Created: 1764576000},
|
||||
{ID: "deepseek-v3.2", DisplayName: "DeepSeek-V3.2-Exp", Description: "DeepSeek V3.2 experimental", Created: 1759104000},
|
||||
{ID: "deepseek-v3.1", DisplayName: "DeepSeek-V3.1-Terminus", Description: "DeepSeek V3.1 Terminus", Created: 1756339200},
|
||||
{ID: "deepseek-r1", DisplayName: "DeepSeek-R1", Description: "DeepSeek reasoning model R1", Created: 1737331200, Thinking: &ThinkingSupport{Levels: []string{"low", "medium", "high"}}},
|
||||
{ID: "deepseek-r1", DisplayName: "DeepSeek-R1", Description: "DeepSeek reasoning model R1", Created: 1737331200},
|
||||
{ID: "deepseek-v3", DisplayName: "DeepSeek-V3-671B", Description: "DeepSeek V3 671B", Created: 1734307200},
|
||||
{ID: "qwen3-32b", DisplayName: "Qwen3-32B", Description: "Qwen3 32B", Created: 1747094400},
|
||||
{ID: "qwen3-235b-a22b-thinking-2507", DisplayName: "Qwen3-235B-A22B-Thinking", Description: "Qwen3 235B A22B Thinking (2507)", Created: 1753401600, Thinking: &ThinkingSupport{Levels: []string{"low", "medium", "high"}}},
|
||||
{ID: "qwen3-235b-a22b-thinking-2507", DisplayName: "Qwen3-235B-A22B-Thinking", Description: "Qwen3 235B A22B Thinking (2507)", Created: 1753401600},
|
||||
{ID: "qwen3-235b-a22b-instruct", DisplayName: "Qwen3-235B-A22B-Instruct", Description: "Qwen3 235B A22B Instruct", Created: 1753401600},
|
||||
{ID: "qwen3-235b", DisplayName: "Qwen3-235B-A22B", Description: "Qwen3 235B A22B", Created: 1753401600},
|
||||
{ID: "minimax-m2", DisplayName: "MiniMax-M2", Description: "MiniMax M2", Created: 1758672000, Thinking: &ThinkingSupport{Levels: []string{"low", "medium", "high"}}},
|
||||
{ID: "minimax-m2", DisplayName: "MiniMax-M2", Description: "MiniMax M2", Created: 1758672000},
|
||||
}
|
||||
models := make([]*ModelInfo, 0, len(entries))
|
||||
for _, entry := range entries {
|
||||
@@ -690,8 +772,9 @@ func GetAntigravityModelConfig() map[string]*AntigravityModelConfig {
|
||||
"gemini-2.5-flash": {Thinking: &ThinkingSupport{Min: 0, Max: 24576, ZeroAllowed: true, DynamicAllowed: true}, Name: "models/gemini-2.5-flash"},
|
||||
"gemini-2.5-flash-lite": {Thinking: &ThinkingSupport{Min: 0, Max: 24576, ZeroAllowed: true, DynamicAllowed: true}, Name: "models/gemini-2.5-flash-lite"},
|
||||
"gemini-2.5-computer-use-preview-10-2025": {Name: "models/gemini-2.5-computer-use-preview-10-2025"},
|
||||
"gemini-3-pro-preview": {Thinking: &ThinkingSupport{Min: 128, Max: 32768, ZeroAllowed: false, DynamicAllowed: true}, Name: "models/gemini-3-pro-preview"},
|
||||
"gemini-3-pro-image-preview": {Thinking: &ThinkingSupport{Min: 128, Max: 32768, ZeroAllowed: false, DynamicAllowed: true}, Name: "models/gemini-3-pro-image-preview"},
|
||||
"gemini-3-pro-preview": {Thinking: &ThinkingSupport{Min: 128, Max: 32768, ZeroAllowed: false, DynamicAllowed: true, Levels: []string{"low", "high"}}, Name: "models/gemini-3-pro-preview"},
|
||||
"gemini-3-pro-image-preview": {Thinking: &ThinkingSupport{Min: 128, Max: 32768, ZeroAllowed: false, DynamicAllowed: true, Levels: []string{"low", "high"}}, Name: "models/gemini-3-pro-image-preview"},
|
||||
"gemini-3-flash-preview": {Thinking: &ThinkingSupport{Min: 128, Max: 32768, ZeroAllowed: false, DynamicAllowed: true, Levels: []string{"minimal", "low", "medium", "high"}}, Name: "models/gemini-3-flash-preview"},
|
||||
"gemini-claude-sonnet-4-5-thinking": {Thinking: &ThinkingSupport{Min: 1024, Max: 200000, ZeroAllowed: false, DynamicAllowed: true}, MaxCompletionTokens: 64000},
|
||||
"gemini-claude-opus-4-5-thinking": {Thinking: &ThinkingSupport{Min: 1024, Max: 200000, ZeroAllowed: false, DynamicAllowed: true}, MaxCompletionTokens: 64000},
|
||||
}
|
||||
|
||||
@@ -90,6 +90,9 @@ type ModelRegistry struct {
|
||||
models map[string]*ModelRegistration
|
||||
// clientModels maps client ID to the models it provides
|
||||
clientModels map[string][]string
|
||||
// clientModelInfos maps client ID to a map of model ID -> ModelInfo
|
||||
// This preserves the original model info provided by each client
|
||||
clientModelInfos map[string]map[string]*ModelInfo
|
||||
// clientProviders maps client ID to its provider identifier
|
||||
clientProviders map[string]string
|
||||
// mutex ensures thread-safe access to the registry
|
||||
@@ -104,10 +107,11 @@ var registryOnce sync.Once
|
||||
func GetGlobalRegistry() *ModelRegistry {
|
||||
registryOnce.Do(func() {
|
||||
globalRegistry = &ModelRegistry{
|
||||
models: make(map[string]*ModelRegistration),
|
||||
clientModels: make(map[string][]string),
|
||||
clientProviders: make(map[string]string),
|
||||
mutex: &sync.RWMutex{},
|
||||
models: make(map[string]*ModelRegistration),
|
||||
clientModels: make(map[string][]string),
|
||||
clientModelInfos: make(map[string]map[string]*ModelInfo),
|
||||
clientProviders: make(map[string]string),
|
||||
mutex: &sync.RWMutex{},
|
||||
}
|
||||
})
|
||||
return globalRegistry
|
||||
@@ -144,6 +148,7 @@ func (r *ModelRegistry) RegisterClient(clientID, clientProvider string, models [
|
||||
// No models supplied; unregister existing client state if present.
|
||||
r.unregisterClientInternal(clientID)
|
||||
delete(r.clientModels, clientID)
|
||||
delete(r.clientModelInfos, clientID)
|
||||
delete(r.clientProviders, clientID)
|
||||
misc.LogCredentialSeparator()
|
||||
return
|
||||
@@ -152,7 +157,7 @@ func (r *ModelRegistry) RegisterClient(clientID, clientProvider string, models [
|
||||
now := time.Now()
|
||||
|
||||
oldModels, hadExisting := r.clientModels[clientID]
|
||||
oldProvider, _ := r.clientProviders[clientID]
|
||||
oldProvider := r.clientProviders[clientID]
|
||||
providerChanged := oldProvider != provider
|
||||
if !hadExisting {
|
||||
// Pure addition path.
|
||||
@@ -161,6 +166,12 @@ func (r *ModelRegistry) RegisterClient(clientID, clientProvider string, models [
|
||||
r.addModelRegistration(modelID, provider, model, now)
|
||||
}
|
||||
r.clientModels[clientID] = append([]string(nil), rawModelIDs...)
|
||||
// Store client's own model infos
|
||||
clientInfos := make(map[string]*ModelInfo, len(newModels))
|
||||
for id, m := range newModels {
|
||||
clientInfos[id] = cloneModelInfo(m)
|
||||
}
|
||||
r.clientModelInfos[clientID] = clientInfos
|
||||
if provider != "" {
|
||||
r.clientProviders[clientID] = provider
|
||||
} else {
|
||||
@@ -287,6 +298,12 @@ func (r *ModelRegistry) RegisterClient(clientID, clientProvider string, models [
|
||||
if len(rawModelIDs) > 0 {
|
||||
r.clientModels[clientID] = append([]string(nil), rawModelIDs...)
|
||||
}
|
||||
// Update client's own model infos
|
||||
clientInfos := make(map[string]*ModelInfo, len(newModels))
|
||||
for id, m := range newModels {
|
||||
clientInfos[id] = cloneModelInfo(m)
|
||||
}
|
||||
r.clientModelInfos[clientID] = clientInfos
|
||||
if provider != "" {
|
||||
r.clientProviders[clientID] = provider
|
||||
} else {
|
||||
@@ -436,6 +453,7 @@ func (r *ModelRegistry) unregisterClientInternal(clientID string) {
|
||||
}
|
||||
|
||||
delete(r.clientModels, clientID)
|
||||
delete(r.clientModelInfos, clientID)
|
||||
if hasProvider {
|
||||
delete(r.clientProviders, clientID)
|
||||
}
|
||||
@@ -871,3 +889,44 @@ func (r *ModelRegistry) GetFirstAvailableModel(handlerType string) (string, erro
|
||||
|
||||
return "", fmt.Errorf("no available clients for any model in handler type: %s", handlerType)
|
||||
}
|
||||
|
||||
// GetModelsForClient returns the models registered for a specific client.
|
||||
// Parameters:
|
||||
// - clientID: The client identifier (typically auth file name or auth ID)
|
||||
//
|
||||
// Returns:
|
||||
// - []*ModelInfo: List of models registered for this client, nil if client not found
|
||||
func (r *ModelRegistry) GetModelsForClient(clientID string) []*ModelInfo {
|
||||
r.mutex.RLock()
|
||||
defer r.mutex.RUnlock()
|
||||
|
||||
modelIDs, exists := r.clientModels[clientID]
|
||||
if !exists || len(modelIDs) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Try to use client-specific model infos first
|
||||
clientInfos := r.clientModelInfos[clientID]
|
||||
|
||||
seen := make(map[string]struct{})
|
||||
result := make([]*ModelInfo, 0, len(modelIDs))
|
||||
for _, modelID := range modelIDs {
|
||||
if _, dup := seen[modelID]; dup {
|
||||
continue
|
||||
}
|
||||
seen[modelID] = struct{}{}
|
||||
|
||||
// Prefer client's own model info to preserve original type/owned_by
|
||||
if clientInfos != nil {
|
||||
if info, ok := clientInfos[modelID]; ok && info != nil {
|
||||
result = append(result, info)
|
||||
continue
|
||||
}
|
||||
}
|
||||
// Fallback to global registry (for backwards compatibility)
|
||||
if reg, ok := r.models[modelID]; ok && reg.Info != nil {
|
||||
result = append(result, reg.Info)
|
||||
}
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
@@ -322,10 +322,11 @@ func (e *AIStudioExecutor) translateRequest(req cliproxyexecutor.Request, opts c
|
||||
from := opts.SourceFormat
|
||||
to := sdktranslator.FromString("gemini")
|
||||
payload := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), stream)
|
||||
payload = applyThinkingMetadata(payload, req.Metadata, req.Model)
|
||||
payload = ApplyThinkingMetadata(payload, req.Metadata, req.Model)
|
||||
payload = util.ApplyGemini3ThinkingLevelFromMetadata(req.Model, req.Metadata, payload)
|
||||
payload = util.ApplyDefaultThinkingIfNeeded(req.Model, payload)
|
||||
payload = util.ConvertThinkingLevelToBudget(payload)
|
||||
payload = util.NormalizeGeminiThinkingBudget(req.Model, payload)
|
||||
payload = util.ConvertThinkingLevelToBudget(payload, req.Model, true)
|
||||
payload = util.NormalizeGeminiThinkingBudget(req.Model, payload, true)
|
||||
payload = util.StripThinkingConfigIfUnsupported(req.Model, payload)
|
||||
payload = fixGeminiImageAspectRatio(req.Model, payload)
|
||||
payload = applyPayloadConfig(e.cfg, req.Model, payload)
|
||||
@@ -384,8 +385,16 @@ func ensureColonSpacedJSON(payload []byte) []byte {
|
||||
|
||||
for i := 0; i < len(indented); i++ {
|
||||
ch := indented[i]
|
||||
if ch == '"' && (i == 0 || indented[i-1] != '\\') {
|
||||
inString = !inString
|
||||
if ch == '"' {
|
||||
// A quote is escaped only when preceded by an odd number of consecutive backslashes.
|
||||
// For example: "\\\"" keeps the quote inside the string, but "\\\\" closes the string.
|
||||
backslashes := 0
|
||||
for j := i - 1; j >= 0 && indented[j] == '\\'; j-- {
|
||||
backslashes++
|
||||
}
|
||||
if backslashes%2 == 0 {
|
||||
inString = !inString
|
||||
}
|
||||
}
|
||||
|
||||
if !inString {
|
||||
|
||||
@@ -32,15 +32,16 @@ import (
|
||||
const (
|
||||
antigravityBaseURLDaily = "https://daily-cloudcode-pa.sandbox.googleapis.com"
|
||||
// antigravityBaseURLAutopush = "https://autopush-cloudcode-pa.sandbox.googleapis.com"
|
||||
antigravityBaseURLProd = "https://cloudcode-pa.googleapis.com"
|
||||
antigravityStreamPath = "/v1internal:streamGenerateContent"
|
||||
antigravityGeneratePath = "/v1internal:generateContent"
|
||||
antigravityModelsPath = "/v1internal:fetchAvailableModels"
|
||||
antigravityClientID = "1071006060591-tmhssin2h21lcre235vtolojh4g403ep.apps.googleusercontent.com"
|
||||
antigravityClientSecret = "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"
|
||||
defaultAntigravityAgent = "antigravity/1.11.5 windows/amd64"
|
||||
antigravityAuthType = "antigravity"
|
||||
refreshSkew = 3000 * time.Second
|
||||
antigravityBaseURLProd = "https://cloudcode-pa.googleapis.com"
|
||||
antigravityCountTokensPath = "/v1internal:countTokens"
|
||||
antigravityStreamPath = "/v1internal:streamGenerateContent"
|
||||
antigravityGeneratePath = "/v1internal:generateContent"
|
||||
antigravityModelsPath = "/v1internal:fetchAvailableModels"
|
||||
antigravityClientID = "1071006060591-tmhssin2h21lcre235vtolojh4g403ep.apps.googleusercontent.com"
|
||||
antigravityClientSecret = "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"
|
||||
defaultAntigravityAgent = "antigravity/1.11.5 windows/amd64"
|
||||
antigravityAuthType = "antigravity"
|
||||
refreshSkew = 3000 * time.Second
|
||||
)
|
||||
|
||||
var randSource = rand.New(rand.NewSource(time.Now().UnixNano()))
|
||||
@@ -69,6 +70,10 @@ func (e *AntigravityExecutor) PrepareRequest(_ *http.Request, _ *cliproxyauth.Au
|
||||
|
||||
// Execute performs a non-streaming request to the Antigravity API.
|
||||
func (e *AntigravityExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (resp cliproxyexecutor.Response, err error) {
|
||||
if strings.Contains(req.Model, "claude") {
|
||||
return e.executeClaudeNonStream(ctx, auth, req, opts)
|
||||
}
|
||||
|
||||
token, updatedAuth, errToken := e.ensureAccessToken(ctx, auth)
|
||||
if errToken != nil {
|
||||
return resp, errToken
|
||||
@@ -85,8 +90,10 @@ func (e *AntigravityExecutor) Execute(ctx context.Context, auth *cliproxyauth.Au
|
||||
translated := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), false)
|
||||
|
||||
translated = applyThinkingMetadataCLI(translated, req.Metadata, req.Model)
|
||||
translated = util.ApplyGemini3ThinkingLevelFromMetadataCLI(req.Model, req.Metadata, translated)
|
||||
translated = util.ApplyDefaultThinkingIfNeededCLI(req.Model, translated)
|
||||
translated = normalizeAntigravityThinking(req.Model, translated)
|
||||
translated = applyPayloadConfigWithRoot(e.cfg, req.Model, "antigravity", "request", translated)
|
||||
|
||||
baseURLs := antigravityBaseURLFallbackOrder(auth)
|
||||
httpClient := newProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
|
||||
@@ -160,6 +167,338 @@ func (e *AntigravityExecutor) Execute(ctx context.Context, auth *cliproxyauth.Au
|
||||
return resp, err
|
||||
}
|
||||
|
||||
// executeClaudeNonStream performs a claude non-streaming request to the Antigravity API.
|
||||
func (e *AntigravityExecutor) executeClaudeNonStream(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (resp cliproxyexecutor.Response, err error) {
|
||||
token, updatedAuth, errToken := e.ensureAccessToken(ctx, auth)
|
||||
if errToken != nil {
|
||||
return resp, errToken
|
||||
}
|
||||
if updatedAuth != nil {
|
||||
auth = updatedAuth
|
||||
}
|
||||
|
||||
reporter := newUsageReporter(ctx, e.Identifier(), req.Model, auth)
|
||||
defer reporter.trackFailure(ctx, &err)
|
||||
|
||||
from := opts.SourceFormat
|
||||
to := sdktranslator.FromString("antigravity")
|
||||
translated := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), true)
|
||||
|
||||
translated = applyThinkingMetadataCLI(translated, req.Metadata, req.Model)
|
||||
translated = util.ApplyGemini3ThinkingLevelFromMetadataCLI(req.Model, req.Metadata, translated)
|
||||
translated = util.ApplyDefaultThinkingIfNeededCLI(req.Model, translated)
|
||||
translated = normalizeAntigravityThinking(req.Model, translated)
|
||||
translated = applyPayloadConfigWithRoot(e.cfg, req.Model, "antigravity", "request", translated)
|
||||
|
||||
baseURLs := antigravityBaseURLFallbackOrder(auth)
|
||||
httpClient := newProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
|
||||
|
||||
var lastStatus int
|
||||
var lastBody []byte
|
||||
var lastErr error
|
||||
|
||||
for idx, baseURL := range baseURLs {
|
||||
httpReq, errReq := e.buildRequest(ctx, auth, token, req.Model, translated, true, opts.Alt, baseURL)
|
||||
if errReq != nil {
|
||||
err = errReq
|
||||
return resp, err
|
||||
}
|
||||
|
||||
httpResp, errDo := httpClient.Do(httpReq)
|
||||
if errDo != nil {
|
||||
recordAPIResponseError(ctx, e.cfg, errDo)
|
||||
lastStatus = 0
|
||||
lastBody = nil
|
||||
lastErr = errDo
|
||||
if idx+1 < len(baseURLs) {
|
||||
log.Debugf("antigravity executor: request error on base url %s, retrying with fallback base url: %s", baseURL, baseURLs[idx+1])
|
||||
continue
|
||||
}
|
||||
err = errDo
|
||||
return resp, err
|
||||
}
|
||||
recordAPIResponseMetadata(ctx, e.cfg, httpResp.StatusCode, httpResp.Header.Clone())
|
||||
if httpResp.StatusCode < http.StatusOK || httpResp.StatusCode >= http.StatusMultipleChoices {
|
||||
bodyBytes, errRead := io.ReadAll(httpResp.Body)
|
||||
if errClose := httpResp.Body.Close(); errClose != nil {
|
||||
log.Errorf("antigravity executor: close response body error: %v", errClose)
|
||||
}
|
||||
if errRead != nil {
|
||||
recordAPIResponseError(ctx, e.cfg, errRead)
|
||||
lastStatus = 0
|
||||
lastBody = nil
|
||||
lastErr = errRead
|
||||
if idx+1 < len(baseURLs) {
|
||||
log.Debugf("antigravity executor: read error on base url %s, retrying with fallback base url: %s", baseURL, baseURLs[idx+1])
|
||||
continue
|
||||
}
|
||||
err = errRead
|
||||
return resp, err
|
||||
}
|
||||
appendAPIResponseChunk(ctx, e.cfg, bodyBytes)
|
||||
lastStatus = httpResp.StatusCode
|
||||
lastBody = append([]byte(nil), bodyBytes...)
|
||||
lastErr = nil
|
||||
if httpResp.StatusCode == http.StatusTooManyRequests && idx+1 < len(baseURLs) {
|
||||
log.Debugf("antigravity executor: rate limited on base url %s, retrying with fallback base url: %s", baseURL, baseURLs[idx+1])
|
||||
continue
|
||||
}
|
||||
err = statusErr{code: httpResp.StatusCode, msg: string(bodyBytes)}
|
||||
return resp, err
|
||||
}
|
||||
|
||||
out := make(chan cliproxyexecutor.StreamChunk)
|
||||
go func(resp *http.Response) {
|
||||
defer close(out)
|
||||
defer func() {
|
||||
if errClose := resp.Body.Close(); errClose != nil {
|
||||
log.Errorf("antigravity executor: close response body error: %v", errClose)
|
||||
}
|
||||
}()
|
||||
scanner := bufio.NewScanner(resp.Body)
|
||||
scanner.Buffer(nil, streamScannerBuffer)
|
||||
for scanner.Scan() {
|
||||
line := scanner.Bytes()
|
||||
appendAPIResponseChunk(ctx, e.cfg, line)
|
||||
|
||||
// Filter usage metadata for all models
|
||||
// Only retain usage statistics in the terminal chunk
|
||||
line = FilterSSEUsageMetadata(line)
|
||||
|
||||
payload := jsonPayload(line)
|
||||
if payload == nil {
|
||||
continue
|
||||
}
|
||||
|
||||
if detail, ok := parseAntigravityStreamUsage(payload); ok {
|
||||
reporter.publish(ctx, detail)
|
||||
}
|
||||
|
||||
out <- cliproxyexecutor.StreamChunk{Payload: payload}
|
||||
}
|
||||
if errScan := scanner.Err(); errScan != nil {
|
||||
recordAPIResponseError(ctx, e.cfg, errScan)
|
||||
reporter.publishFailure(ctx)
|
||||
out <- cliproxyexecutor.StreamChunk{Err: errScan}
|
||||
} else {
|
||||
reporter.ensurePublished(ctx)
|
||||
}
|
||||
}(httpResp)
|
||||
|
||||
var buffer bytes.Buffer
|
||||
for chunk := range out {
|
||||
if chunk.Err != nil {
|
||||
return resp, chunk.Err
|
||||
}
|
||||
if len(chunk.Payload) > 0 {
|
||||
_, _ = buffer.Write(chunk.Payload)
|
||||
_, _ = buffer.Write([]byte("\n"))
|
||||
}
|
||||
}
|
||||
resp = cliproxyexecutor.Response{Payload: e.convertStreamToNonStream(buffer.Bytes())}
|
||||
|
||||
reporter.publish(ctx, parseAntigravityUsage(resp.Payload))
|
||||
var param any
|
||||
converted := sdktranslator.TranslateNonStream(ctx, to, from, req.Model, bytes.Clone(opts.OriginalRequest), translated, resp.Payload, ¶m)
|
||||
resp = cliproxyexecutor.Response{Payload: []byte(converted)}
|
||||
reporter.ensurePublished(ctx)
|
||||
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
switch {
|
||||
case lastStatus != 0:
|
||||
err = statusErr{code: lastStatus, msg: string(lastBody)}
|
||||
case lastErr != nil:
|
||||
err = lastErr
|
||||
default:
|
||||
err = statusErr{code: http.StatusServiceUnavailable, msg: "antigravity executor: no base url available"}
|
||||
}
|
||||
return resp, err
|
||||
}
|
||||
|
||||
func (e *AntigravityExecutor) convertStreamToNonStream(stream []byte) []byte {
|
||||
responseTemplate := ""
|
||||
var traceID string
|
||||
var finishReason string
|
||||
var modelVersion string
|
||||
var responseID string
|
||||
var role string
|
||||
var usageRaw string
|
||||
parts := make([]map[string]interface{}, 0)
|
||||
var pendingKind string
|
||||
var pendingText strings.Builder
|
||||
var pendingThoughtSig string
|
||||
|
||||
flushPending := func() {
|
||||
if pendingKind == "" {
|
||||
return
|
||||
}
|
||||
text := pendingText.String()
|
||||
switch pendingKind {
|
||||
case "text":
|
||||
if strings.TrimSpace(text) == "" {
|
||||
pendingKind = ""
|
||||
pendingText.Reset()
|
||||
pendingThoughtSig = ""
|
||||
return
|
||||
}
|
||||
parts = append(parts, map[string]interface{}{"text": text})
|
||||
case "thought":
|
||||
if strings.TrimSpace(text) == "" && pendingThoughtSig == "" {
|
||||
pendingKind = ""
|
||||
pendingText.Reset()
|
||||
pendingThoughtSig = ""
|
||||
return
|
||||
}
|
||||
part := map[string]interface{}{"thought": true}
|
||||
part["text"] = text
|
||||
if pendingThoughtSig != "" {
|
||||
part["thoughtSignature"] = pendingThoughtSig
|
||||
}
|
||||
parts = append(parts, part)
|
||||
}
|
||||
pendingKind = ""
|
||||
pendingText.Reset()
|
||||
pendingThoughtSig = ""
|
||||
}
|
||||
|
||||
normalizePart := func(partResult gjson.Result) map[string]interface{} {
|
||||
var m map[string]interface{}
|
||||
_ = json.Unmarshal([]byte(partResult.Raw), &m)
|
||||
if m == nil {
|
||||
m = map[string]interface{}{}
|
||||
}
|
||||
sig := partResult.Get("thoughtSignature").String()
|
||||
if sig == "" {
|
||||
sig = partResult.Get("thought_signature").String()
|
||||
}
|
||||
if sig != "" {
|
||||
m["thoughtSignature"] = sig
|
||||
delete(m, "thought_signature")
|
||||
}
|
||||
if inlineData, ok := m["inline_data"]; ok {
|
||||
m["inlineData"] = inlineData
|
||||
delete(m, "inline_data")
|
||||
}
|
||||
return m
|
||||
}
|
||||
|
||||
for _, line := range bytes.Split(stream, []byte("\n")) {
|
||||
trimmed := bytes.TrimSpace(line)
|
||||
if len(trimmed) == 0 || !gjson.ValidBytes(trimmed) {
|
||||
continue
|
||||
}
|
||||
|
||||
root := gjson.ParseBytes(trimmed)
|
||||
responseNode := root.Get("response")
|
||||
if !responseNode.Exists() {
|
||||
if root.Get("candidates").Exists() {
|
||||
responseNode = root
|
||||
} else {
|
||||
continue
|
||||
}
|
||||
}
|
||||
responseTemplate = responseNode.Raw
|
||||
|
||||
if traceResult := root.Get("traceId"); traceResult.Exists() && traceResult.String() != "" {
|
||||
traceID = traceResult.String()
|
||||
}
|
||||
|
||||
if roleResult := responseNode.Get("candidates.0.content.role"); roleResult.Exists() {
|
||||
role = roleResult.String()
|
||||
}
|
||||
|
||||
if finishResult := responseNode.Get("candidates.0.finishReason"); finishResult.Exists() && finishResult.String() != "" {
|
||||
finishReason = finishResult.String()
|
||||
}
|
||||
|
||||
if modelResult := responseNode.Get("modelVersion"); modelResult.Exists() && modelResult.String() != "" {
|
||||
modelVersion = modelResult.String()
|
||||
}
|
||||
if responseIDResult := responseNode.Get("responseId"); responseIDResult.Exists() && responseIDResult.String() != "" {
|
||||
responseID = responseIDResult.String()
|
||||
}
|
||||
if usageResult := responseNode.Get("usageMetadata"); usageResult.Exists() {
|
||||
usageRaw = usageResult.Raw
|
||||
} else if usageResult := root.Get("usageMetadata"); usageResult.Exists() {
|
||||
usageRaw = usageResult.Raw
|
||||
}
|
||||
|
||||
if partsResult := responseNode.Get("candidates.0.content.parts"); partsResult.IsArray() {
|
||||
for _, part := range partsResult.Array() {
|
||||
hasFunctionCall := part.Get("functionCall").Exists()
|
||||
hasInlineData := part.Get("inlineData").Exists() || part.Get("inline_data").Exists()
|
||||
sig := part.Get("thoughtSignature").String()
|
||||
if sig == "" {
|
||||
sig = part.Get("thought_signature").String()
|
||||
}
|
||||
text := part.Get("text").String()
|
||||
thought := part.Get("thought").Bool()
|
||||
|
||||
if hasFunctionCall || hasInlineData {
|
||||
flushPending()
|
||||
parts = append(parts, normalizePart(part))
|
||||
continue
|
||||
}
|
||||
|
||||
if thought || part.Get("text").Exists() {
|
||||
kind := "text"
|
||||
if thought {
|
||||
kind = "thought"
|
||||
}
|
||||
if pendingKind != "" && pendingKind != kind {
|
||||
flushPending()
|
||||
}
|
||||
pendingKind = kind
|
||||
pendingText.WriteString(text)
|
||||
if kind == "thought" && sig != "" {
|
||||
pendingThoughtSig = sig
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
flushPending()
|
||||
parts = append(parts, normalizePart(part))
|
||||
}
|
||||
}
|
||||
}
|
||||
flushPending()
|
||||
|
||||
if responseTemplate == "" {
|
||||
responseTemplate = `{"candidates":[{"content":{"role":"model","parts":[]}}]}`
|
||||
}
|
||||
|
||||
partsJSON, _ := json.Marshal(parts)
|
||||
responseTemplate, _ = sjson.SetRaw(responseTemplate, "candidates.0.content.parts", string(partsJSON))
|
||||
if role != "" {
|
||||
responseTemplate, _ = sjson.Set(responseTemplate, "candidates.0.content.role", role)
|
||||
}
|
||||
if finishReason != "" {
|
||||
responseTemplate, _ = sjson.Set(responseTemplate, "candidates.0.finishReason", finishReason)
|
||||
}
|
||||
if modelVersion != "" {
|
||||
responseTemplate, _ = sjson.Set(responseTemplate, "modelVersion", modelVersion)
|
||||
}
|
||||
if responseID != "" {
|
||||
responseTemplate, _ = sjson.Set(responseTemplate, "responseId", responseID)
|
||||
}
|
||||
if usageRaw != "" {
|
||||
responseTemplate, _ = sjson.SetRaw(responseTemplate, "usageMetadata", usageRaw)
|
||||
} else if !gjson.Get(responseTemplate, "usageMetadata").Exists() {
|
||||
responseTemplate, _ = sjson.Set(responseTemplate, "usageMetadata.promptTokenCount", 0)
|
||||
responseTemplate, _ = sjson.Set(responseTemplate, "usageMetadata.candidatesTokenCount", 0)
|
||||
responseTemplate, _ = sjson.Set(responseTemplate, "usageMetadata.totalTokenCount", 0)
|
||||
}
|
||||
|
||||
output := `{"response":{},"traceId":""}`
|
||||
output, _ = sjson.SetRaw(output, "response", responseTemplate)
|
||||
if traceID != "" {
|
||||
output, _ = sjson.Set(output, "traceId", traceID)
|
||||
}
|
||||
return []byte(output)
|
||||
}
|
||||
|
||||
// ExecuteStream performs a streaming request to the Antigravity API.
|
||||
func (e *AntigravityExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (stream <-chan cliproxyexecutor.StreamChunk, err error) {
|
||||
ctx = context.WithValue(ctx, "alt", "")
|
||||
@@ -180,8 +519,10 @@ func (e *AntigravityExecutor) ExecuteStream(ctx context.Context, auth *cliproxya
|
||||
translated := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), true)
|
||||
|
||||
translated = applyThinkingMetadataCLI(translated, req.Metadata, req.Model)
|
||||
translated = util.ApplyGemini3ThinkingLevelFromMetadataCLI(req.Model, req.Metadata, translated)
|
||||
translated = util.ApplyDefaultThinkingIfNeededCLI(req.Model, translated)
|
||||
translated = normalizeAntigravityThinking(req.Model, translated)
|
||||
translated = applyPayloadConfigWithRoot(e.cfg, req.Model, "antigravity", "request", translated)
|
||||
|
||||
baseURLs := antigravityBaseURLFallbackOrder(auth)
|
||||
httpClient := newProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
|
||||
@@ -312,9 +653,131 @@ func (e *AntigravityExecutor) Refresh(ctx context.Context, auth *cliproxyauth.Au
|
||||
return updated, nil
|
||||
}
|
||||
|
||||
// CountTokens counts tokens for the given request (not supported for Antigravity).
|
||||
func (e *AntigravityExecutor) CountTokens(context.Context, *cliproxyauth.Auth, cliproxyexecutor.Request, cliproxyexecutor.Options) (cliproxyexecutor.Response, error) {
|
||||
return cliproxyexecutor.Response{}, statusErr{code: http.StatusNotImplemented, msg: "count tokens not supported"}
|
||||
// CountTokens counts tokens for the given request using the Antigravity API.
|
||||
func (e *AntigravityExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (cliproxyexecutor.Response, error) {
|
||||
token, updatedAuth, errToken := e.ensureAccessToken(ctx, auth)
|
||||
if errToken != nil {
|
||||
return cliproxyexecutor.Response{}, errToken
|
||||
}
|
||||
if updatedAuth != nil {
|
||||
auth = updatedAuth
|
||||
}
|
||||
if strings.TrimSpace(token) == "" {
|
||||
return cliproxyexecutor.Response{}, statusErr{code: http.StatusUnauthorized, msg: "missing access token"}
|
||||
}
|
||||
|
||||
from := opts.SourceFormat
|
||||
to := sdktranslator.FromString("antigravity")
|
||||
respCtx := context.WithValue(ctx, "alt", opts.Alt)
|
||||
|
||||
baseURLs := antigravityBaseURLFallbackOrder(auth)
|
||||
httpClient := newProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
|
||||
|
||||
var authID, authLabel, authType, authValue string
|
||||
if auth != nil {
|
||||
authID = auth.ID
|
||||
authLabel = auth.Label
|
||||
authType, authValue = auth.AccountInfo()
|
||||
}
|
||||
|
||||
var lastStatus int
|
||||
var lastBody []byte
|
||||
var lastErr error
|
||||
|
||||
for idx, baseURL := range baseURLs {
|
||||
payload := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), false)
|
||||
payload = applyThinkingMetadataCLI(payload, req.Metadata, req.Model)
|
||||
payload = util.ApplyDefaultThinkingIfNeededCLI(req.Model, payload)
|
||||
payload = normalizeAntigravityThinking(req.Model, payload)
|
||||
payload = deleteJSONField(payload, "project")
|
||||
payload = deleteJSONField(payload, "model")
|
||||
payload = deleteJSONField(payload, "request.safetySettings")
|
||||
|
||||
base := strings.TrimSuffix(baseURL, "/")
|
||||
if base == "" {
|
||||
base = buildBaseURL(auth)
|
||||
}
|
||||
|
||||
var requestURL strings.Builder
|
||||
requestURL.WriteString(base)
|
||||
requestURL.WriteString(antigravityCountTokensPath)
|
||||
if opts.Alt != "" {
|
||||
requestURL.WriteString("?$alt=")
|
||||
requestURL.WriteString(url.QueryEscape(opts.Alt))
|
||||
}
|
||||
|
||||
httpReq, errReq := http.NewRequestWithContext(ctx, http.MethodPost, requestURL.String(), bytes.NewReader(payload))
|
||||
if errReq != nil {
|
||||
return cliproxyexecutor.Response{}, errReq
|
||||
}
|
||||
httpReq.Header.Set("Content-Type", "application/json")
|
||||
httpReq.Header.Set("Authorization", "Bearer "+token)
|
||||
httpReq.Header.Set("User-Agent", resolveUserAgent(auth))
|
||||
httpReq.Header.Set("Accept", "application/json")
|
||||
if host := resolveHost(base); host != "" {
|
||||
httpReq.Host = host
|
||||
}
|
||||
|
||||
recordAPIRequest(ctx, e.cfg, upstreamRequestLog{
|
||||
URL: requestURL.String(),
|
||||
Method: http.MethodPost,
|
||||
Headers: httpReq.Header.Clone(),
|
||||
Body: payload,
|
||||
Provider: e.Identifier(),
|
||||
AuthID: authID,
|
||||
AuthLabel: authLabel,
|
||||
AuthType: authType,
|
||||
AuthValue: authValue,
|
||||
})
|
||||
|
||||
httpResp, errDo := httpClient.Do(httpReq)
|
||||
if errDo != nil {
|
||||
recordAPIResponseError(ctx, e.cfg, errDo)
|
||||
lastStatus = 0
|
||||
lastBody = nil
|
||||
lastErr = errDo
|
||||
if idx+1 < len(baseURLs) {
|
||||
log.Debugf("antigravity executor: request error on base url %s, retrying with fallback base url: %s", baseURL, baseURLs[idx+1])
|
||||
continue
|
||||
}
|
||||
return cliproxyexecutor.Response{}, errDo
|
||||
}
|
||||
|
||||
recordAPIResponseMetadata(ctx, e.cfg, httpResp.StatusCode, httpResp.Header.Clone())
|
||||
bodyBytes, errRead := io.ReadAll(httpResp.Body)
|
||||
if errClose := httpResp.Body.Close(); errClose != nil {
|
||||
log.Errorf("antigravity executor: close response body error: %v", errClose)
|
||||
}
|
||||
if errRead != nil {
|
||||
recordAPIResponseError(ctx, e.cfg, errRead)
|
||||
return cliproxyexecutor.Response{}, errRead
|
||||
}
|
||||
appendAPIResponseChunk(ctx, e.cfg, bodyBytes)
|
||||
|
||||
if httpResp.StatusCode >= http.StatusOK && httpResp.StatusCode < http.StatusMultipleChoices {
|
||||
count := gjson.GetBytes(bodyBytes, "totalTokens").Int()
|
||||
translated := sdktranslator.TranslateTokenCount(respCtx, to, from, count, bodyBytes)
|
||||
return cliproxyexecutor.Response{Payload: []byte(translated)}, nil
|
||||
}
|
||||
|
||||
lastStatus = httpResp.StatusCode
|
||||
lastBody = append([]byte(nil), bodyBytes...)
|
||||
lastErr = nil
|
||||
if httpResp.StatusCode == http.StatusTooManyRequests && idx+1 < len(baseURLs) {
|
||||
log.Debugf("antigravity executor: rate limited on base url %s, retrying with fallback base url: %s", baseURL, baseURLs[idx+1])
|
||||
continue
|
||||
}
|
||||
return cliproxyexecutor.Response{}, statusErr{code: httpResp.StatusCode, msg: string(bodyBytes)}
|
||||
}
|
||||
|
||||
switch {
|
||||
case lastStatus != 0:
|
||||
return cliproxyexecutor.Response{}, statusErr{code: lastStatus, msg: string(lastBody)}
|
||||
case lastErr != nil:
|
||||
return cliproxyexecutor.Response{}, lastErr
|
||||
default:
|
||||
return cliproxyexecutor.Response{}, statusErr{code: http.StatusServiceUnavailable, msg: "antigravity executor: no base url available"}
|
||||
}
|
||||
}
|
||||
|
||||
// FetchAntigravityModels retrieves available models using the supplied auth.
|
||||
@@ -545,27 +1008,9 @@ func (e *AntigravityExecutor) buildRequest(ctx context.Context, auth *cliproxyau
|
||||
strJSON, _ = util.RenameKey(strJSON, p, p[:len(p)-len("parametersJsonSchema")]+"parameters")
|
||||
}
|
||||
|
||||
strJSON = util.DeleteKey(strJSON, "$schema")
|
||||
strJSON = util.DeleteKey(strJSON, "maxItems")
|
||||
strJSON = util.DeleteKey(strJSON, "minItems")
|
||||
strJSON = util.DeleteKey(strJSON, "minLength")
|
||||
strJSON = util.DeleteKey(strJSON, "maxLength")
|
||||
strJSON = util.DeleteKey(strJSON, "exclusiveMinimum")
|
||||
strJSON = util.DeleteKey(strJSON, "exclusiveMaximum")
|
||||
strJSON = util.DeleteKey(strJSON, "$ref")
|
||||
strJSON = util.DeleteKey(strJSON, "$defs")
|
||||
|
||||
paths = make([]string, 0)
|
||||
util.Walk(gjson.Parse(strJSON), "", "anyOf", &paths)
|
||||
for _, p := range paths {
|
||||
anyOf := gjson.Get(strJSON, p)
|
||||
if anyOf.IsArray() {
|
||||
anyOfItems := anyOf.Array()
|
||||
if len(anyOfItems) > 0 {
|
||||
strJSON, _ = sjson.SetRaw(strJSON, p[:len(p)-len(".anyOf")], anyOfItems[0].Raw)
|
||||
}
|
||||
}
|
||||
}
|
||||
// Use the centralized schema cleaner to handle unsupported keywords,
|
||||
// const->enum conversion, and flattening of types/anyOf.
|
||||
strJSON = util.CleanJSONSchemaForGemini(strJSON)
|
||||
|
||||
payload = []byte(strJSON)
|
||||
}
|
||||
@@ -798,6 +1243,8 @@ func modelName2Alias(modelName string) string {
|
||||
return "gemini-3-pro-image-preview"
|
||||
case "gemini-3-pro-high":
|
||||
return "gemini-3-pro-preview"
|
||||
case "gemini-3-flash":
|
||||
return "gemini-3-flash-preview"
|
||||
case "claude-sonnet-4-5":
|
||||
return "gemini-claude-sonnet-4-5"
|
||||
case "claude-sonnet-4-5-thinking":
|
||||
@@ -819,6 +1266,8 @@ func alias2ModelName(modelName string) string {
|
||||
return "gemini-3-pro-image"
|
||||
case "gemini-3-pro-preview":
|
||||
return "gemini-3-pro-high"
|
||||
case "gemini-3-flash-preview":
|
||||
return "gemini-3-flash"
|
||||
case "gemini-claude-sonnet-4-5":
|
||||
return "claude-sonnet-4-5"
|
||||
case "gemini-claude-sonnet-4-5-thinking":
|
||||
|
||||
@@ -54,9 +54,9 @@ func (e *CodexExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, re
|
||||
from := opts.SourceFormat
|
||||
to := sdktranslator.FromString("codex")
|
||||
body := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), false)
|
||||
body = applyReasoningEffortMetadata(body, req.Metadata, req.Model, "reasoning.effort", false)
|
||||
body = normalizeThinkingConfig(body, upstreamModel, false)
|
||||
if errValidate := validateThinkingConfig(body, upstreamModel); errValidate != nil {
|
||||
body = ApplyReasoningEffortMetadata(body, req.Metadata, req.Model, "reasoning.effort", false)
|
||||
body = NormalizeThinkingConfig(body, upstreamModel, false)
|
||||
if errValidate := ValidateThinkingConfig(body, upstreamModel); errValidate != nil {
|
||||
return resp, errValidate
|
||||
}
|
||||
body = applyPayloadConfig(e.cfg, req.Model, body)
|
||||
@@ -152,9 +152,9 @@ func (e *CodexExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Au
|
||||
to := sdktranslator.FromString("codex")
|
||||
body := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), true)
|
||||
|
||||
body = applyReasoningEffortMetadata(body, req.Metadata, req.Model, "reasoning.effort", false)
|
||||
body = normalizeThinkingConfig(body, upstreamModel, false)
|
||||
if errValidate := validateThinkingConfig(body, upstreamModel); errValidate != nil {
|
||||
body = ApplyReasoningEffortMetadata(body, req.Metadata, req.Model, "reasoning.effort", false)
|
||||
body = NormalizeThinkingConfig(body, upstreamModel, false)
|
||||
if errValidate := ValidateThinkingConfig(body, upstreamModel); errValidate != nil {
|
||||
return nil, errValidate
|
||||
}
|
||||
body = applyPayloadConfig(e.cfg, req.Model, body)
|
||||
@@ -254,7 +254,7 @@ func (e *CodexExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.Auth
|
||||
|
||||
modelForCounting := req.Model
|
||||
|
||||
body = applyReasoningEffortMetadata(body, req.Metadata, req.Model, "reasoning.effort", false)
|
||||
body = ApplyReasoningEffortMetadata(body, req.Metadata, req.Model, "reasoning.effort", false)
|
||||
body, _ = sjson.SetBytes(body, "model", upstreamModel)
|
||||
body, _ = sjson.DeleteBytes(body, "previous_response_id")
|
||||
body, _ = sjson.SetBytes(body, "stream", false)
|
||||
|
||||
@@ -11,6 +11,8 @@ import (
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"regexp"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
@@ -77,6 +79,7 @@ func (e *GeminiCLIExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth
|
||||
to := sdktranslator.FromString("gemini-cli")
|
||||
basePayload := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), false)
|
||||
basePayload = applyThinkingMetadataCLI(basePayload, req.Metadata, req.Model)
|
||||
basePayload = util.ApplyGemini3ThinkingLevelFromMetadataCLI(req.Model, req.Metadata, basePayload)
|
||||
basePayload = util.ApplyDefaultThinkingIfNeededCLI(req.Model, basePayload)
|
||||
basePayload = util.NormalizeGeminiCLIThinkingBudget(req.Model, basePayload)
|
||||
basePayload = util.StripThinkingConfigIfUnsupported(req.Model, basePayload)
|
||||
@@ -215,6 +218,7 @@ func (e *GeminiCLIExecutor) ExecuteStream(ctx context.Context, auth *cliproxyaut
|
||||
to := sdktranslator.FromString("gemini-cli")
|
||||
basePayload := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), true)
|
||||
basePayload = applyThinkingMetadataCLI(basePayload, req.Metadata, req.Model)
|
||||
basePayload = util.ApplyGemini3ThinkingLevelFromMetadataCLI(req.Model, req.Metadata, basePayload)
|
||||
basePayload = util.ApplyDefaultThinkingIfNeededCLI(req.Model, basePayload)
|
||||
basePayload = util.NormalizeGeminiCLIThinkingBudget(req.Model, basePayload)
|
||||
basePayload = util.StripThinkingConfigIfUnsupported(req.Model, basePayload)
|
||||
@@ -416,6 +420,7 @@ func (e *GeminiCLIExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.
|
||||
for _, attemptModel := range models {
|
||||
payload := sdktranslator.TranslateRequest(from, to, attemptModel, bytes.Clone(req.Payload), false)
|
||||
payload = applyThinkingMetadataCLI(payload, req.Metadata, req.Model)
|
||||
payload = util.ApplyGemini3ThinkingLevelFromMetadataCLI(req.Model, req.Metadata, payload)
|
||||
payload = deleteJSONField(payload, "project")
|
||||
payload = deleteJSONField(payload, "model")
|
||||
payload = deleteJSONField(payload, "request.safetySettings")
|
||||
@@ -784,20 +789,45 @@ func parseRetryDelay(errorBody []byte) (*time.Duration, error) {
|
||||
// Try to parse the retryDelay from the error response
|
||||
// Format: error.details[].retryDelay where @type == "type.googleapis.com/google.rpc.RetryInfo"
|
||||
details := gjson.GetBytes(errorBody, "error.details")
|
||||
if !details.Exists() || !details.IsArray() {
|
||||
return nil, fmt.Errorf("no error.details found")
|
||||
if details.Exists() && details.IsArray() {
|
||||
for _, detail := range details.Array() {
|
||||
typeVal := detail.Get("@type").String()
|
||||
if typeVal == "type.googleapis.com/google.rpc.RetryInfo" {
|
||||
retryDelay := detail.Get("retryDelay").String()
|
||||
if retryDelay != "" {
|
||||
// Parse duration string like "0.847655010s"
|
||||
duration, err := time.ParseDuration(retryDelay)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to parse duration")
|
||||
}
|
||||
return &duration, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Fallback: try ErrorInfo.metadata.quotaResetDelay (e.g., "373.801628ms")
|
||||
for _, detail := range details.Array() {
|
||||
typeVal := detail.Get("@type").String()
|
||||
if typeVal == "type.googleapis.com/google.rpc.ErrorInfo" {
|
||||
quotaResetDelay := detail.Get("metadata.quotaResetDelay").String()
|
||||
if quotaResetDelay != "" {
|
||||
duration, err := time.ParseDuration(quotaResetDelay)
|
||||
if err == nil {
|
||||
return &duration, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for _, detail := range details.Array() {
|
||||
typeVal := detail.Get("@type").String()
|
||||
if typeVal == "type.googleapis.com/google.rpc.RetryInfo" {
|
||||
retryDelay := detail.Get("retryDelay").String()
|
||||
if retryDelay != "" {
|
||||
// Parse duration string like "0.847655010s"
|
||||
duration, err := time.ParseDuration(retryDelay)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to parse duration")
|
||||
}
|
||||
// Fallback: parse from error.message "Your quota will reset after Xs."
|
||||
message := gjson.GetBytes(errorBody, "error.message").String()
|
||||
if message != "" {
|
||||
re := regexp.MustCompile(`after\s+(\d+)s\.?`)
|
||||
if matches := re.FindStringSubmatch(message); len(matches) > 1 {
|
||||
seconds, err := strconv.Atoi(matches[1])
|
||||
if err == nil {
|
||||
duration := time.Duration(seconds) * time.Second
|
||||
return &duration, nil
|
||||
}
|
||||
}
|
||||
|
||||
@@ -83,7 +83,7 @@ func (e *GeminiExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, r
|
||||
from := opts.SourceFormat
|
||||
to := sdktranslator.FromString("gemini")
|
||||
body := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), false)
|
||||
body = applyThinkingMetadata(body, req.Metadata, req.Model)
|
||||
body = ApplyThinkingMetadata(body, req.Metadata, req.Model)
|
||||
body = util.ApplyDefaultThinkingIfNeeded(req.Model, body)
|
||||
body = util.NormalizeGeminiThinkingBudget(req.Model, body)
|
||||
body = util.StripThinkingConfigIfUnsupported(req.Model, body)
|
||||
@@ -178,7 +178,7 @@ func (e *GeminiExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.A
|
||||
from := opts.SourceFormat
|
||||
to := sdktranslator.FromString("gemini")
|
||||
body := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), true)
|
||||
body = applyThinkingMetadata(body, req.Metadata, req.Model)
|
||||
body = ApplyThinkingMetadata(body, req.Metadata, req.Model)
|
||||
body = util.ApplyDefaultThinkingIfNeeded(req.Model, body)
|
||||
body = util.NormalizeGeminiThinkingBudget(req.Model, body)
|
||||
body = util.StripThinkingConfigIfUnsupported(req.Model, body)
|
||||
@@ -290,7 +290,7 @@ func (e *GeminiExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.Aut
|
||||
from := opts.SourceFormat
|
||||
to := sdktranslator.FromString("gemini")
|
||||
translatedReq := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), false)
|
||||
translatedReq = applyThinkingMetadata(translatedReq, req.Metadata, req.Model)
|
||||
translatedReq = ApplyThinkingMetadata(translatedReq, req.Metadata, req.Model)
|
||||
translatedReq = util.StripThinkingConfigIfUnsupported(req.Model, translatedReq)
|
||||
translatedReq = fixGeminiImageAspectRatio(req.Model, translatedReq)
|
||||
respCtx := context.WithValue(ctx, "alt", opts.Alt)
|
||||
|
||||
@@ -57,15 +57,16 @@ func (e *IFlowExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, re
|
||||
from := opts.SourceFormat
|
||||
to := sdktranslator.FromString("openai")
|
||||
body := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), false)
|
||||
body = applyReasoningEffortMetadata(body, req.Metadata, req.Model, "reasoning_effort", false)
|
||||
body = ApplyReasoningEffortMetadata(body, req.Metadata, req.Model, "reasoning_effort", false)
|
||||
upstreamModel := util.ResolveOriginalModel(req.Model, req.Metadata)
|
||||
if upstreamModel != "" {
|
||||
body, _ = sjson.SetBytes(body, "model", upstreamModel)
|
||||
}
|
||||
body = normalizeThinkingConfig(body, upstreamModel, false)
|
||||
if errValidate := validateThinkingConfig(body, upstreamModel); errValidate != nil {
|
||||
body = NormalizeThinkingConfig(body, upstreamModel, false)
|
||||
if errValidate := ValidateThinkingConfig(body, upstreamModel); errValidate != nil {
|
||||
return resp, errValidate
|
||||
}
|
||||
body = applyIFlowThinkingConfig(body)
|
||||
body = applyPayloadConfig(e.cfg, req.Model, body)
|
||||
|
||||
endpoint := strings.TrimSuffix(baseURL, "/") + iflowDefaultEndpoint
|
||||
@@ -148,15 +149,16 @@ func (e *IFlowExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Au
|
||||
to := sdktranslator.FromString("openai")
|
||||
body := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), true)
|
||||
|
||||
body = applyReasoningEffortMetadata(body, req.Metadata, req.Model, "reasoning_effort", false)
|
||||
body = ApplyReasoningEffortMetadata(body, req.Metadata, req.Model, "reasoning_effort", false)
|
||||
upstreamModel := util.ResolveOriginalModel(req.Model, req.Metadata)
|
||||
if upstreamModel != "" {
|
||||
body, _ = sjson.SetBytes(body, "model", upstreamModel)
|
||||
}
|
||||
body = normalizeThinkingConfig(body, upstreamModel, false)
|
||||
if errValidate := validateThinkingConfig(body, upstreamModel); errValidate != nil {
|
||||
body = NormalizeThinkingConfig(body, upstreamModel, false)
|
||||
if errValidate := ValidateThinkingConfig(body, upstreamModel); errValidate != nil {
|
||||
return nil, errValidate
|
||||
}
|
||||
body = applyIFlowThinkingConfig(body)
|
||||
// Ensure tools array exists to avoid provider quirks similar to Qwen's behaviour.
|
||||
toolsResult := gjson.GetBytes(body, "tools")
|
||||
if toolsResult.Exists() && toolsResult.IsArray() && len(toolsResult.Array()) == 0 {
|
||||
@@ -442,3 +444,21 @@ func ensureToolsArray(body []byte) []byte {
|
||||
}
|
||||
return updated
|
||||
}
|
||||
|
||||
// applyIFlowThinkingConfig converts normalized reasoning_effort to iFlow chat_template_kwargs.enable_thinking.
|
||||
// This should be called after NormalizeThinkingConfig has processed the payload.
|
||||
// iFlow only supports boolean enable_thinking, so any non-"none" effort enables thinking.
|
||||
func applyIFlowThinkingConfig(body []byte) []byte {
|
||||
effort := gjson.GetBytes(body, "reasoning_effort")
|
||||
if !effort.Exists() {
|
||||
return body
|
||||
}
|
||||
|
||||
val := strings.ToLower(strings.TrimSpace(effort.String()))
|
||||
enableThinking := val != "none" && val != ""
|
||||
|
||||
body, _ = sjson.DeleteBytes(body, "reasoning_effort")
|
||||
body, _ = sjson.SetBytes(body, "chat_template_kwargs.enable_thinking", enableThinking)
|
||||
|
||||
return body
|
||||
}
|
||||
|
||||
@@ -60,13 +60,13 @@ func (e *OpenAICompatExecutor) Execute(ctx context.Context, auth *cliproxyauth.A
|
||||
}
|
||||
translated = applyPayloadConfigWithRoot(e.cfg, req.Model, to.String(), "", translated)
|
||||
allowCompat := e.allowCompatReasoningEffort(req.Model, auth)
|
||||
translated = applyReasoningEffortMetadata(translated, req.Metadata, req.Model, "reasoning_effort", allowCompat)
|
||||
translated = ApplyReasoningEffortMetadata(translated, req.Metadata, req.Model, "reasoning_effort", allowCompat)
|
||||
upstreamModel := util.ResolveOriginalModel(req.Model, req.Metadata)
|
||||
if upstreamModel != "" && modelOverride == "" {
|
||||
translated, _ = sjson.SetBytes(translated, "model", upstreamModel)
|
||||
}
|
||||
translated = normalizeThinkingConfig(translated, upstreamModel, allowCompat)
|
||||
if errValidate := validateThinkingConfig(translated, upstreamModel); errValidate != nil {
|
||||
translated = NormalizeThinkingConfig(translated, upstreamModel, allowCompat)
|
||||
if errValidate := ValidateThinkingConfig(translated, upstreamModel); errValidate != nil {
|
||||
return resp, errValidate
|
||||
}
|
||||
|
||||
@@ -156,13 +156,13 @@ func (e *OpenAICompatExecutor) ExecuteStream(ctx context.Context, auth *cliproxy
|
||||
}
|
||||
translated = applyPayloadConfigWithRoot(e.cfg, req.Model, to.String(), "", translated)
|
||||
allowCompat := e.allowCompatReasoningEffort(req.Model, auth)
|
||||
translated = applyReasoningEffortMetadata(translated, req.Metadata, req.Model, "reasoning_effort", allowCompat)
|
||||
translated = ApplyReasoningEffortMetadata(translated, req.Metadata, req.Model, "reasoning_effort", allowCompat)
|
||||
upstreamModel := util.ResolveOriginalModel(req.Model, req.Metadata)
|
||||
if upstreamModel != "" && modelOverride == "" {
|
||||
translated, _ = sjson.SetBytes(translated, "model", upstreamModel)
|
||||
}
|
||||
translated = normalizeThinkingConfig(translated, upstreamModel, allowCompat)
|
||||
if errValidate := validateThinkingConfig(translated, upstreamModel); errValidate != nil {
|
||||
translated = NormalizeThinkingConfig(translated, upstreamModel, allowCompat)
|
||||
if errValidate := ValidateThinkingConfig(translated, upstreamModel); errValidate != nil {
|
||||
return nil, errValidate
|
||||
}
|
||||
|
||||
|
||||
@@ -11,9 +11,9 @@ import (
|
||||
"github.com/tidwall/sjson"
|
||||
)
|
||||
|
||||
// applyThinkingMetadata applies thinking config from model suffix metadata (e.g., (high), (8192))
|
||||
// ApplyThinkingMetadata applies thinking config from model suffix metadata (e.g., (high), (8192))
|
||||
// for standard Gemini format payloads. It normalizes the budget when the model supports thinking.
|
||||
func applyThinkingMetadata(payload []byte, metadata map[string]any, model string) []byte {
|
||||
func ApplyThinkingMetadata(payload []byte, metadata map[string]any, model string) []byte {
|
||||
budgetOverride, includeOverride, ok := util.ResolveThinkingConfigFromMetadata(model, metadata)
|
||||
if !ok || (budgetOverride == nil && includeOverride == nil) {
|
||||
return payload
|
||||
@@ -45,10 +45,10 @@ func applyThinkingMetadataCLI(payload []byte, metadata map[string]any, model str
|
||||
return util.ApplyGeminiCLIThinkingConfig(payload, budgetOverride, includeOverride)
|
||||
}
|
||||
|
||||
// applyReasoningEffortMetadata applies reasoning effort overrides from metadata to the given JSON path.
|
||||
// ApplyReasoningEffortMetadata applies reasoning effort overrides from metadata to the given JSON path.
|
||||
// Metadata values take precedence over any existing field when the model supports thinking, intentionally
|
||||
// overwriting caller-provided values to honor suffix/default metadata priority.
|
||||
func applyReasoningEffortMetadata(payload []byte, metadata map[string]any, model, field string, allowCompat bool) []byte {
|
||||
func ApplyReasoningEffortMetadata(payload []byte, metadata map[string]any, model, field string, allowCompat bool) []byte {
|
||||
if len(metadata) == 0 {
|
||||
return payload
|
||||
}
|
||||
@@ -72,7 +72,7 @@ func applyReasoningEffortMetadata(payload []byte, metadata map[string]any, model
|
||||
// Fallback: numeric thinking_budget suffix for level-based (OpenAI-style) models.
|
||||
if util.ModelUsesThinkingLevels(baseModel) || allowCompat {
|
||||
if budget, _, _, matched := util.ThinkingFromMetadata(metadata); matched && budget != nil {
|
||||
if effort, ok := util.OpenAIThinkingBudgetToEffort(baseModel, *budget); ok && effort != "" {
|
||||
if effort, ok := util.ThinkingBudgetToEffort(baseModel, *budget); ok && effort != "" {
|
||||
if updated, err := sjson.SetBytes(payload, field, effort); err == nil {
|
||||
return updated
|
||||
}
|
||||
@@ -232,12 +232,12 @@ func matchModelPattern(pattern, model string) bool {
|
||||
return pi == len(pattern)
|
||||
}
|
||||
|
||||
// normalizeThinkingConfig normalizes thinking-related fields in the payload
|
||||
// NormalizeThinkingConfig normalizes thinking-related fields in the payload
|
||||
// based on model capabilities. For models without thinking support, it strips
|
||||
// reasoning fields. For models with level-based thinking, it validates and
|
||||
// normalizes the reasoning effort level. For models with numeric budget thinking,
|
||||
// it strips the effort string fields.
|
||||
func normalizeThinkingConfig(payload []byte, model string, allowCompat bool) []byte {
|
||||
func NormalizeThinkingConfig(payload []byte, model string, allowCompat bool) []byte {
|
||||
if len(payload) == 0 || model == "" {
|
||||
return payload
|
||||
}
|
||||
@@ -246,28 +246,28 @@ func normalizeThinkingConfig(payload []byte, model string, allowCompat bool) []b
|
||||
if allowCompat {
|
||||
return payload
|
||||
}
|
||||
return stripThinkingFields(payload, false)
|
||||
return StripThinkingFields(payload, false)
|
||||
}
|
||||
|
||||
if util.ModelUsesThinkingLevels(model) {
|
||||
return normalizeReasoningEffortLevel(payload, model)
|
||||
return NormalizeReasoningEffortLevel(payload, model)
|
||||
}
|
||||
|
||||
// Model supports thinking but uses numeric budgets, not levels.
|
||||
// Strip effort string fields since they are not applicable.
|
||||
return stripThinkingFields(payload, true)
|
||||
return StripThinkingFields(payload, true)
|
||||
}
|
||||
|
||||
// stripThinkingFields removes thinking-related fields from the payload for
|
||||
// StripThinkingFields removes thinking-related fields from the payload for
|
||||
// models that do not support thinking. If effortOnly is true, only removes
|
||||
// effort string fields (for models using numeric budgets).
|
||||
func stripThinkingFields(payload []byte, effortOnly bool) []byte {
|
||||
func StripThinkingFields(payload []byte, effortOnly bool) []byte {
|
||||
fieldsToRemove := []string{
|
||||
"reasoning_effort",
|
||||
"reasoning.effort",
|
||||
}
|
||||
if !effortOnly {
|
||||
fieldsToRemove = append([]string{"reasoning"}, fieldsToRemove...)
|
||||
fieldsToRemove = append([]string{"reasoning", "thinking"}, fieldsToRemove...)
|
||||
}
|
||||
out := payload
|
||||
for _, field := range fieldsToRemove {
|
||||
@@ -278,9 +278,9 @@ func stripThinkingFields(payload []byte, effortOnly bool) []byte {
|
||||
return out
|
||||
}
|
||||
|
||||
// normalizeReasoningEffortLevel validates and normalizes the reasoning_effort
|
||||
// NormalizeReasoningEffortLevel validates and normalizes the reasoning_effort
|
||||
// or reasoning.effort field for level-based thinking models.
|
||||
func normalizeReasoningEffortLevel(payload []byte, model string) []byte {
|
||||
func NormalizeReasoningEffortLevel(payload []byte, model string) []byte {
|
||||
out := payload
|
||||
|
||||
if effort := gjson.GetBytes(out, "reasoning_effort"); effort.Exists() {
|
||||
@@ -298,10 +298,10 @@ func normalizeReasoningEffortLevel(payload []byte, model string) []byte {
|
||||
return out
|
||||
}
|
||||
|
||||
// validateThinkingConfig checks for unsupported reasoning levels on level-based models.
|
||||
// ValidateThinkingConfig checks for unsupported reasoning levels on level-based models.
|
||||
// Returns a statusErr with 400 when an unsupported level is supplied to avoid silently
|
||||
// downgrading requests.
|
||||
func validateThinkingConfig(payload []byte, model string) error {
|
||||
func ValidateThinkingConfig(payload []byte, model string) error {
|
||||
if len(payload) == 0 || model == "" {
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -51,13 +51,13 @@ func (e *QwenExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, req
|
||||
from := opts.SourceFormat
|
||||
to := sdktranslator.FromString("openai")
|
||||
body := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), false)
|
||||
body = applyReasoningEffortMetadata(body, req.Metadata, req.Model, "reasoning_effort", false)
|
||||
body = ApplyReasoningEffortMetadata(body, req.Metadata, req.Model, "reasoning_effort", false)
|
||||
upstreamModel := util.ResolveOriginalModel(req.Model, req.Metadata)
|
||||
if upstreamModel != "" {
|
||||
body, _ = sjson.SetBytes(body, "model", upstreamModel)
|
||||
}
|
||||
body = normalizeThinkingConfig(body, upstreamModel, false)
|
||||
if errValidate := validateThinkingConfig(body, upstreamModel); errValidate != nil {
|
||||
body = NormalizeThinkingConfig(body, upstreamModel, false)
|
||||
if errValidate := ValidateThinkingConfig(body, upstreamModel); errValidate != nil {
|
||||
return resp, errValidate
|
||||
}
|
||||
body = applyPayloadConfig(e.cfg, req.Model, body)
|
||||
@@ -131,13 +131,13 @@ func (e *QwenExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Aut
|
||||
to := sdktranslator.FromString("openai")
|
||||
body := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), true)
|
||||
|
||||
body = applyReasoningEffortMetadata(body, req.Metadata, req.Model, "reasoning_effort", false)
|
||||
body = ApplyReasoningEffortMetadata(body, req.Metadata, req.Model, "reasoning_effort", false)
|
||||
upstreamModel := util.ResolveOriginalModel(req.Model, req.Metadata)
|
||||
if upstreamModel != "" {
|
||||
body, _ = sjson.SetBytes(body, "model", upstreamModel)
|
||||
}
|
||||
body = normalizeThinkingConfig(body, upstreamModel, false)
|
||||
if errValidate := validateThinkingConfig(body, upstreamModel); errValidate != nil {
|
||||
body = NormalizeThinkingConfig(body, upstreamModel, false)
|
||||
if errValidate := ValidateThinkingConfig(body, upstreamModel); errValidate != nil {
|
||||
return nil, errValidate
|
||||
}
|
||||
toolsResult := gjson.GetBytes(body, "tools")
|
||||
|
||||
@@ -7,10 +7,8 @@ package claude
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"strings"
|
||||
|
||||
client "github.com/router-for-me/CLIProxyAPI/v6/internal/interfaces"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/translator/gemini/common"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
|
||||
"github.com/tidwall/gjson"
|
||||
@@ -42,27 +40,30 @@ func ConvertClaudeRequestToAntigravity(modelName string, inputRawJSON []byte, _
|
||||
rawJSON = bytes.Replace(rawJSON, []byte(`"url":{"type":"string","format":"uri",`), []byte(`"url":{"type":"string",`), -1)
|
||||
|
||||
// system instruction
|
||||
var systemInstruction *client.Content
|
||||
systemInstructionJSON := ""
|
||||
hasSystemInstruction := false
|
||||
systemResult := gjson.GetBytes(rawJSON, "system")
|
||||
if systemResult.IsArray() {
|
||||
systemResults := systemResult.Array()
|
||||
systemInstruction = &client.Content{Role: "user", Parts: []client.Part{}}
|
||||
systemInstructionJSON = `{"role":"user","parts":[]}`
|
||||
for i := 0; i < len(systemResults); i++ {
|
||||
systemPromptResult := systemResults[i]
|
||||
systemTypePromptResult := systemPromptResult.Get("type")
|
||||
if systemTypePromptResult.Type == gjson.String && systemTypePromptResult.String() == "text" {
|
||||
systemPrompt := systemPromptResult.Get("text").String()
|
||||
systemPart := client.Part{Text: systemPrompt}
|
||||
systemInstruction.Parts = append(systemInstruction.Parts, systemPart)
|
||||
partJSON := `{}`
|
||||
if systemPrompt != "" {
|
||||
partJSON, _ = sjson.Set(partJSON, "text", systemPrompt)
|
||||
}
|
||||
systemInstructionJSON, _ = sjson.SetRaw(systemInstructionJSON, "parts.-1", partJSON)
|
||||
hasSystemInstruction = true
|
||||
}
|
||||
}
|
||||
if len(systemInstruction.Parts) == 0 {
|
||||
systemInstruction = nil
|
||||
}
|
||||
}
|
||||
|
||||
// contents
|
||||
contents := make([]client.Content, 0)
|
||||
contentsJSON := "[]"
|
||||
hasContents := false
|
||||
messagesResult := gjson.GetBytes(rawJSON, "messages")
|
||||
if messagesResult.IsArray() {
|
||||
messageResults := messagesResult.Array()
|
||||
@@ -76,7 +77,8 @@ func ConvertClaudeRequestToAntigravity(modelName string, inputRawJSON []byte, _
|
||||
if role == "assistant" {
|
||||
role = "model"
|
||||
}
|
||||
clientContent := client.Content{Role: role, Parts: []client.Part{}}
|
||||
clientContentJSON := `{"role":"","parts":[]}`
|
||||
clientContentJSON, _ = sjson.Set(clientContentJSON, "role", role)
|
||||
contentsResult := messageResult.Get("content")
|
||||
if contentsResult.IsArray() {
|
||||
contentResults := contentsResult.Array()
|
||||
@@ -90,25 +92,39 @@ func ConvertClaudeRequestToAntigravity(modelName string, inputRawJSON []byte, _
|
||||
if signatureResult.Exists() {
|
||||
signature = signatureResult.String()
|
||||
}
|
||||
clientContent.Parts = append(clientContent.Parts, client.Part{Text: prompt, Thought: true, ThoughtSignature: signature})
|
||||
partJSON := `{}`
|
||||
partJSON, _ = sjson.Set(partJSON, "thought", true)
|
||||
if prompt != "" {
|
||||
partJSON, _ = sjson.Set(partJSON, "text", prompt)
|
||||
}
|
||||
if signature != "" {
|
||||
partJSON, _ = sjson.Set(partJSON, "thoughtSignature", signature)
|
||||
}
|
||||
clientContentJSON, _ = sjson.SetRaw(clientContentJSON, "parts.-1", partJSON)
|
||||
} else if contentTypeResult.Type == gjson.String && contentTypeResult.String() == "text" {
|
||||
prompt := contentResult.Get("text").String()
|
||||
clientContent.Parts = append(clientContent.Parts, client.Part{Text: prompt})
|
||||
partJSON := `{}`
|
||||
if prompt != "" {
|
||||
partJSON, _ = sjson.Set(partJSON, "text", prompt)
|
||||
}
|
||||
clientContentJSON, _ = sjson.SetRaw(clientContentJSON, "parts.-1", partJSON)
|
||||
} else if contentTypeResult.Type == gjson.String && contentTypeResult.String() == "tool_use" {
|
||||
functionName := contentResult.Get("name").String()
|
||||
functionArgs := contentResult.Get("input").String()
|
||||
functionID := contentResult.Get("id").String()
|
||||
var args map[string]any
|
||||
if err := json.Unmarshal([]byte(functionArgs), &args); err == nil {
|
||||
if strings.Contains(modelName, "claude") {
|
||||
clientContent.Parts = append(clientContent.Parts, client.Part{
|
||||
FunctionCall: &client.FunctionCall{ID: functionID, Name: functionName, Args: args},
|
||||
})
|
||||
} else {
|
||||
clientContent.Parts = append(clientContent.Parts, client.Part{
|
||||
FunctionCall: &client.FunctionCall{ID: functionID, Name: functionName, Args: args},
|
||||
ThoughtSignature: geminiCLIClaudeThoughtSignature,
|
||||
})
|
||||
if gjson.Valid(functionArgs) {
|
||||
argsResult := gjson.Parse(functionArgs)
|
||||
if argsResult.IsObject() {
|
||||
partJSON := `{}`
|
||||
if !strings.Contains(modelName, "claude") {
|
||||
partJSON, _ = sjson.Set(partJSON, "thoughtSignature", geminiCLIClaudeThoughtSignature)
|
||||
}
|
||||
if functionID != "" {
|
||||
partJSON, _ = sjson.Set(partJSON, "functionCall.id", functionID)
|
||||
}
|
||||
partJSON, _ = sjson.Set(partJSON, "functionCall.name", functionName)
|
||||
partJSON, _ = sjson.SetRaw(partJSON, "functionCall.args", argsResult.Raw)
|
||||
clientContentJSON, _ = sjson.SetRaw(clientContentJSON, "parts.-1", partJSON)
|
||||
}
|
||||
}
|
||||
} else if contentTypeResult.Type == gjson.String && contentTypeResult.String() == "tool_result" {
|
||||
@@ -117,37 +133,74 @@ func ConvertClaudeRequestToAntigravity(modelName string, inputRawJSON []byte, _
|
||||
funcName := toolCallID
|
||||
toolCallIDs := strings.Split(toolCallID, "-")
|
||||
if len(toolCallIDs) > 1 {
|
||||
funcName = strings.Join(toolCallIDs[0:len(toolCallIDs)-1], "-")
|
||||
funcName = strings.Join(toolCallIDs[0:len(toolCallIDs)-2], "-")
|
||||
}
|
||||
responseData := contentResult.Get("content").Raw
|
||||
functionResponse := client.FunctionResponse{ID: toolCallID, Name: funcName, Response: map[string]interface{}{"result": responseData}}
|
||||
clientContent.Parts = append(clientContent.Parts, client.Part{FunctionResponse: &functionResponse})
|
||||
functionResponseResult := contentResult.Get("content")
|
||||
|
||||
functionResponseJSON := `{}`
|
||||
functionResponseJSON, _ = sjson.Set(functionResponseJSON, "id", toolCallID)
|
||||
functionResponseJSON, _ = sjson.Set(functionResponseJSON, "name", funcName)
|
||||
|
||||
responseData := ""
|
||||
if functionResponseResult.Type == gjson.String {
|
||||
responseData = functionResponseResult.String()
|
||||
functionResponseJSON, _ = sjson.Set(functionResponseJSON, "response.result", responseData)
|
||||
} else if functionResponseResult.IsArray() {
|
||||
frResults := functionResponseResult.Array()
|
||||
if len(frResults) == 1 {
|
||||
functionResponseJSON, _ = sjson.SetRaw(functionResponseJSON, "response.result", frResults[0].Raw)
|
||||
} else {
|
||||
functionResponseJSON, _ = sjson.SetRaw(functionResponseJSON, "response.result", functionResponseResult.Raw)
|
||||
}
|
||||
|
||||
} else if functionResponseResult.IsObject() {
|
||||
functionResponseJSON, _ = sjson.SetRaw(functionResponseJSON, "response.result", functionResponseResult.Raw)
|
||||
} else {
|
||||
functionResponseJSON, _ = sjson.SetRaw(functionResponseJSON, "response.result", functionResponseResult.Raw)
|
||||
}
|
||||
|
||||
partJSON := `{}`
|
||||
partJSON, _ = sjson.SetRaw(partJSON, "functionResponse", functionResponseJSON)
|
||||
clientContentJSON, _ = sjson.SetRaw(clientContentJSON, "parts.-1", partJSON)
|
||||
}
|
||||
} else if contentTypeResult.Type == gjson.String && contentTypeResult.String() == "image" {
|
||||
sourceResult := contentResult.Get("source")
|
||||
if sourceResult.Get("type").String() == "base64" {
|
||||
inlineData := &client.InlineData{
|
||||
MimeType: sourceResult.Get("media_type").String(),
|
||||
Data: sourceResult.Get("data").String(),
|
||||
inlineDataJSON := `{}`
|
||||
if mimeType := sourceResult.Get("media_type").String(); mimeType != "" {
|
||||
inlineDataJSON, _ = sjson.Set(inlineDataJSON, "mime_type", mimeType)
|
||||
}
|
||||
clientContent.Parts = append(clientContent.Parts, client.Part{InlineData: inlineData})
|
||||
if data := sourceResult.Get("data").String(); data != "" {
|
||||
inlineDataJSON, _ = sjson.Set(inlineDataJSON, "data", data)
|
||||
}
|
||||
|
||||
partJSON := `{}`
|
||||
partJSON, _ = sjson.SetRaw(partJSON, "inlineData", inlineDataJSON)
|
||||
clientContentJSON, _ = sjson.SetRaw(clientContentJSON, "parts.-1", partJSON)
|
||||
}
|
||||
}
|
||||
}
|
||||
contents = append(contents, clientContent)
|
||||
contentsJSON, _ = sjson.SetRaw(contentsJSON, "-1", clientContentJSON)
|
||||
hasContents = true
|
||||
} else if contentsResult.Type == gjson.String {
|
||||
prompt := contentsResult.String()
|
||||
contents = append(contents, client.Content{Role: role, Parts: []client.Part{{Text: prompt}}})
|
||||
partJSON := `{}`
|
||||
if prompt != "" {
|
||||
partJSON, _ = sjson.Set(partJSON, "text", prompt)
|
||||
}
|
||||
clientContentJSON, _ = sjson.SetRaw(clientContentJSON, "parts.-1", partJSON)
|
||||
contentsJSON, _ = sjson.SetRaw(contentsJSON, "-1", clientContentJSON)
|
||||
hasContents = true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// tools
|
||||
var tools []client.ToolDeclaration
|
||||
toolsJSON := ""
|
||||
toolDeclCount := 0
|
||||
toolsResult := gjson.GetBytes(rawJSON, "tools")
|
||||
if toolsResult.IsArray() {
|
||||
tools = make([]client.ToolDeclaration, 1)
|
||||
tools[0].FunctionDeclarations = make([]any, 0)
|
||||
toolsJSON = `[{"functionDeclarations":[]}]`
|
||||
toolsResults := toolsResult.Array()
|
||||
for i := 0; i < len(toolsResults); i++ {
|
||||
toolResult := toolsResults[i]
|
||||
@@ -158,30 +211,23 @@ func ConvertClaudeRequestToAntigravity(modelName string, inputRawJSON []byte, _
|
||||
tool, _ = sjson.SetRaw(tool, "parametersJsonSchema", inputSchema)
|
||||
tool, _ = sjson.Delete(tool, "strict")
|
||||
tool, _ = sjson.Delete(tool, "input_examples")
|
||||
var toolDeclaration any
|
||||
if err := json.Unmarshal([]byte(tool), &toolDeclaration); err == nil {
|
||||
tools[0].FunctionDeclarations = append(tools[0].FunctionDeclarations, toolDeclaration)
|
||||
}
|
||||
toolsJSON, _ = sjson.SetRaw(toolsJSON, "0.functionDeclarations.-1", tool)
|
||||
toolDeclCount++
|
||||
}
|
||||
}
|
||||
} else {
|
||||
tools = make([]client.ToolDeclaration, 0)
|
||||
}
|
||||
|
||||
// Build output Gemini CLI request JSON
|
||||
out := `{"model":"","request":{"contents":[]}}`
|
||||
out, _ = sjson.Set(out, "model", modelName)
|
||||
if systemInstruction != nil {
|
||||
b, _ := json.Marshal(systemInstruction)
|
||||
out, _ = sjson.SetRaw(out, "request.systemInstruction", string(b))
|
||||
if hasSystemInstruction {
|
||||
out, _ = sjson.SetRaw(out, "request.systemInstruction", systemInstructionJSON)
|
||||
}
|
||||
if len(contents) > 0 {
|
||||
b, _ := json.Marshal(contents)
|
||||
out, _ = sjson.SetRaw(out, "request.contents", string(b))
|
||||
if hasContents {
|
||||
out, _ = sjson.SetRaw(out, "request.contents", contentsJSON)
|
||||
}
|
||||
if len(tools) > 0 && len(tools[0].FunctionDeclarations) > 0 {
|
||||
b, _ := json.Marshal(tools)
|
||||
out, _ = sjson.SetRaw(out, "request.tools", string(b))
|
||||
if toolDeclCount > 0 {
|
||||
out, _ = sjson.SetRaw(out, "request.tools", toolsJSON)
|
||||
}
|
||||
|
||||
// Map Anthropic thinking -> Gemini thinkingBudget/include_thoughts when type==enabled
|
||||
|
||||
@@ -9,7 +9,6 @@ package claude
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"strings"
|
||||
"sync/atomic"
|
||||
@@ -350,24 +349,25 @@ func ConvertAntigravityResponseToClaudeNonStream(_ context.Context, _ string, or
|
||||
}
|
||||
}
|
||||
|
||||
response := map[string]interface{}{
|
||||
"id": root.Get("response.responseId").String(),
|
||||
"type": "message",
|
||||
"role": "assistant",
|
||||
"model": root.Get("response.modelVersion").String(),
|
||||
"content": []interface{}{},
|
||||
"stop_reason": nil,
|
||||
"stop_sequence": nil,
|
||||
"usage": map[string]interface{}{
|
||||
"input_tokens": promptTokens,
|
||||
"output_tokens": outputTokens,
|
||||
},
|
||||
responseJSON := `{"id":"","type":"message","role":"assistant","model":"","content":null,"stop_reason":null,"stop_sequence":null,"usage":{"input_tokens":0,"output_tokens":0}}`
|
||||
responseJSON, _ = sjson.Set(responseJSON, "id", root.Get("response.responseId").String())
|
||||
responseJSON, _ = sjson.Set(responseJSON, "model", root.Get("response.modelVersion").String())
|
||||
responseJSON, _ = sjson.Set(responseJSON, "usage.input_tokens", promptTokens)
|
||||
responseJSON, _ = sjson.Set(responseJSON, "usage.output_tokens", outputTokens)
|
||||
|
||||
contentArrayInitialized := false
|
||||
ensureContentArray := func() {
|
||||
if contentArrayInitialized {
|
||||
return
|
||||
}
|
||||
responseJSON, _ = sjson.SetRaw(responseJSON, "content", "[]")
|
||||
contentArrayInitialized = true
|
||||
}
|
||||
|
||||
parts := root.Get("response.candidates.0.content.parts")
|
||||
var contentBlocks []interface{}
|
||||
textBuilder := strings.Builder{}
|
||||
thinkingBuilder := strings.Builder{}
|
||||
thinkingSignature := ""
|
||||
toolIDCounter := 0
|
||||
hasToolCall := false
|
||||
|
||||
@@ -375,28 +375,43 @@ func ConvertAntigravityResponseToClaudeNonStream(_ context.Context, _ string, or
|
||||
if textBuilder.Len() == 0 {
|
||||
return
|
||||
}
|
||||
contentBlocks = append(contentBlocks, map[string]interface{}{
|
||||
"type": "text",
|
||||
"text": textBuilder.String(),
|
||||
})
|
||||
ensureContentArray()
|
||||
block := `{"type":"text","text":""}`
|
||||
block, _ = sjson.Set(block, "text", textBuilder.String())
|
||||
responseJSON, _ = sjson.SetRaw(responseJSON, "content.-1", block)
|
||||
textBuilder.Reset()
|
||||
}
|
||||
|
||||
flushThinking := func() {
|
||||
if thinkingBuilder.Len() == 0 {
|
||||
if thinkingBuilder.Len() == 0 && thinkingSignature == "" {
|
||||
return
|
||||
}
|
||||
contentBlocks = append(contentBlocks, map[string]interface{}{
|
||||
"type": "thinking",
|
||||
"thinking": thinkingBuilder.String(),
|
||||
})
|
||||
ensureContentArray()
|
||||
block := `{"type":"thinking","thinking":""}`
|
||||
block, _ = sjson.Set(block, "thinking", thinkingBuilder.String())
|
||||
if thinkingSignature != "" {
|
||||
block, _ = sjson.Set(block, "signature", thinkingSignature)
|
||||
}
|
||||
responseJSON, _ = sjson.SetRaw(responseJSON, "content.-1", block)
|
||||
thinkingBuilder.Reset()
|
||||
thinkingSignature = ""
|
||||
}
|
||||
|
||||
if parts.IsArray() {
|
||||
for _, part := range parts.Array() {
|
||||
isThought := part.Get("thought").Bool()
|
||||
if isThought {
|
||||
sig := part.Get("thoughtSignature")
|
||||
if !sig.Exists() {
|
||||
sig = part.Get("thought_signature")
|
||||
}
|
||||
if sig.Exists() && sig.String() != "" {
|
||||
thinkingSignature = sig.String()
|
||||
}
|
||||
}
|
||||
|
||||
if text := part.Get("text"); text.Exists() && text.String() != "" {
|
||||
if part.Get("thought").Bool() {
|
||||
if isThought {
|
||||
flushText()
|
||||
thinkingBuilder.WriteString(text.String())
|
||||
continue
|
||||
@@ -413,21 +428,16 @@ func ConvertAntigravityResponseToClaudeNonStream(_ context.Context, _ string, or
|
||||
|
||||
name := functionCall.Get("name").String()
|
||||
toolIDCounter++
|
||||
toolBlock := map[string]interface{}{
|
||||
"type": "tool_use",
|
||||
"id": fmt.Sprintf("tool_%d", toolIDCounter),
|
||||
"name": name,
|
||||
"input": map[string]interface{}{},
|
||||
toolBlock := `{"type":"tool_use","id":"","name":"","input":{}}`
|
||||
toolBlock, _ = sjson.Set(toolBlock, "id", fmt.Sprintf("tool_%d", toolIDCounter))
|
||||
toolBlock, _ = sjson.Set(toolBlock, "name", name)
|
||||
|
||||
if args := functionCall.Get("args"); args.Exists() && args.Raw != "" && gjson.Valid(args.Raw) {
|
||||
toolBlock, _ = sjson.SetRaw(toolBlock, "input", args.Raw)
|
||||
}
|
||||
|
||||
if args := functionCall.Get("args"); args.Exists() {
|
||||
var parsed interface{}
|
||||
if err := json.Unmarshal([]byte(args.Raw), &parsed); err == nil {
|
||||
toolBlock["input"] = parsed
|
||||
}
|
||||
}
|
||||
|
||||
contentBlocks = append(contentBlocks, toolBlock)
|
||||
ensureContentArray()
|
||||
responseJSON, _ = sjson.SetRaw(responseJSON, "content.-1", toolBlock)
|
||||
continue
|
||||
}
|
||||
}
|
||||
@@ -436,8 +446,6 @@ func ConvertAntigravityResponseToClaudeNonStream(_ context.Context, _ string, or
|
||||
flushThinking()
|
||||
flushText()
|
||||
|
||||
response["content"] = contentBlocks
|
||||
|
||||
stopReason := "end_turn"
|
||||
if hasToolCall {
|
||||
stopReason = "tool_use"
|
||||
@@ -453,19 +461,15 @@ func ConvertAntigravityResponseToClaudeNonStream(_ context.Context, _ string, or
|
||||
}
|
||||
}
|
||||
}
|
||||
response["stop_reason"] = stopReason
|
||||
responseJSON, _ = sjson.Set(responseJSON, "stop_reason", stopReason)
|
||||
|
||||
if usage := response["usage"].(map[string]interface{}); usage["input_tokens"] == int64(0) && usage["output_tokens"] == int64(0) {
|
||||
if promptTokens == 0 && outputTokens == 0 {
|
||||
if usageMeta := root.Get("response.usageMetadata"); !usageMeta.Exists() {
|
||||
delete(response, "usage")
|
||||
responseJSON, _ = sjson.Delete(responseJSON, "usage")
|
||||
}
|
||||
}
|
||||
|
||||
encoded, err := json.Marshal(response)
|
||||
if err != nil {
|
||||
return ""
|
||||
}
|
||||
return string(encoded)
|
||||
return responseJSON
|
||||
}
|
||||
|
||||
func ClaudeTokenCount(ctx context.Context, count int64) string {
|
||||
|
||||
@@ -122,6 +122,38 @@ type FunctionCallGroup struct {
|
||||
ResponsesNeeded int
|
||||
}
|
||||
|
||||
// parseFunctionResponse attempts to unmarshal a function response part.
|
||||
// Falls back to gjson extraction if standard json.Unmarshal fails.
|
||||
func parseFunctionResponse(response gjson.Result) map[string]interface{} {
|
||||
var responseMap map[string]interface{}
|
||||
err := json.Unmarshal([]byte(response.Raw), &responseMap)
|
||||
if err == nil {
|
||||
return responseMap
|
||||
}
|
||||
|
||||
log.Debugf("unmarshal function response failed, using fallback: %v", err)
|
||||
funcResp := response.Get("functionResponse")
|
||||
if funcResp.Exists() {
|
||||
fr := map[string]interface{}{
|
||||
"name": funcResp.Get("name").String(),
|
||||
"response": map[string]interface{}{
|
||||
"result": funcResp.Get("response").String(),
|
||||
},
|
||||
}
|
||||
if id := funcResp.Get("id").String(); id != "" {
|
||||
fr["id"] = id
|
||||
}
|
||||
return map[string]interface{}{"functionResponse": fr}
|
||||
}
|
||||
|
||||
return map[string]interface{}{
|
||||
"functionResponse": map[string]interface{}{
|
||||
"name": "unknown",
|
||||
"response": map[string]interface{}{"result": response.String()},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// fixCLIToolResponse performs sophisticated tool response format conversion and grouping.
|
||||
// This function transforms the CLI tool response format by intelligently grouping function calls
|
||||
// with their corresponding responses, ensuring proper conversation flow and API compatibility.
|
||||
@@ -180,13 +212,7 @@ func fixCLIToolResponse(input string) (string, error) {
|
||||
// Create merged function response content
|
||||
var responseParts []interface{}
|
||||
for _, response := range groupResponses {
|
||||
var responseMap map[string]interface{}
|
||||
errUnmarshal := json.Unmarshal([]byte(response.Raw), &responseMap)
|
||||
if errUnmarshal != nil {
|
||||
log.Warnf("failed to unmarshal function response: %v\n", errUnmarshal)
|
||||
continue
|
||||
}
|
||||
responseParts = append(responseParts, responseMap)
|
||||
responseParts = append(responseParts, parseFunctionResponse(response))
|
||||
}
|
||||
|
||||
if len(responseParts) > 0 {
|
||||
@@ -265,13 +291,7 @@ func fixCLIToolResponse(input string) (string, error) {
|
||||
|
||||
var responseParts []interface{}
|
||||
for _, response := range groupResponses {
|
||||
var responseMap map[string]interface{}
|
||||
errUnmarshal := json.Unmarshal([]byte(response.Raw), &responseMap)
|
||||
if errUnmarshal != nil {
|
||||
log.Warnf("failed to unmarshal function response: %v\n", errUnmarshal)
|
||||
continue
|
||||
}
|
||||
responseParts = append(responseParts, responseMap)
|
||||
responseParts = append(responseParts, parseFunctionResponse(response))
|
||||
}
|
||||
|
||||
if len(responseParts) > 0 {
|
||||
|
||||
@@ -40,30 +40,27 @@ func ConvertOpenAIRequestToAntigravity(modelName string, inputRawJSON []byte, _
|
||||
re := gjson.GetBytes(rawJSON, "reasoning_effort")
|
||||
hasOfficialThinking := re.Exists()
|
||||
if hasOfficialThinking && util.ModelSupportsThinking(modelName) {
|
||||
switch re.String() {
|
||||
case "none":
|
||||
out, _ = sjson.DeleteBytes(out, "request.generationConfig.thinkingConfig.include_thoughts")
|
||||
out, _ = sjson.SetBytes(out, "request.generationConfig.thinkingConfig.thinkingBudget", 0)
|
||||
case "auto":
|
||||
out, _ = sjson.SetBytes(out, "request.generationConfig.thinkingConfig.thinkingBudget", -1)
|
||||
out, _ = sjson.SetBytes(out, "request.generationConfig.thinkingConfig.include_thoughts", true)
|
||||
case "low":
|
||||
out, _ = sjson.SetBytes(out, "request.generationConfig.thinkingConfig.thinkingBudget", 1024)
|
||||
out, _ = sjson.SetBytes(out, "request.generationConfig.thinkingConfig.include_thoughts", true)
|
||||
case "medium":
|
||||
out, _ = sjson.SetBytes(out, "request.generationConfig.thinkingConfig.thinkingBudget", 8192)
|
||||
out, _ = sjson.SetBytes(out, "request.generationConfig.thinkingConfig.include_thoughts", true)
|
||||
case "high":
|
||||
out, _ = sjson.SetBytes(out, "request.generationConfig.thinkingConfig.thinkingBudget", 32768)
|
||||
out, _ = sjson.SetBytes(out, "request.generationConfig.thinkingConfig.include_thoughts", true)
|
||||
default:
|
||||
out, _ = sjson.SetBytes(out, "request.generationConfig.thinkingConfig.thinkingBudget", -1)
|
||||
out, _ = sjson.SetBytes(out, "request.generationConfig.thinkingConfig.include_thoughts", true)
|
||||
effort := strings.ToLower(strings.TrimSpace(re.String()))
|
||||
if util.IsGemini3Model(modelName) {
|
||||
switch effort {
|
||||
case "none":
|
||||
out, _ = sjson.DeleteBytes(out, "request.generationConfig.thinkingConfig")
|
||||
case "auto":
|
||||
includeThoughts := true
|
||||
out = util.ApplyGeminiCLIThinkingLevel(out, "", &includeThoughts)
|
||||
default:
|
||||
if level, ok := util.ValidateGemini3ThinkingLevel(modelName, effort); ok {
|
||||
out = util.ApplyGeminiCLIThinkingLevel(out, level, nil)
|
||||
}
|
||||
}
|
||||
} else if !util.ModelUsesThinkingLevels(modelName) {
|
||||
out = util.ApplyReasoningEffortToGeminiCLI(out, effort)
|
||||
}
|
||||
}
|
||||
|
||||
// Cherry Studio extension extra_body.google.thinking_config (effective only when official fields are absent)
|
||||
if !hasOfficialThinking && util.ModelSupportsThinking(modelName) {
|
||||
// Only apply for models that use numeric budgets, not discrete levels.
|
||||
if !hasOfficialThinking && util.ModelSupportsThinking(modelName) && !util.ModelUsesThinkingLevels(modelName) {
|
||||
if tc := gjson.GetBytes(rawJSON, "extra_body.google.thinking_config"); tc.Exists() && tc.IsObject() {
|
||||
var setBudget bool
|
||||
var budget int
|
||||
@@ -240,62 +237,61 @@ func ConvertOpenAIRequestToAntigravity(modelName string, inputRawJSON []byte, _
|
||||
}
|
||||
out, _ = sjson.SetRawBytes(out, "request.contents.-1", node)
|
||||
} else if role == "assistant" {
|
||||
node := []byte(`{"role":"model","parts":[]}`)
|
||||
p := 0
|
||||
if content.Type == gjson.String {
|
||||
// Assistant text -> single model content
|
||||
node := []byte(`{"role":"model","parts":[{"text":""}]}`)
|
||||
node, _ = sjson.SetBytes(node, "parts.0.text", content.String())
|
||||
node, _ = sjson.SetBytes(node, "parts.-1.text", content.String())
|
||||
out, _ = sjson.SetRawBytes(out, "request.contents.-1", node)
|
||||
} else if !content.Exists() || content.Type == gjson.Null {
|
||||
// Tool calls -> single model content with functionCall parts
|
||||
tcs := m.Get("tool_calls")
|
||||
if tcs.IsArray() {
|
||||
node := []byte(`{"role":"model","parts":[]}`)
|
||||
p := 0
|
||||
fIDs := make([]string, 0)
|
||||
for _, tc := range tcs.Array() {
|
||||
if tc.Get("type").String() != "function" {
|
||||
continue
|
||||
}
|
||||
fid := tc.Get("id").String()
|
||||
fname := tc.Get("function.name").String()
|
||||
fargs := tc.Get("function.arguments").String()
|
||||
node, _ = sjson.SetBytes(node, "parts."+itoa(p)+".functionCall.id", fid)
|
||||
node, _ = sjson.SetBytes(node, "parts."+itoa(p)+".functionCall.name", fname)
|
||||
node, _ = sjson.SetRawBytes(node, "parts."+itoa(p)+".functionCall.args", []byte(fargs))
|
||||
node, _ = sjson.SetBytes(node, "parts."+itoa(p)+".thoughtSignature", geminiCLIFunctionThoughtSignature)
|
||||
p++
|
||||
if fid != "" {
|
||||
fIDs = append(fIDs, fid)
|
||||
}
|
||||
}
|
||||
out, _ = sjson.SetRawBytes(out, "request.contents.-1", node)
|
||||
p++
|
||||
}
|
||||
|
||||
// Append a single tool content combining name + response per function
|
||||
toolNode := []byte(`{"role":"user","parts":[]}`)
|
||||
pp := 0
|
||||
for _, fid := range fIDs {
|
||||
if name, ok := tcID2Name[fid]; ok {
|
||||
toolNode, _ = sjson.SetBytes(toolNode, "parts."+itoa(pp)+".functionResponse.id", fid)
|
||||
toolNode, _ = sjson.SetBytes(toolNode, "parts."+itoa(pp)+".functionResponse.name", name)
|
||||
resp := toolResponses[fid]
|
||||
if resp == "" {
|
||||
resp = "{}"
|
||||
}
|
||||
// Handle non-JSON output gracefully (matches dev branch approach)
|
||||
if resp != "null" {
|
||||
parsed := gjson.Parse(resp)
|
||||
if parsed.Type == gjson.JSON {
|
||||
toolNode, _ = sjson.SetRawBytes(toolNode, "parts."+itoa(pp)+".functionResponse.response.result", []byte(parsed.Raw))
|
||||
} else {
|
||||
toolNode, _ = sjson.SetBytes(toolNode, "parts."+itoa(pp)+".functionResponse.response.result", resp)
|
||||
}
|
||||
}
|
||||
pp++
|
||||
// Tool calls -> single model content with functionCall parts
|
||||
tcs := m.Get("tool_calls")
|
||||
if tcs.IsArray() {
|
||||
fIDs := make([]string, 0)
|
||||
for _, tc := range tcs.Array() {
|
||||
if tc.Get("type").String() != "function" {
|
||||
continue
|
||||
}
|
||||
fid := tc.Get("id").String()
|
||||
fname := tc.Get("function.name").String()
|
||||
fargs := tc.Get("function.arguments").String()
|
||||
node, _ = sjson.SetBytes(node, "parts."+itoa(p)+".functionCall.id", fid)
|
||||
node, _ = sjson.SetBytes(node, "parts."+itoa(p)+".functionCall.name", fname)
|
||||
node, _ = sjson.SetRawBytes(node, "parts."+itoa(p)+".functionCall.args", []byte(fargs))
|
||||
node, _ = sjson.SetBytes(node, "parts."+itoa(p)+".thoughtSignature", geminiCLIFunctionThoughtSignature)
|
||||
p++
|
||||
if fid != "" {
|
||||
fIDs = append(fIDs, fid)
|
||||
}
|
||||
}
|
||||
out, _ = sjson.SetRawBytes(out, "request.contents.-1", node)
|
||||
|
||||
// Append a single tool content combining name + response per function
|
||||
toolNode := []byte(`{"role":"user","parts":[]}`)
|
||||
pp := 0
|
||||
for _, fid := range fIDs {
|
||||
if name, ok := tcID2Name[fid]; ok {
|
||||
toolNode, _ = sjson.SetBytes(toolNode, "parts."+itoa(pp)+".functionResponse.id", fid)
|
||||
toolNode, _ = sjson.SetBytes(toolNode, "parts."+itoa(pp)+".functionResponse.name", name)
|
||||
resp := toolResponses[fid]
|
||||
if resp == "" {
|
||||
resp = "{}"
|
||||
}
|
||||
// Handle non-JSON output gracefully (matches dev branch approach)
|
||||
if resp != "null" {
|
||||
parsed := gjson.Parse(resp)
|
||||
if parsed.Type == gjson.JSON {
|
||||
toolNode, _ = sjson.SetRawBytes(toolNode, "parts."+itoa(pp)+".functionResponse.response.result", []byte(parsed.Raw))
|
||||
} else {
|
||||
toolNode, _ = sjson.SetBytes(toolNode, "parts."+itoa(pp)+".functionResponse.response.result", resp)
|
||||
}
|
||||
}
|
||||
pp++
|
||||
}
|
||||
if pp > 0 {
|
||||
out, _ = sjson.SetRawBytes(out, "request.contents.-1", toolNode)
|
||||
}
|
||||
}
|
||||
if pp > 0 {
|
||||
out, _ = sjson.SetRawBytes(out, "request.contents.-1", toolNode)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -379,18 +375,3 @@ func ConvertOpenAIRequestToAntigravity(modelName string, inputRawJSON []byte, _
|
||||
|
||||
// itoa converts int to string without strconv import for few usages.
|
||||
func itoa(i int) string { return fmt.Sprintf("%d", i) }
|
||||
|
||||
// quoteIfNeeded ensures a string is valid JSON value (quotes plain text), pass-through for JSON objects/arrays.
|
||||
func quoteIfNeeded(s string) string {
|
||||
s = strings.TrimSpace(s)
|
||||
if s == "" {
|
||||
return "\"\""
|
||||
}
|
||||
if len(s) > 0 && (s[0] == '{' || s[0] == '[') {
|
||||
return s
|
||||
}
|
||||
// escape quotes minimally
|
||||
s = strings.ReplaceAll(s, "\\", "\\\\")
|
||||
s = strings.ReplaceAll(s, "\"", "\\\"")
|
||||
return "\"" + s + "\""
|
||||
}
|
||||
|
||||
@@ -114,14 +114,16 @@ func ConvertGeminiRequestToClaude(modelName string, inputRawJSON []byte, stream
|
||||
}
|
||||
}
|
||||
// Include thoughts configuration for reasoning process visibility
|
||||
if thinkingConfig := genConfig.Get("thinkingConfig"); thinkingConfig.Exists() && thinkingConfig.IsObject() {
|
||||
if includeThoughts := thinkingConfig.Get("include_thoughts"); includeThoughts.Exists() {
|
||||
if includeThoughts.Type == gjson.True {
|
||||
out, _ = sjson.Set(out, "thinking.type", "enabled")
|
||||
if thinkingBudget := thinkingConfig.Get("thinkingBudget"); thinkingBudget.Exists() {
|
||||
out, _ = sjson.Set(out, "thinking.budget_tokens", thinkingBudget.Int())
|
||||
}
|
||||
}
|
||||
// Only apply for models that support thinking and use numeric budgets, not discrete levels.
|
||||
if thinkingConfig := genConfig.Get("thinkingConfig"); thinkingConfig.Exists() && thinkingConfig.IsObject() && util.ModelSupportsThinking(modelName) && !util.ModelUsesThinkingLevels(modelName) {
|
||||
// Check for thinkingBudget first - if present, enable thinking with budget
|
||||
if thinkingBudget := thinkingConfig.Get("thinkingBudget"); thinkingBudget.Exists() && thinkingBudget.Int() > 0 {
|
||||
out, _ = sjson.Set(out, "thinking.type", "enabled")
|
||||
normalizedBudget := util.NormalizeThinkingBudget(modelName, int(thinkingBudget.Int()))
|
||||
out, _ = sjson.Set(out, "thinking.budget_tokens", normalizedBudget)
|
||||
} else if includeThoughts := thinkingConfig.Get("include_thoughts"); includeThoughts.Exists() && includeThoughts.Type == gjson.True {
|
||||
// Fallback to include_thoughts if no budget specified
|
||||
out, _ = sjson.Set(out, "thinking.type", "enabled")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -16,6 +16,7 @@ import (
|
||||
"strings"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
|
||||
"github.com/tidwall/gjson"
|
||||
"github.com/tidwall/sjson"
|
||||
)
|
||||
@@ -65,18 +66,23 @@ func ConvertOpenAIRequestToClaude(modelName string, inputRawJSON []byte, stream
|
||||
|
||||
root := gjson.ParseBytes(rawJSON)
|
||||
|
||||
if v := root.Get("reasoning_effort"); v.Exists() {
|
||||
out, _ = sjson.Set(out, "thinking.type", "enabled")
|
||||
|
||||
switch v.String() {
|
||||
case "none":
|
||||
out, _ = sjson.Set(out, "thinking.type", "disabled")
|
||||
case "low":
|
||||
out, _ = sjson.Set(out, "thinking.budget_tokens", 1024)
|
||||
case "medium":
|
||||
out, _ = sjson.Set(out, "thinking.budget_tokens", 8192)
|
||||
case "high":
|
||||
out, _ = sjson.Set(out, "thinking.budget_tokens", 24576)
|
||||
if v := root.Get("reasoning_effort"); v.Exists() && util.ModelSupportsThinking(modelName) && !util.ModelUsesThinkingLevels(modelName) {
|
||||
effort := strings.ToLower(strings.TrimSpace(v.String()))
|
||||
if effort != "" {
|
||||
budget, ok := util.ThinkingEffortToBudget(modelName, effort)
|
||||
if ok {
|
||||
switch budget {
|
||||
case 0:
|
||||
out, _ = sjson.Set(out, "thinking.type", "disabled")
|
||||
case -1:
|
||||
out, _ = sjson.Set(out, "thinking.type", "enabled")
|
||||
default:
|
||||
if budget > 0 {
|
||||
out, _ = sjson.Set(out, "thinking.type", "enabled")
|
||||
out, _ = sjson.Set(out, "thinking.budget_tokens", budget)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -10,6 +10,7 @@ import (
|
||||
"strings"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
|
||||
"github.com/tidwall/gjson"
|
||||
"github.com/tidwall/sjson"
|
||||
)
|
||||
@@ -52,20 +53,23 @@ func ConvertOpenAIResponsesRequestToClaude(modelName string, inputRawJSON []byte
|
||||
|
||||
root := gjson.ParseBytes(rawJSON)
|
||||
|
||||
if v := root.Get("reasoning.effort"); v.Exists() {
|
||||
out, _ = sjson.Set(out, "thinking.type", "enabled")
|
||||
|
||||
switch v.String() {
|
||||
case "none":
|
||||
out, _ = sjson.Set(out, "thinking.type", "disabled")
|
||||
case "minimal":
|
||||
out, _ = sjson.Set(out, "thinking.budget_tokens", 1024)
|
||||
case "low":
|
||||
out, _ = sjson.Set(out, "thinking.budget_tokens", 4096)
|
||||
case "medium":
|
||||
out, _ = sjson.Set(out, "thinking.budget_tokens", 8192)
|
||||
case "high":
|
||||
out, _ = sjson.Set(out, "thinking.budget_tokens", 24576)
|
||||
if v := root.Get("reasoning.effort"); v.Exists() && util.ModelSupportsThinking(modelName) && !util.ModelUsesThinkingLevels(modelName) {
|
||||
effort := strings.ToLower(strings.TrimSpace(v.String()))
|
||||
if effort != "" {
|
||||
budget, ok := util.ThinkingEffortToBudget(modelName, effort)
|
||||
if ok {
|
||||
switch budget {
|
||||
case 0:
|
||||
out, _ = sjson.Set(out, "thinking.type", "disabled")
|
||||
case -1:
|
||||
out, _ = sjson.Set(out, "thinking.type", "enabled")
|
||||
default:
|
||||
if budget > 0 {
|
||||
out, _ = sjson.Set(out, "thinking.type", "enabled")
|
||||
out, _ = sjson.Set(out, "thinking.budget_tokens", budget)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -95,7 +95,7 @@ func ConvertClaudeResponseToOpenAIResponses(ctx context.Context, modelName strin
|
||||
}
|
||||
}
|
||||
// response.created
|
||||
created := `{"type":"response.created","sequence_number":0,"response":{"id":"","object":"response","created_at":0,"status":"in_progress","background":false,"error":null,"instructions":""}}`
|
||||
created := `{"type":"response.created","sequence_number":0,"response":{"id":"","object":"response","created_at":0,"status":"in_progress","background":false,"error":null,"output":[]}}`
|
||||
created, _ = sjson.Set(created, "sequence_number", nextSeq())
|
||||
created, _ = sjson.Set(created, "response.id", st.ResponseID)
|
||||
created, _ = sjson.Set(created, "response.created_at", st.CreatedAt)
|
||||
@@ -197,11 +197,11 @@ func ConvertClaudeResponseToOpenAIResponses(ctx context.Context, modelName strin
|
||||
if st.ReasoningActive {
|
||||
if t := d.Get("thinking"); t.Exists() {
|
||||
st.ReasoningBuf.WriteString(t.String())
|
||||
msg := `{"type":"response.reasoning_summary_text.delta","sequence_number":0,"item_id":"","output_index":0,"summary_index":0,"text":""}`
|
||||
msg := `{"type":"response.reasoning_summary_text.delta","sequence_number":0,"item_id":"","output_index":0,"summary_index":0,"delta":""}`
|
||||
msg, _ = sjson.Set(msg, "sequence_number", nextSeq())
|
||||
msg, _ = sjson.Set(msg, "item_id", st.ReasoningItemID)
|
||||
msg, _ = sjson.Set(msg, "output_index", st.ReasoningIndex)
|
||||
msg, _ = sjson.Set(msg, "text", t.String())
|
||||
msg, _ = sjson.Set(msg, "delta", t.String())
|
||||
out = append(out, emitEvent("response.reasoning_summary_text.delta", msg))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -12,6 +12,7 @@ import (
|
||||
"strings"
|
||||
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/misc"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
|
||||
"github.com/tidwall/gjson"
|
||||
"github.com/tidwall/sjson"
|
||||
)
|
||||
@@ -214,7 +215,27 @@ func ConvertClaudeRequestToCodex(modelName string, inputRawJSON []byte, _ bool)
|
||||
|
||||
// Add additional configuration parameters for the Codex API.
|
||||
template, _ = sjson.Set(template, "parallel_tool_calls", true)
|
||||
template, _ = sjson.Set(template, "reasoning.effort", "medium")
|
||||
|
||||
// Convert thinking.budget_tokens to reasoning.effort for level-based models
|
||||
reasoningEffort := "medium" // default
|
||||
if thinking := rootResult.Get("thinking"); thinking.Exists() && thinking.IsObject() {
|
||||
switch thinking.Get("type").String() {
|
||||
case "enabled":
|
||||
if util.ModelUsesThinkingLevels(modelName) {
|
||||
if budgetTokens := thinking.Get("budget_tokens"); budgetTokens.Exists() {
|
||||
budget := int(budgetTokens.Int())
|
||||
if effort, ok := util.ThinkingBudgetToEffort(modelName, budget); ok && effort != "" {
|
||||
reasoningEffort = effort
|
||||
}
|
||||
}
|
||||
}
|
||||
case "disabled":
|
||||
if effort, ok := util.ThinkingBudgetToEffort(modelName, 0); ok && effort != "" {
|
||||
reasoningEffort = effort
|
||||
}
|
||||
}
|
||||
}
|
||||
template, _ = sjson.Set(template, "reasoning.effort", reasoningEffort)
|
||||
template, _ = sjson.Set(template, "reasoning.summary", "auto")
|
||||
template, _ = sjson.Set(template, "stream", true)
|
||||
template, _ = sjson.Set(template, "store", false)
|
||||
|
||||
@@ -245,7 +245,22 @@ func ConvertGeminiRequestToCodex(modelName string, inputRawJSON []byte, _ bool)
|
||||
|
||||
// Fixed flags aligning with Codex expectations
|
||||
out, _ = sjson.Set(out, "parallel_tool_calls", true)
|
||||
out, _ = sjson.Set(out, "reasoning.effort", "medium")
|
||||
|
||||
// Convert thinkingBudget to reasoning.effort for level-based models
|
||||
reasoningEffort := "medium" // default
|
||||
if genConfig := root.Get("generationConfig"); genConfig.Exists() {
|
||||
if thinkingConfig := genConfig.Get("thinkingConfig"); thinkingConfig.Exists() && thinkingConfig.IsObject() {
|
||||
if util.ModelUsesThinkingLevels(modelName) {
|
||||
if thinkingBudget := thinkingConfig.Get("thinkingBudget"); thinkingBudget.Exists() {
|
||||
budget := int(thinkingBudget.Int())
|
||||
if effort, ok := util.ThinkingBudgetToEffort(modelName, budget); ok && effort != "" {
|
||||
reasoningEffort = effort
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
out, _ = sjson.Set(out, "reasoning.effort", reasoningEffort)
|
||||
out, _ = sjson.Set(out, "reasoning.summary", "auto")
|
||||
out, _ = sjson.Set(out, "stream", true)
|
||||
out, _ = sjson.Set(out, "store", false)
|
||||
|
||||
@@ -39,31 +39,13 @@ func ConvertOpenAIRequestToGeminiCLI(modelName string, inputRawJSON []byte, _ bo
|
||||
// Note: OpenAI official fields take precedence over extra_body.google.thinking_config
|
||||
re := gjson.GetBytes(rawJSON, "reasoning_effort")
|
||||
hasOfficialThinking := re.Exists()
|
||||
if hasOfficialThinking && util.ModelSupportsThinking(modelName) {
|
||||
switch re.String() {
|
||||
case "none":
|
||||
out, _ = sjson.DeleteBytes(out, "request.generationConfig.thinkingConfig.include_thoughts")
|
||||
out, _ = sjson.SetBytes(out, "request.generationConfig.thinkingConfig.thinkingBudget", 0)
|
||||
case "auto":
|
||||
out, _ = sjson.SetBytes(out, "request.generationConfig.thinkingConfig.thinkingBudget", -1)
|
||||
out, _ = sjson.SetBytes(out, "request.generationConfig.thinkingConfig.include_thoughts", true)
|
||||
case "low":
|
||||
out, _ = sjson.SetBytes(out, "request.generationConfig.thinkingConfig.thinkingBudget", 1024)
|
||||
out, _ = sjson.SetBytes(out, "request.generationConfig.thinkingConfig.include_thoughts", true)
|
||||
case "medium":
|
||||
out, _ = sjson.SetBytes(out, "request.generationConfig.thinkingConfig.thinkingBudget", 8192)
|
||||
out, _ = sjson.SetBytes(out, "request.generationConfig.thinkingConfig.include_thoughts", true)
|
||||
case "high":
|
||||
out, _ = sjson.SetBytes(out, "request.generationConfig.thinkingConfig.thinkingBudget", 32768)
|
||||
out, _ = sjson.SetBytes(out, "request.generationConfig.thinkingConfig.include_thoughts", true)
|
||||
default:
|
||||
out, _ = sjson.SetBytes(out, "request.generationConfig.thinkingConfig.thinkingBudget", -1)
|
||||
out, _ = sjson.SetBytes(out, "request.generationConfig.thinkingConfig.include_thoughts", true)
|
||||
}
|
||||
if hasOfficialThinking && util.ModelSupportsThinking(modelName) && !util.ModelUsesThinkingLevels(modelName) {
|
||||
out = util.ApplyReasoningEffortToGeminiCLI(out, re.String())
|
||||
}
|
||||
|
||||
// Cherry Studio extension extra_body.google.thinking_config (effective only when official fields are absent)
|
||||
if !hasOfficialThinking && util.ModelSupportsThinking(modelName) {
|
||||
// Only apply for models that use numeric budgets, not discrete levels.
|
||||
if !hasOfficialThinking && util.ModelSupportsThinking(modelName) && !util.ModelUsesThinkingLevels(modelName) {
|
||||
if tc := gjson.GetBytes(rawJSON, "extra_body.google.thinking_config"); tc.Exists() && tc.IsObject() {
|
||||
var setBudget bool
|
||||
var budget int
|
||||
@@ -223,52 +205,52 @@ func ConvertOpenAIRequestToGeminiCLI(modelName string, inputRawJSON []byte, _ bo
|
||||
}
|
||||
out, _ = sjson.SetRawBytes(out, "request.contents.-1", node)
|
||||
} else if role == "assistant" {
|
||||
p := 0
|
||||
node := []byte(`{"role":"model","parts":[]}`)
|
||||
if content.Type == gjson.String {
|
||||
// Assistant text -> single model content
|
||||
node := []byte(`{"role":"model","parts":[{"text":""}]}`)
|
||||
node, _ = sjson.SetBytes(node, "parts.0.text", content.String())
|
||||
node, _ = sjson.SetBytes(node, "parts.-1.text", content.String())
|
||||
out, _ = sjson.SetRawBytes(out, "request.contents.-1", node)
|
||||
} else if !content.Exists() || content.Type == gjson.Null {
|
||||
// Tool calls -> single model content with functionCall parts
|
||||
tcs := m.Get("tool_calls")
|
||||
if tcs.IsArray() {
|
||||
node := []byte(`{"role":"model","parts":[]}`)
|
||||
p := 0
|
||||
fIDs := make([]string, 0)
|
||||
for _, tc := range tcs.Array() {
|
||||
if tc.Get("type").String() != "function" {
|
||||
continue
|
||||
}
|
||||
fid := tc.Get("id").String()
|
||||
fname := tc.Get("function.name").String()
|
||||
fargs := tc.Get("function.arguments").String()
|
||||
node, _ = sjson.SetBytes(node, "parts."+itoa(p)+".functionCall.name", fname)
|
||||
node, _ = sjson.SetRawBytes(node, "parts."+itoa(p)+".functionCall.args", []byte(fargs))
|
||||
node, _ = sjson.SetBytes(node, "parts."+itoa(p)+".thoughtSignature", geminiCLIFunctionThoughtSignature)
|
||||
p++
|
||||
if fid != "" {
|
||||
fIDs = append(fIDs, fid)
|
||||
}
|
||||
}
|
||||
out, _ = sjson.SetRawBytes(out, "request.contents.-1", node)
|
||||
p++
|
||||
}
|
||||
|
||||
// Append a single tool content combining name + response per function
|
||||
toolNode := []byte(`{"role":"tool","parts":[]}`)
|
||||
pp := 0
|
||||
for _, fid := range fIDs {
|
||||
if name, ok := tcID2Name[fid]; ok {
|
||||
toolNode, _ = sjson.SetBytes(toolNode, "parts."+itoa(pp)+".functionResponse.name", name)
|
||||
resp := toolResponses[fid]
|
||||
if resp == "" {
|
||||
resp = "{}"
|
||||
}
|
||||
toolNode, _ = sjson.SetBytes(toolNode, "parts."+itoa(pp)+".functionResponse.response.result", []byte(resp))
|
||||
pp++
|
||||
// Tool calls -> single model content with functionCall parts
|
||||
tcs := m.Get("tool_calls")
|
||||
if tcs.IsArray() {
|
||||
fIDs := make([]string, 0)
|
||||
for _, tc := range tcs.Array() {
|
||||
if tc.Get("type").String() != "function" {
|
||||
continue
|
||||
}
|
||||
fid := tc.Get("id").String()
|
||||
fname := tc.Get("function.name").String()
|
||||
fargs := tc.Get("function.arguments").String()
|
||||
node, _ = sjson.SetBytes(node, "parts."+itoa(p)+".functionCall.name", fname)
|
||||
node, _ = sjson.SetRawBytes(node, "parts."+itoa(p)+".functionCall.args", []byte(fargs))
|
||||
node, _ = sjson.SetBytes(node, "parts."+itoa(p)+".thoughtSignature", geminiCLIFunctionThoughtSignature)
|
||||
p++
|
||||
if fid != "" {
|
||||
fIDs = append(fIDs, fid)
|
||||
}
|
||||
}
|
||||
out, _ = sjson.SetRawBytes(out, "request.contents.-1", node)
|
||||
|
||||
// Append a single tool content combining name + response per function
|
||||
toolNode := []byte(`{"role":"tool","parts":[]}`)
|
||||
pp := 0
|
||||
for _, fid := range fIDs {
|
||||
if name, ok := tcID2Name[fid]; ok {
|
||||
toolNode, _ = sjson.SetBytes(toolNode, "parts."+itoa(pp)+".functionResponse.name", name)
|
||||
resp := toolResponses[fid]
|
||||
if resp == "" {
|
||||
resp = "{}"
|
||||
}
|
||||
toolNode, _ = sjson.SetBytes(toolNode, "parts."+itoa(pp)+".functionResponse.response.result", []byte(resp))
|
||||
pp++
|
||||
}
|
||||
if pp > 0 {
|
||||
out, _ = sjson.SetRawBytes(out, "request.contents.-1", toolNode)
|
||||
}
|
||||
}
|
||||
if pp > 0 {
|
||||
out, _ = sjson.SetRawBytes(out, "request.contents.-1", toolNode)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -352,18 +334,3 @@ func ConvertOpenAIRequestToGeminiCLI(modelName string, inputRawJSON []byte, _ bo
|
||||
|
||||
// itoa converts int to string without strconv import for few usages.
|
||||
func itoa(i int) string { return fmt.Sprintf("%d", i) }
|
||||
|
||||
// quoteIfNeeded ensures a string is valid JSON value (quotes plain text), pass-through for JSON objects/arrays.
|
||||
func quoteIfNeeded(s string) string {
|
||||
s = strings.TrimSpace(s)
|
||||
if s == "" {
|
||||
return "\"\""
|
||||
}
|
||||
if len(s) > 0 && (s[0] == '{' || s[0] == '[') {
|
||||
return s
|
||||
}
|
||||
// escape quotes minimally
|
||||
s = strings.ReplaceAll(s, "\\", "\\\\")
|
||||
s = strings.ReplaceAll(s, "\"", "\\\"")
|
||||
return "\"" + s + "\""
|
||||
}
|
||||
|
||||
@@ -154,7 +154,8 @@ func ConvertClaudeRequestToGemini(modelName string, inputRawJSON []byte, _ bool)
|
||||
}
|
||||
|
||||
// Map Anthropic thinking -> Gemini thinkingBudget/include_thoughts when enabled
|
||||
if t := gjson.GetBytes(rawJSON, "thinking"); t.Exists() && t.IsObject() && util.ModelSupportsThinking(modelName) {
|
||||
// Only apply for models that use numeric budgets, not discrete levels.
|
||||
if t := gjson.GetBytes(rawJSON, "thinking"); t.Exists() && t.IsObject() && util.ModelSupportsThinking(modelName) && !util.ModelUsesThinkingLevels(modelName) {
|
||||
if t.Get("type").String() == "enabled" {
|
||||
if b := t.Get("budget_tokens"); b.Exists() && b.Type == gjson.Number {
|
||||
budget := int(b.Int())
|
||||
|
||||
@@ -179,6 +179,18 @@ func ConvertGeminiResponseToClaude(_ context.Context, _ string, originalRequestR
|
||||
usedTool = true
|
||||
fcName := functionCallResult.Get("name").String()
|
||||
|
||||
// FIX: Handle streaming split/delta where name might be empty in subsequent chunks.
|
||||
// If we are already in tool use mode and name is empty, treat as continuation (delta).
|
||||
if (*param).(*Params).ResponseType == 3 && fcName == "" {
|
||||
if fcArgsResult := functionCallResult.Get("args"); fcArgsResult.Exists() {
|
||||
output = output + "event: content_block_delta\n"
|
||||
data, _ := sjson.Set(fmt.Sprintf(`{"type":"content_block_delta","index":%d,"delta":{"type":"input_json_delta","partial_json":""}}`, (*param).(*Params).ResponseIndex), "delta.partial_json", fcArgsResult.Raw)
|
||||
output = output + fmt.Sprintf("data: %s\n\n\n", data)
|
||||
}
|
||||
// Continue to next part without closing/opening logic
|
||||
continue
|
||||
}
|
||||
|
||||
// Handle state transitions when switching to function calls
|
||||
// Close any existing function call block first
|
||||
if (*param).(*Params).ResponseType == 3 {
|
||||
|
||||
@@ -37,33 +37,33 @@ func ConvertOpenAIRequestToGemini(modelName string, inputRawJSON []byte, _ bool)
|
||||
|
||||
// Reasoning effort -> thinkingBudget/include_thoughts
|
||||
// Note: OpenAI official fields take precedence over extra_body.google.thinking_config
|
||||
// Only apply numeric budgets for models that use budgets (not discrete levels) to avoid
|
||||
// incorrectly applying thinkingBudget for level-based models like gpt-5. Gemini 3 models
|
||||
// use thinkingLevel/includeThoughts instead.
|
||||
re := gjson.GetBytes(rawJSON, "reasoning_effort")
|
||||
hasOfficialThinking := re.Exists()
|
||||
if hasOfficialThinking && util.ModelSupportsThinking(modelName) {
|
||||
switch re.String() {
|
||||
case "none":
|
||||
out, _ = sjson.DeleteBytes(out, "generationConfig.thinkingConfig.include_thoughts")
|
||||
out, _ = sjson.SetBytes(out, "generationConfig.thinkingConfig.thinkingBudget", 0)
|
||||
case "auto":
|
||||
out, _ = sjson.SetBytes(out, "generationConfig.thinkingConfig.thinkingBudget", -1)
|
||||
out, _ = sjson.SetBytes(out, "generationConfig.thinkingConfig.include_thoughts", true)
|
||||
case "low":
|
||||
out, _ = sjson.SetBytes(out, "generationConfig.thinkingConfig.thinkingBudget", 1024)
|
||||
out, _ = sjson.SetBytes(out, "generationConfig.thinkingConfig.include_thoughts", true)
|
||||
case "medium":
|
||||
out, _ = sjson.SetBytes(out, "generationConfig.thinkingConfig.thinkingBudget", 8192)
|
||||
out, _ = sjson.SetBytes(out, "generationConfig.thinkingConfig.include_thoughts", true)
|
||||
case "high":
|
||||
out, _ = sjson.SetBytes(out, "generationConfig.thinkingConfig.thinkingBudget", 32768)
|
||||
out, _ = sjson.SetBytes(out, "generationConfig.thinkingConfig.include_thoughts", true)
|
||||
default:
|
||||
out, _ = sjson.SetBytes(out, "generationConfig.thinkingConfig.thinkingBudget", -1)
|
||||
out, _ = sjson.SetBytes(out, "generationConfig.thinkingConfig.include_thoughts", true)
|
||||
effort := strings.ToLower(strings.TrimSpace(re.String()))
|
||||
if util.IsGemini3Model(modelName) {
|
||||
switch effort {
|
||||
case "none":
|
||||
out, _ = sjson.DeleteBytes(out, "generationConfig.thinkingConfig")
|
||||
case "auto":
|
||||
includeThoughts := true
|
||||
out = util.ApplyGeminiThinkingLevel(out, "", &includeThoughts)
|
||||
default:
|
||||
if level, ok := util.ValidateGemini3ThinkingLevel(modelName, effort); ok {
|
||||
out = util.ApplyGeminiThinkingLevel(out, level, nil)
|
||||
}
|
||||
}
|
||||
} else if !util.ModelUsesThinkingLevels(modelName) {
|
||||
out = util.ApplyReasoningEffortToGemini(out, effort)
|
||||
}
|
||||
}
|
||||
|
||||
// Cherry Studio extension extra_body.google.thinking_config (effective only when official fields are absent)
|
||||
if !hasOfficialThinking && util.ModelSupportsThinking(modelName) {
|
||||
// Only apply for models that use numeric budgets, not discrete levels.
|
||||
if !hasOfficialThinking && util.ModelSupportsThinking(modelName) && !util.ModelUsesThinkingLevels(modelName) {
|
||||
if tc := gjson.GetBytes(rawJSON, "extra_body.google.thinking_config"); tc.Exists() && tc.IsObject() {
|
||||
var setBudget bool
|
||||
var budget int
|
||||
@@ -223,15 +223,16 @@ func ConvertOpenAIRequestToGemini(modelName string, inputRawJSON []byte, _ bool)
|
||||
}
|
||||
out, _ = sjson.SetRawBytes(out, "contents.-1", node)
|
||||
} else if role == "assistant" {
|
||||
node := []byte(`{"role":"model","parts":[]}`)
|
||||
p := 0
|
||||
|
||||
if content.Type == gjson.String {
|
||||
// Assistant text -> single model content
|
||||
node := []byte(`{"role":"model","parts":[{"text":""}]}`)
|
||||
node, _ = sjson.SetBytes(node, "parts.0.text", content.String())
|
||||
node, _ = sjson.SetBytes(node, "parts.-1.text", content.String())
|
||||
out, _ = sjson.SetRawBytes(out, "contents.-1", node)
|
||||
p++
|
||||
} else if content.IsArray() {
|
||||
// Assistant multimodal content (e.g. text + image) -> single model content with parts
|
||||
node := []byte(`{"role":"model","parts":[]}`)
|
||||
p := 0
|
||||
for _, item := range content.Array() {
|
||||
switch item.Get("type").String() {
|
||||
case "text":
|
||||
@@ -253,47 +254,45 @@ func ConvertOpenAIRequestToGemini(modelName string, inputRawJSON []byte, _ bool)
|
||||
}
|
||||
}
|
||||
out, _ = sjson.SetRawBytes(out, "contents.-1", node)
|
||||
} else if !content.Exists() || content.Type == gjson.Null {
|
||||
// Tool calls -> single model content with functionCall parts
|
||||
tcs := m.Get("tool_calls")
|
||||
if tcs.IsArray() {
|
||||
node := []byte(`{"role":"model","parts":[]}`)
|
||||
p := 0
|
||||
fIDs := make([]string, 0)
|
||||
for _, tc := range tcs.Array() {
|
||||
if tc.Get("type").String() != "function" {
|
||||
continue
|
||||
}
|
||||
fid := tc.Get("id").String()
|
||||
fname := tc.Get("function.name").String()
|
||||
fargs := tc.Get("function.arguments").String()
|
||||
node, _ = sjson.SetBytes(node, "parts."+itoa(p)+".functionCall.name", fname)
|
||||
node, _ = sjson.SetRawBytes(node, "parts."+itoa(p)+".functionCall.args", []byte(fargs))
|
||||
node, _ = sjson.SetBytes(node, "parts."+itoa(p)+".thoughtSignature", geminiFunctionThoughtSignature)
|
||||
p++
|
||||
if fid != "" {
|
||||
fIDs = append(fIDs, fid)
|
||||
}
|
||||
}
|
||||
out, _ = sjson.SetRawBytes(out, "contents.-1", node)
|
||||
}
|
||||
|
||||
// Append a single tool content combining name + response per function
|
||||
toolNode := []byte(`{"role":"tool","parts":[]}`)
|
||||
pp := 0
|
||||
for _, fid := range fIDs {
|
||||
if name, ok := tcID2Name[fid]; ok {
|
||||
toolNode, _ = sjson.SetBytes(toolNode, "parts."+itoa(pp)+".functionResponse.name", name)
|
||||
resp := toolResponses[fid]
|
||||
if resp == "" {
|
||||
resp = "{}"
|
||||
}
|
||||
toolNode, _ = sjson.SetBytes(toolNode, "parts."+itoa(pp)+".functionResponse.response.result", []byte(resp))
|
||||
pp++
|
||||
// Tool calls -> single model content with functionCall parts
|
||||
tcs := m.Get("tool_calls")
|
||||
if tcs.IsArray() {
|
||||
fIDs := make([]string, 0)
|
||||
for _, tc := range tcs.Array() {
|
||||
if tc.Get("type").String() != "function" {
|
||||
continue
|
||||
}
|
||||
fid := tc.Get("id").String()
|
||||
fname := tc.Get("function.name").String()
|
||||
fargs := tc.Get("function.arguments").String()
|
||||
node, _ = sjson.SetBytes(node, "parts."+itoa(p)+".functionCall.name", fname)
|
||||
node, _ = sjson.SetRawBytes(node, "parts."+itoa(p)+".functionCall.args", []byte(fargs))
|
||||
node, _ = sjson.SetBytes(node, "parts."+itoa(p)+".thoughtSignature", geminiFunctionThoughtSignature)
|
||||
p++
|
||||
if fid != "" {
|
||||
fIDs = append(fIDs, fid)
|
||||
}
|
||||
}
|
||||
out, _ = sjson.SetRawBytes(out, "contents.-1", node)
|
||||
|
||||
// Append a single tool content combining name + response per function
|
||||
toolNode := []byte(`{"role":"tool","parts":[]}`)
|
||||
pp := 0
|
||||
for _, fid := range fIDs {
|
||||
if name, ok := tcID2Name[fid]; ok {
|
||||
toolNode, _ = sjson.SetBytes(toolNode, "parts."+itoa(pp)+".functionResponse.name", name)
|
||||
resp := toolResponses[fid]
|
||||
if resp == "" {
|
||||
resp = "{}"
|
||||
}
|
||||
toolNode, _ = sjson.SetBytes(toolNode, "parts."+itoa(pp)+".functionResponse.response.result", []byte(resp))
|
||||
pp++
|
||||
}
|
||||
if pp > 0 {
|
||||
out, _ = sjson.SetRawBytes(out, "contents.-1", toolNode)
|
||||
}
|
||||
}
|
||||
if pp > 0 {
|
||||
out, _ = sjson.SetRawBytes(out, "contents.-1", toolNode)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -379,18 +378,3 @@ func ConvertOpenAIRequestToGemini(modelName string, inputRawJSON []byte, _ bool)
|
||||
|
||||
// itoa converts int to string without strconv import for few usages.
|
||||
func itoa(i int) string { return fmt.Sprintf("%d", i) }
|
||||
|
||||
// quoteIfNeeded ensures a string is valid JSON value (quotes plain text), pass-through for JSON objects/arrays.
|
||||
func quoteIfNeeded(s string) string {
|
||||
s = strings.TrimSpace(s)
|
||||
if s == "" {
|
||||
return "\"\""
|
||||
}
|
||||
if len(s) > 0 && (s[0] == '{' || s[0] == '[') {
|
||||
return s
|
||||
}
|
||||
// escape quotes minimally
|
||||
s = strings.ReplaceAll(s, "\\", "\\\\")
|
||||
s = strings.ReplaceAll(s, "\"", "\\\"")
|
||||
return "\"" + s + "\""
|
||||
}
|
||||
|
||||
@@ -389,36 +389,16 @@ func ConvertOpenAIResponsesRequestToGemini(modelName string, inputRawJSON []byte
|
||||
}
|
||||
|
||||
// OpenAI official reasoning fields take precedence
|
||||
// Only convert for models that use numeric budgets (not discrete levels).
|
||||
hasOfficialThinking := root.Get("reasoning.effort").Exists()
|
||||
if hasOfficialThinking && util.ModelSupportsThinking(modelName) {
|
||||
if hasOfficialThinking && util.ModelSupportsThinking(modelName) && !util.ModelUsesThinkingLevels(modelName) {
|
||||
reasoningEffort := root.Get("reasoning.effort")
|
||||
switch reasoningEffort.String() {
|
||||
case "none":
|
||||
out, _ = sjson.Set(out, "generationConfig.thinkingConfig.include_thoughts", false)
|
||||
out, _ = sjson.Set(out, "generationConfig.thinkingConfig.thinkingBudget", 0)
|
||||
case "auto":
|
||||
out, _ = sjson.Set(out, "generationConfig.thinkingConfig.thinkingBudget", -1)
|
||||
out, _ = sjson.Set(out, "generationConfig.thinkingConfig.include_thoughts", true)
|
||||
case "minimal":
|
||||
out, _ = sjson.Set(out, "generationConfig.thinkingConfig.thinkingBudget", 1024)
|
||||
out, _ = sjson.Set(out, "generationConfig.thinkingConfig.include_thoughts", true)
|
||||
case "low":
|
||||
out, _ = sjson.Set(out, "generationConfig.thinkingConfig.thinkingBudget", 4096)
|
||||
out, _ = sjson.Set(out, "generationConfig.thinkingConfig.include_thoughts", true)
|
||||
case "medium":
|
||||
out, _ = sjson.Set(out, "generationConfig.thinkingConfig.thinkingBudget", 8192)
|
||||
out, _ = sjson.Set(out, "generationConfig.thinkingConfig.include_thoughts", true)
|
||||
case "high":
|
||||
out, _ = sjson.Set(out, "generationConfig.thinkingConfig.thinkingBudget", 32768)
|
||||
out, _ = sjson.Set(out, "generationConfig.thinkingConfig.include_thoughts", true)
|
||||
default:
|
||||
out, _ = sjson.Set(out, "generationConfig.thinkingConfig.thinkingBudget", -1)
|
||||
out, _ = sjson.Set(out, "generationConfig.thinkingConfig.include_thoughts", true)
|
||||
}
|
||||
out = string(util.ApplyReasoningEffortToGemini([]byte(out), reasoningEffort.String()))
|
||||
}
|
||||
|
||||
// Cherry Studio extension (applies only when official fields are missing)
|
||||
if !hasOfficialThinking && util.ModelSupportsThinking(modelName) {
|
||||
// Only apply for models that use numeric budgets, not discrete levels.
|
||||
if !hasOfficialThinking && util.ModelSupportsThinking(modelName) && !util.ModelUsesThinkingLevels(modelName) {
|
||||
if tc := root.Get("extra_body.google.thinking_config"); tc.Exists() && tc.IsObject() {
|
||||
var setBudget bool
|
||||
var budget int
|
||||
|
||||
@@ -117,7 +117,7 @@ func ConvertGeminiResponseToOpenAIResponses(_ context.Context, modelName string,
|
||||
st.CreatedAt = time.Now().Unix()
|
||||
}
|
||||
|
||||
created := `{"type":"response.created","sequence_number":0,"response":{"id":"","object":"response","created_at":0,"status":"in_progress","background":false,"error":null}}`
|
||||
created := `{"type":"response.created","sequence_number":0,"response":{"id":"","object":"response","created_at":0,"status":"in_progress","background":false,"error":null,"output":[]}}`
|
||||
created, _ = sjson.Set(created, "sequence_number", nextSeq())
|
||||
created, _ = sjson.Set(created, "response.id", st.ResponseID)
|
||||
created, _ = sjson.Set(created, "response.created_at", st.CreatedAt)
|
||||
@@ -160,11 +160,11 @@ func ConvertGeminiResponseToOpenAIResponses(_ context.Context, modelName string,
|
||||
}
|
||||
if t := part.Get("text"); t.Exists() && t.String() != "" {
|
||||
st.ReasoningBuf.WriteString(t.String())
|
||||
msg := `{"type":"response.reasoning_summary_text.delta","sequence_number":0,"item_id":"","output_index":0,"summary_index":0,"text":""}`
|
||||
msg := `{"type":"response.reasoning_summary_text.delta","sequence_number":0,"item_id":"","output_index":0,"summary_index":0,"delta":""}`
|
||||
msg, _ = sjson.Set(msg, "sequence_number", nextSeq())
|
||||
msg, _ = sjson.Set(msg, "item_id", st.ReasoningItemID)
|
||||
msg, _ = sjson.Set(msg, "output_index", st.ReasoningIndex)
|
||||
msg, _ = sjson.Set(msg, "text", t.String())
|
||||
msg, _ = sjson.Set(msg, "delta", t.String())
|
||||
out = append(out, emitEvent("response.reasoning_summary_text.delta", msg))
|
||||
}
|
||||
return true
|
||||
|
||||
@@ -10,6 +10,7 @@ import (
|
||||
"encoding/json"
|
||||
"strings"
|
||||
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
|
||||
"github.com/tidwall/gjson"
|
||||
"github.com/tidwall/sjson"
|
||||
)
|
||||
@@ -60,6 +61,30 @@ func ConvertClaudeRequestToOpenAI(modelName string, inputRawJSON []byte, stream
|
||||
// Stream
|
||||
out, _ = sjson.Set(out, "stream", stream)
|
||||
|
||||
// Thinking: Convert Claude thinking.budget_tokens to OpenAI reasoning_effort
|
||||
if thinking := root.Get("thinking"); thinking.Exists() && thinking.IsObject() {
|
||||
if thinkingType := thinking.Get("type"); thinkingType.Exists() {
|
||||
switch thinkingType.String() {
|
||||
case "enabled":
|
||||
if budgetTokens := thinking.Get("budget_tokens"); budgetTokens.Exists() {
|
||||
budget := int(budgetTokens.Int())
|
||||
if effort, ok := util.ThinkingBudgetToEffort(modelName, budget); ok && effort != "" {
|
||||
out, _ = sjson.Set(out, "reasoning_effort", effort)
|
||||
}
|
||||
} else {
|
||||
// No budget_tokens specified, default to "auto" for enabled thinking
|
||||
if effort, ok := util.ThinkingBudgetToEffort(modelName, -1); ok && effort != "" {
|
||||
out, _ = sjson.Set(out, "reasoning_effort", effort)
|
||||
}
|
||||
}
|
||||
case "disabled":
|
||||
if effort, ok := util.ThinkingBudgetToEffort(modelName, 0); ok && effort != "" {
|
||||
out, _ = sjson.Set(out, "reasoning_effort", effort)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Process messages and system
|
||||
var messagesJSON = "[]"
|
||||
|
||||
|
||||
@@ -128,9 +128,10 @@ func convertOpenAIStreamingChunkToAnthropic(rawJSON []byte, param *ConvertOpenAI
|
||||
param.CreatedAt = root.Get("created").Int()
|
||||
}
|
||||
|
||||
// Check if this is the first chunk (has role)
|
||||
// Emit message_start on the very first chunk, regardless of whether it has a role field.
|
||||
// Some providers (like Copilot) may send tool_calls in the first chunk without a role field.
|
||||
if delta := root.Get("choices.0.delta"); delta.Exists() {
|
||||
if role := delta.Get("role"); role.Exists() && role.String() == "assistant" && !param.MessageStarted {
|
||||
if !param.MessageStarted {
|
||||
// Send message_start event
|
||||
messageStart := map[string]interface{}{
|
||||
"type": "message_start",
|
||||
|
||||
@@ -13,6 +13,7 @@ import (
|
||||
"math/big"
|
||||
"strings"
|
||||
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
|
||||
"github.com/tidwall/gjson"
|
||||
"github.com/tidwall/sjson"
|
||||
)
|
||||
@@ -76,6 +77,17 @@ func ConvertGeminiRequestToOpenAI(modelName string, inputRawJSON []byte, stream
|
||||
out, _ = sjson.Set(out, "stop", stops)
|
||||
}
|
||||
}
|
||||
|
||||
// Convert thinkingBudget to reasoning_effort
|
||||
// Always perform conversion to support allowCompat models that may not be in registry
|
||||
if thinkingConfig := genConfig.Get("thinkingConfig"); thinkingConfig.Exists() && thinkingConfig.IsObject() {
|
||||
if thinkingBudget := thinkingConfig.Get("thinkingBudget"); thinkingBudget.Exists() {
|
||||
budget := int(thinkingBudget.Int())
|
||||
if effort, ok := util.ThinkingBudgetToEffort(modelName, budget); ok && effort != "" {
|
||||
out, _ = sjson.Set(out, "reasoning_effort", effort)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Stream parameter
|
||||
|
||||
@@ -2,6 +2,7 @@ package responses
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"strings"
|
||||
|
||||
"github.com/tidwall/gjson"
|
||||
"github.com/tidwall/sjson"
|
||||
@@ -64,7 +65,7 @@ func ConvertOpenAIResponsesRequestToOpenAIChatCompletions(modelName string, inpu
|
||||
}
|
||||
|
||||
switch itemType {
|
||||
case "message":
|
||||
case "message", "":
|
||||
// Handle regular message conversion
|
||||
role := item.Get("role").String()
|
||||
message := `{"role":"","content":""}`
|
||||
@@ -106,6 +107,8 @@ func ConvertOpenAIResponsesRequestToOpenAIChatCompletions(modelName string, inpu
|
||||
if len(toolCalls) > 0 {
|
||||
message, _ = sjson.Set(message, "tool_calls", toolCalls)
|
||||
}
|
||||
} else if content.Type == gjson.String {
|
||||
message, _ = sjson.Set(message, "content", content.String())
|
||||
}
|
||||
|
||||
out, _ = sjson.SetRaw(out, "messages.-1", message)
|
||||
@@ -189,23 +192,9 @@ func ConvertOpenAIResponsesRequestToOpenAIChatCompletions(modelName string, inpu
|
||||
}
|
||||
|
||||
if reasoningEffort := root.Get("reasoning.effort"); reasoningEffort.Exists() {
|
||||
switch reasoningEffort.String() {
|
||||
case "none":
|
||||
out, _ = sjson.Set(out, "reasoning_effort", "none")
|
||||
case "auto":
|
||||
out, _ = sjson.Set(out, "reasoning_effort", "auto")
|
||||
case "minimal":
|
||||
out, _ = sjson.Set(out, "reasoning_effort", "low")
|
||||
case "low":
|
||||
out, _ = sjson.Set(out, "reasoning_effort", "low")
|
||||
case "medium":
|
||||
out, _ = sjson.Set(out, "reasoning_effort", "medium")
|
||||
case "high":
|
||||
out, _ = sjson.Set(out, "reasoning_effort", "high")
|
||||
case "xhigh":
|
||||
out, _ = sjson.Set(out, "reasoning_effort", "xhigh")
|
||||
default:
|
||||
out, _ = sjson.Set(out, "reasoning_effort", "auto")
|
||||
effort := strings.ToLower(strings.TrimSpace(reasoningEffort.String()))
|
||||
if effort != "" {
|
||||
out, _ = sjson.Set(out, "reasoning_effort", effort)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -143,7 +143,7 @@ func ConvertOpenAIChatCompletionsResponseToOpenAIResponses(ctx context.Context,
|
||||
st.ReasoningTokens = 0
|
||||
st.UsageSeen = false
|
||||
// response.created
|
||||
created := `{"type":"response.created","sequence_number":0,"response":{"id":"","object":"response","created_at":0,"status":"in_progress","background":false,"error":null}}`
|
||||
created := `{"type":"response.created","sequence_number":0,"response":{"id":"","object":"response","created_at":0,"status":"in_progress","background":false,"error":null,"output":[]}}`
|
||||
created, _ = sjson.Set(created, "sequence_number", nextSeq())
|
||||
created, _ = sjson.Set(created, "response.id", st.ResponseID)
|
||||
created, _ = sjson.Set(created, "response.created_at", st.Created)
|
||||
@@ -216,11 +216,11 @@ func ConvertOpenAIChatCompletionsResponseToOpenAIResponses(ctx context.Context,
|
||||
}
|
||||
// Append incremental text to reasoning buffer
|
||||
st.ReasoningBuf.WriteString(rc.String())
|
||||
msg := `{"type":"response.reasoning_summary_text.delta","sequence_number":0,"item_id":"","output_index":0,"summary_index":0,"text":""}`
|
||||
msg := `{"type":"response.reasoning_summary_text.delta","sequence_number":0,"item_id":"","output_index":0,"summary_index":0,"delta":""}`
|
||||
msg, _ = sjson.Set(msg, "sequence_number", nextSeq())
|
||||
msg, _ = sjson.Set(msg, "item_id", st.ReasoningID)
|
||||
msg, _ = sjson.Set(msg, "output_index", st.ReasoningIndex)
|
||||
msg, _ = sjson.Set(msg, "text", rc.String())
|
||||
msg, _ = sjson.Set(msg, "delta", rc.String())
|
||||
out = append(out, emitRespEvent("response.reasoning_summary_text.delta", msg))
|
||||
}
|
||||
|
||||
|
||||
497
internal/util/gemini_schema.go
Normal file
497
internal/util/gemini_schema.go
Normal file
@@ -0,0 +1,497 @@
|
||||
// Package util provides utility functions for the CLI Proxy API server.
|
||||
package util
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"sort"
|
||||
"strings"
|
||||
|
||||
"github.com/tidwall/gjson"
|
||||
"github.com/tidwall/sjson"
|
||||
)
|
||||
|
||||
var gjsonPathKeyReplacer = strings.NewReplacer(".", "\\.", "*", "\\*", "?", "\\?")
|
||||
|
||||
// CleanJSONSchemaForGemini transforms a JSON schema to be compatible with Gemini/Antigravity API.
|
||||
// It handles unsupported keywords, type flattening, and schema simplification while preserving
|
||||
// semantic information as description hints.
|
||||
func CleanJSONSchemaForGemini(jsonStr string) string {
|
||||
// Phase 1: Convert and add hints
|
||||
jsonStr = convertRefsToHints(jsonStr)
|
||||
jsonStr = convertConstToEnum(jsonStr)
|
||||
jsonStr = addEnumHints(jsonStr)
|
||||
jsonStr = addAdditionalPropertiesHints(jsonStr)
|
||||
jsonStr = moveConstraintsToDescription(jsonStr)
|
||||
|
||||
// Phase 2: Flatten complex structures
|
||||
jsonStr = mergeAllOf(jsonStr)
|
||||
jsonStr = flattenAnyOfOneOf(jsonStr)
|
||||
jsonStr = flattenTypeArrays(jsonStr)
|
||||
|
||||
// Phase 3: Cleanup
|
||||
jsonStr = removeUnsupportedKeywords(jsonStr)
|
||||
jsonStr = cleanupRequiredFields(jsonStr)
|
||||
|
||||
return jsonStr
|
||||
}
|
||||
|
||||
// convertRefsToHints converts $ref to description hints (Lazy Hint strategy).
|
||||
func convertRefsToHints(jsonStr string) string {
|
||||
paths := findPaths(jsonStr, "$ref")
|
||||
sortByDepth(paths)
|
||||
|
||||
for _, p := range paths {
|
||||
refVal := gjson.Get(jsonStr, p).String()
|
||||
defName := refVal
|
||||
if idx := strings.LastIndex(refVal, "/"); idx >= 0 {
|
||||
defName = refVal[idx+1:]
|
||||
}
|
||||
|
||||
parentPath := trimSuffix(p, ".$ref")
|
||||
hint := fmt.Sprintf("See: %s", defName)
|
||||
if existing := gjson.Get(jsonStr, descriptionPath(parentPath)).String(); existing != "" {
|
||||
hint = fmt.Sprintf("%s (%s)", existing, hint)
|
||||
}
|
||||
|
||||
replacement := `{"type":"object","description":""}`
|
||||
replacement, _ = sjson.Set(replacement, "description", hint)
|
||||
jsonStr = setRawAt(jsonStr, parentPath, replacement)
|
||||
}
|
||||
return jsonStr
|
||||
}
|
||||
|
||||
func convertConstToEnum(jsonStr string) string {
|
||||
for _, p := range findPaths(jsonStr, "const") {
|
||||
val := gjson.Get(jsonStr, p)
|
||||
if !val.Exists() {
|
||||
continue
|
||||
}
|
||||
enumPath := trimSuffix(p, ".const") + ".enum"
|
||||
if !gjson.Get(jsonStr, enumPath).Exists() {
|
||||
jsonStr, _ = sjson.Set(jsonStr, enumPath, []interface{}{val.Value()})
|
||||
}
|
||||
}
|
||||
return jsonStr
|
||||
}
|
||||
|
||||
func addEnumHints(jsonStr string) string {
|
||||
for _, p := range findPaths(jsonStr, "enum") {
|
||||
arr := gjson.Get(jsonStr, p)
|
||||
if !arr.IsArray() {
|
||||
continue
|
||||
}
|
||||
items := arr.Array()
|
||||
if len(items) <= 1 || len(items) > 10 {
|
||||
continue
|
||||
}
|
||||
|
||||
var vals []string
|
||||
for _, item := range items {
|
||||
vals = append(vals, item.String())
|
||||
}
|
||||
jsonStr = appendHint(jsonStr, trimSuffix(p, ".enum"), "Allowed: "+strings.Join(vals, ", "))
|
||||
}
|
||||
return jsonStr
|
||||
}
|
||||
|
||||
func addAdditionalPropertiesHints(jsonStr string) string {
|
||||
for _, p := range findPaths(jsonStr, "additionalProperties") {
|
||||
if gjson.Get(jsonStr, p).Type == gjson.False {
|
||||
jsonStr = appendHint(jsonStr, trimSuffix(p, ".additionalProperties"), "No extra properties allowed")
|
||||
}
|
||||
}
|
||||
return jsonStr
|
||||
}
|
||||
|
||||
var unsupportedConstraints = []string{
|
||||
"minLength", "maxLength", "exclusiveMinimum", "exclusiveMaximum",
|
||||
"pattern", "minItems", "maxItems",
|
||||
}
|
||||
|
||||
func moveConstraintsToDescription(jsonStr string) string {
|
||||
for _, key := range unsupportedConstraints {
|
||||
for _, p := range findPaths(jsonStr, key) {
|
||||
val := gjson.Get(jsonStr, p)
|
||||
if !val.Exists() || val.IsObject() || val.IsArray() {
|
||||
continue
|
||||
}
|
||||
parentPath := trimSuffix(p, "."+key)
|
||||
if isPropertyDefinition(parentPath) {
|
||||
continue
|
||||
}
|
||||
jsonStr = appendHint(jsonStr, parentPath, fmt.Sprintf("%s: %s", key, val.String()))
|
||||
}
|
||||
}
|
||||
return jsonStr
|
||||
}
|
||||
|
||||
func mergeAllOf(jsonStr string) string {
|
||||
paths := findPaths(jsonStr, "allOf")
|
||||
sortByDepth(paths)
|
||||
|
||||
for _, p := range paths {
|
||||
allOf := gjson.Get(jsonStr, p)
|
||||
if !allOf.IsArray() {
|
||||
continue
|
||||
}
|
||||
parentPath := trimSuffix(p, ".allOf")
|
||||
|
||||
for _, item := range allOf.Array() {
|
||||
if props := item.Get("properties"); props.IsObject() {
|
||||
props.ForEach(func(key, value gjson.Result) bool {
|
||||
destPath := joinPath(parentPath, "properties."+escapeGJSONPathKey(key.String()))
|
||||
jsonStr, _ = sjson.SetRaw(jsonStr, destPath, value.Raw)
|
||||
return true
|
||||
})
|
||||
}
|
||||
if req := item.Get("required"); req.IsArray() {
|
||||
reqPath := joinPath(parentPath, "required")
|
||||
current := getStrings(jsonStr, reqPath)
|
||||
for _, r := range req.Array() {
|
||||
if s := r.String(); !contains(current, s) {
|
||||
current = append(current, s)
|
||||
}
|
||||
}
|
||||
jsonStr, _ = sjson.Set(jsonStr, reqPath, current)
|
||||
}
|
||||
}
|
||||
jsonStr, _ = sjson.Delete(jsonStr, p)
|
||||
}
|
||||
return jsonStr
|
||||
}
|
||||
|
||||
func flattenAnyOfOneOf(jsonStr string) string {
|
||||
for _, key := range []string{"anyOf", "oneOf"} {
|
||||
paths := findPaths(jsonStr, key)
|
||||
sortByDepth(paths)
|
||||
|
||||
for _, p := range paths {
|
||||
arr := gjson.Get(jsonStr, p)
|
||||
if !arr.IsArray() || len(arr.Array()) == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
parentPath := trimSuffix(p, "."+key)
|
||||
parentDesc := gjson.Get(jsonStr, descriptionPath(parentPath)).String()
|
||||
|
||||
items := arr.Array()
|
||||
bestIdx, allTypes := selectBest(items)
|
||||
selected := items[bestIdx].Raw
|
||||
|
||||
if parentDesc != "" {
|
||||
selected = mergeDescriptionRaw(selected, parentDesc)
|
||||
}
|
||||
|
||||
if len(allTypes) > 1 {
|
||||
hint := "Accepts: " + strings.Join(allTypes, " | ")
|
||||
selected = appendHintRaw(selected, hint)
|
||||
}
|
||||
|
||||
jsonStr = setRawAt(jsonStr, parentPath, selected)
|
||||
}
|
||||
}
|
||||
return jsonStr
|
||||
}
|
||||
|
||||
func selectBest(items []gjson.Result) (bestIdx int, types []string) {
|
||||
bestScore := -1
|
||||
for i, item := range items {
|
||||
t := item.Get("type").String()
|
||||
score := 0
|
||||
|
||||
switch {
|
||||
case t == "object" || item.Get("properties").Exists():
|
||||
score, t = 3, orDefault(t, "object")
|
||||
case t == "array" || item.Get("items").Exists():
|
||||
score, t = 2, orDefault(t, "array")
|
||||
case t != "" && t != "null":
|
||||
score = 1
|
||||
default:
|
||||
t = orDefault(t, "null")
|
||||
}
|
||||
|
||||
if t != "" {
|
||||
types = append(types, t)
|
||||
}
|
||||
if score > bestScore {
|
||||
bestScore, bestIdx = score, i
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func flattenTypeArrays(jsonStr string) string {
|
||||
paths := findPaths(jsonStr, "type")
|
||||
sortByDepth(paths)
|
||||
|
||||
nullableFields := make(map[string][]string)
|
||||
|
||||
for _, p := range paths {
|
||||
res := gjson.Get(jsonStr, p)
|
||||
if !res.IsArray() || len(res.Array()) == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
hasNull := false
|
||||
var nonNullTypes []string
|
||||
for _, item := range res.Array() {
|
||||
s := item.String()
|
||||
if s == "null" {
|
||||
hasNull = true
|
||||
} else if s != "" {
|
||||
nonNullTypes = append(nonNullTypes, s)
|
||||
}
|
||||
}
|
||||
|
||||
firstType := "string"
|
||||
if len(nonNullTypes) > 0 {
|
||||
firstType = nonNullTypes[0]
|
||||
}
|
||||
|
||||
jsonStr, _ = sjson.Set(jsonStr, p, firstType)
|
||||
|
||||
parentPath := trimSuffix(p, ".type")
|
||||
if len(nonNullTypes) > 1 {
|
||||
hint := "Accepts: " + strings.Join(nonNullTypes, " | ")
|
||||
jsonStr = appendHint(jsonStr, parentPath, hint)
|
||||
}
|
||||
|
||||
if hasNull {
|
||||
parts := splitGJSONPath(p)
|
||||
if len(parts) >= 3 && parts[len(parts)-3] == "properties" {
|
||||
fieldNameEscaped := parts[len(parts)-2]
|
||||
fieldName := unescapeGJSONPathKey(fieldNameEscaped)
|
||||
objectPath := strings.Join(parts[:len(parts)-3], ".")
|
||||
nullableFields[objectPath] = append(nullableFields[objectPath], fieldName)
|
||||
|
||||
propPath := joinPath(objectPath, "properties."+fieldNameEscaped)
|
||||
jsonStr = appendHint(jsonStr, propPath, "(nullable)")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for objectPath, fields := range nullableFields {
|
||||
reqPath := joinPath(objectPath, "required")
|
||||
req := gjson.Get(jsonStr, reqPath)
|
||||
if !req.IsArray() {
|
||||
continue
|
||||
}
|
||||
|
||||
var filtered []string
|
||||
for _, r := range req.Array() {
|
||||
if !contains(fields, r.String()) {
|
||||
filtered = append(filtered, r.String())
|
||||
}
|
||||
}
|
||||
|
||||
if len(filtered) == 0 {
|
||||
jsonStr, _ = sjson.Delete(jsonStr, reqPath)
|
||||
} else {
|
||||
jsonStr, _ = sjson.Set(jsonStr, reqPath, filtered)
|
||||
}
|
||||
}
|
||||
return jsonStr
|
||||
}
|
||||
|
||||
func removeUnsupportedKeywords(jsonStr string) string {
|
||||
keywords := append(unsupportedConstraints,
|
||||
"$schema", "$defs", "definitions", "const", "$ref", "additionalProperties",
|
||||
"propertyNames", // Gemini doesn't support property name validation
|
||||
)
|
||||
for _, key := range keywords {
|
||||
for _, p := range findPaths(jsonStr, key) {
|
||||
if isPropertyDefinition(trimSuffix(p, "."+key)) {
|
||||
continue
|
||||
}
|
||||
jsonStr, _ = sjson.Delete(jsonStr, p)
|
||||
}
|
||||
}
|
||||
return jsonStr
|
||||
}
|
||||
|
||||
func cleanupRequiredFields(jsonStr string) string {
|
||||
for _, p := range findPaths(jsonStr, "required") {
|
||||
parentPath := trimSuffix(p, ".required")
|
||||
propsPath := joinPath(parentPath, "properties")
|
||||
|
||||
req := gjson.Get(jsonStr, p)
|
||||
props := gjson.Get(jsonStr, propsPath)
|
||||
if !req.IsArray() || !props.IsObject() {
|
||||
continue
|
||||
}
|
||||
|
||||
var valid []string
|
||||
for _, r := range req.Array() {
|
||||
key := r.String()
|
||||
if props.Get(escapeGJSONPathKey(key)).Exists() {
|
||||
valid = append(valid, key)
|
||||
}
|
||||
}
|
||||
|
||||
if len(valid) != len(req.Array()) {
|
||||
if len(valid) == 0 {
|
||||
jsonStr, _ = sjson.Delete(jsonStr, p)
|
||||
} else {
|
||||
jsonStr, _ = sjson.Set(jsonStr, p, valid)
|
||||
}
|
||||
}
|
||||
}
|
||||
return jsonStr
|
||||
}
|
||||
|
||||
// --- Helpers ---
|
||||
|
||||
func findPaths(jsonStr, field string) []string {
|
||||
var paths []string
|
||||
Walk(gjson.Parse(jsonStr), "", field, &paths)
|
||||
return paths
|
||||
}
|
||||
|
||||
func sortByDepth(paths []string) {
|
||||
sort.Slice(paths, func(i, j int) bool { return len(paths[i]) > len(paths[j]) })
|
||||
}
|
||||
|
||||
func trimSuffix(path, suffix string) string {
|
||||
if path == strings.TrimPrefix(suffix, ".") {
|
||||
return ""
|
||||
}
|
||||
return strings.TrimSuffix(path, suffix)
|
||||
}
|
||||
|
||||
func joinPath(base, suffix string) string {
|
||||
if base == "" {
|
||||
return suffix
|
||||
}
|
||||
return base + "." + suffix
|
||||
}
|
||||
|
||||
func setRawAt(jsonStr, path, value string) string {
|
||||
if path == "" {
|
||||
return value
|
||||
}
|
||||
result, _ := sjson.SetRaw(jsonStr, path, value)
|
||||
return result
|
||||
}
|
||||
|
||||
func isPropertyDefinition(path string) bool {
|
||||
return path == "properties" || strings.HasSuffix(path, ".properties")
|
||||
}
|
||||
|
||||
func descriptionPath(parentPath string) string {
|
||||
if parentPath == "" || parentPath == "@this" {
|
||||
return "description"
|
||||
}
|
||||
return parentPath + ".description"
|
||||
}
|
||||
|
||||
func appendHint(jsonStr, parentPath, hint string) string {
|
||||
descPath := parentPath + ".description"
|
||||
if parentPath == "" || parentPath == "@this" {
|
||||
descPath = "description"
|
||||
}
|
||||
existing := gjson.Get(jsonStr, descPath).String()
|
||||
if existing != "" {
|
||||
hint = fmt.Sprintf("%s (%s)", existing, hint)
|
||||
}
|
||||
jsonStr, _ = sjson.Set(jsonStr, descPath, hint)
|
||||
return jsonStr
|
||||
}
|
||||
|
||||
func appendHintRaw(jsonRaw, hint string) string {
|
||||
existing := gjson.Get(jsonRaw, "description").String()
|
||||
if existing != "" {
|
||||
hint = fmt.Sprintf("%s (%s)", existing, hint)
|
||||
}
|
||||
jsonRaw, _ = sjson.Set(jsonRaw, "description", hint)
|
||||
return jsonRaw
|
||||
}
|
||||
|
||||
func getStrings(jsonStr, path string) []string {
|
||||
var result []string
|
||||
if arr := gjson.Get(jsonStr, path); arr.IsArray() {
|
||||
for _, r := range arr.Array() {
|
||||
result = append(result, r.String())
|
||||
}
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
func contains(slice []string, item string) bool {
|
||||
for _, s := range slice {
|
||||
if s == item {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func orDefault(val, def string) string {
|
||||
if val == "" {
|
||||
return def
|
||||
}
|
||||
return val
|
||||
}
|
||||
|
||||
func escapeGJSONPathKey(key string) string {
|
||||
return gjsonPathKeyReplacer.Replace(key)
|
||||
}
|
||||
|
||||
func unescapeGJSONPathKey(key string) string {
|
||||
if !strings.Contains(key, "\\") {
|
||||
return key
|
||||
}
|
||||
var b strings.Builder
|
||||
b.Grow(len(key))
|
||||
for i := 0; i < len(key); i++ {
|
||||
if key[i] == '\\' && i+1 < len(key) {
|
||||
i++
|
||||
b.WriteByte(key[i])
|
||||
continue
|
||||
}
|
||||
b.WriteByte(key[i])
|
||||
}
|
||||
return b.String()
|
||||
}
|
||||
|
||||
func splitGJSONPath(path string) []string {
|
||||
if path == "" {
|
||||
return nil
|
||||
}
|
||||
|
||||
parts := make([]string, 0, strings.Count(path, ".")+1)
|
||||
var b strings.Builder
|
||||
b.Grow(len(path))
|
||||
|
||||
for i := 0; i < len(path); i++ {
|
||||
c := path[i]
|
||||
if c == '\\' && i+1 < len(path) {
|
||||
b.WriteByte('\\')
|
||||
i++
|
||||
b.WriteByte(path[i])
|
||||
continue
|
||||
}
|
||||
if c == '.' {
|
||||
parts = append(parts, b.String())
|
||||
b.Reset()
|
||||
continue
|
||||
}
|
||||
b.WriteByte(c)
|
||||
}
|
||||
parts = append(parts, b.String())
|
||||
return parts
|
||||
}
|
||||
|
||||
func mergeDescriptionRaw(schemaRaw, parentDesc string) string {
|
||||
childDesc := gjson.Get(schemaRaw, "description").String()
|
||||
switch {
|
||||
case childDesc == "":
|
||||
schemaRaw, _ = sjson.Set(schemaRaw, "description", parentDesc)
|
||||
return schemaRaw
|
||||
case childDesc == parentDesc:
|
||||
return schemaRaw
|
||||
default:
|
||||
combined := fmt.Sprintf("%s (%s)", parentDesc, childDesc)
|
||||
schemaRaw, _ = sjson.Set(schemaRaw, "description", combined)
|
||||
return schemaRaw
|
||||
}
|
||||
}
|
||||
678
internal/util/gemini_schema_test.go
Normal file
678
internal/util/gemini_schema_test.go
Normal file
@@ -0,0 +1,678 @@
|
||||
package util
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"reflect"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestCleanJSONSchemaForGemini_ConstToEnum(t *testing.T) {
|
||||
input := `{
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"kind": {
|
||||
"type": "string",
|
||||
"const": "InsightVizNode"
|
||||
}
|
||||
}
|
||||
}`
|
||||
|
||||
expected := `{
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"kind": {
|
||||
"type": "string",
|
||||
"enum": ["InsightVizNode"]
|
||||
}
|
||||
}
|
||||
}`
|
||||
|
||||
result := CleanJSONSchemaForGemini(input)
|
||||
compareJSON(t, expected, result)
|
||||
}
|
||||
|
||||
func TestCleanJSONSchemaForGemini_TypeFlattening_Nullable(t *testing.T) {
|
||||
input := `{
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"name": {
|
||||
"type": ["string", "null"]
|
||||
},
|
||||
"other": {
|
||||
"type": "string"
|
||||
}
|
||||
},
|
||||
"required": ["name", "other"]
|
||||
}`
|
||||
|
||||
expected := `{
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"name": {
|
||||
"type": "string",
|
||||
"description": "(nullable)"
|
||||
},
|
||||
"other": {
|
||||
"type": "string"
|
||||
}
|
||||
},
|
||||
"required": ["other"]
|
||||
}`
|
||||
|
||||
result := CleanJSONSchemaForGemini(input)
|
||||
compareJSON(t, expected, result)
|
||||
}
|
||||
|
||||
func TestCleanJSONSchemaForGemini_ConstraintsToDescription(t *testing.T) {
|
||||
input := `{
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"tags": {
|
||||
"type": "array",
|
||||
"description": "List of tags",
|
||||
"minItems": 1
|
||||
},
|
||||
"name": {
|
||||
"type": "string",
|
||||
"description": "User name",
|
||||
"minLength": 3
|
||||
}
|
||||
}
|
||||
}`
|
||||
|
||||
result := CleanJSONSchemaForGemini(input)
|
||||
|
||||
// minItems should be REMOVED and moved to description
|
||||
if strings.Contains(result, `"minItems"`) {
|
||||
t.Errorf("minItems keyword should be removed")
|
||||
}
|
||||
if !strings.Contains(result, "minItems: 1") {
|
||||
t.Errorf("minItems hint missing in description")
|
||||
}
|
||||
|
||||
// minLength should be moved to description
|
||||
if !strings.Contains(result, "minLength: 3") {
|
||||
t.Errorf("minLength hint missing in description")
|
||||
}
|
||||
if strings.Contains(result, `"minLength":`) || strings.Contains(result, `"minLength" :`) {
|
||||
t.Errorf("minLength keyword should be removed")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCleanJSONSchemaForGemini_AnyOfFlattening_SmartSelection(t *testing.T) {
|
||||
input := `{
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"query": {
|
||||
"anyOf": [
|
||||
{ "type": "null" },
|
||||
{
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"kind": { "type": "string" }
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
}`
|
||||
|
||||
expected := `{
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"query": {
|
||||
"type": "object",
|
||||
"description": "Accepts: null | object",
|
||||
"properties": {
|
||||
"kind": { "type": "string" }
|
||||
}
|
||||
}
|
||||
}
|
||||
}`
|
||||
|
||||
result := CleanJSONSchemaForGemini(input)
|
||||
compareJSON(t, expected, result)
|
||||
}
|
||||
|
||||
func TestCleanJSONSchemaForGemini_OneOfFlattening(t *testing.T) {
|
||||
input := `{
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"config": {
|
||||
"oneOf": [
|
||||
{ "type": "string" },
|
||||
{ "type": "integer" }
|
||||
]
|
||||
}
|
||||
}
|
||||
}`
|
||||
|
||||
expected := `{
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"config": {
|
||||
"type": "string",
|
||||
"description": "Accepts: string | integer"
|
||||
}
|
||||
}
|
||||
}`
|
||||
|
||||
result := CleanJSONSchemaForGemini(input)
|
||||
compareJSON(t, expected, result)
|
||||
}
|
||||
|
||||
func TestCleanJSONSchemaForGemini_AllOfMerging(t *testing.T) {
|
||||
input := `{
|
||||
"type": "object",
|
||||
"allOf": [
|
||||
{
|
||||
"properties": {
|
||||
"a": { "type": "string" }
|
||||
},
|
||||
"required": ["a"]
|
||||
},
|
||||
{
|
||||
"properties": {
|
||||
"b": { "type": "integer" }
|
||||
},
|
||||
"required": ["b"]
|
||||
}
|
||||
]
|
||||
}`
|
||||
|
||||
expected := `{
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"a": { "type": "string" },
|
||||
"b": { "type": "integer" }
|
||||
},
|
||||
"required": ["a", "b"]
|
||||
}`
|
||||
|
||||
result := CleanJSONSchemaForGemini(input)
|
||||
compareJSON(t, expected, result)
|
||||
}
|
||||
|
||||
func TestCleanJSONSchemaForGemini_RefHandling(t *testing.T) {
|
||||
input := `{
|
||||
"definitions": {
|
||||
"User": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"name": { "type": "string" }
|
||||
}
|
||||
}
|
||||
},
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"customer": { "$ref": "#/definitions/User" }
|
||||
}
|
||||
}`
|
||||
|
||||
expected := `{
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"customer": {
|
||||
"type": "object",
|
||||
"description": "See: User"
|
||||
}
|
||||
}
|
||||
}`
|
||||
|
||||
result := CleanJSONSchemaForGemini(input)
|
||||
compareJSON(t, expected, result)
|
||||
}
|
||||
|
||||
func TestCleanJSONSchemaForGemini_RefHandling_DescriptionEscaping(t *testing.T) {
|
||||
input := `{
|
||||
"definitions": {
|
||||
"User": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"name": { "type": "string" }
|
||||
}
|
||||
}
|
||||
},
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"customer": {
|
||||
"description": "He said \"hi\"\\nsecond line",
|
||||
"$ref": "#/definitions/User"
|
||||
}
|
||||
}
|
||||
}`
|
||||
|
||||
expected := `{
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"customer": {
|
||||
"type": "object",
|
||||
"description": "He said \"hi\"\\nsecond line (See: User)"
|
||||
}
|
||||
}
|
||||
}`
|
||||
|
||||
result := CleanJSONSchemaForGemini(input)
|
||||
compareJSON(t, expected, result)
|
||||
}
|
||||
|
||||
func TestCleanJSONSchemaForGemini_CyclicRefDefaults(t *testing.T) {
|
||||
input := `{
|
||||
"definitions": {
|
||||
"Node": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"child": { "$ref": "#/definitions/Node" }
|
||||
}
|
||||
}
|
||||
},
|
||||
"$ref": "#/definitions/Node"
|
||||
}`
|
||||
|
||||
result := CleanJSONSchemaForGemini(input)
|
||||
|
||||
var resMap map[string]interface{}
|
||||
json.Unmarshal([]byte(result), &resMap)
|
||||
|
||||
if resMap["type"] != "object" {
|
||||
t.Errorf("Expected type: object, got: %v", resMap["type"])
|
||||
}
|
||||
|
||||
desc, ok := resMap["description"].(string)
|
||||
if !ok || !strings.Contains(desc, "Node") {
|
||||
t.Errorf("Expected description hint containing 'Node', got: %v", resMap["description"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestCleanJSONSchemaForGemini_RequiredCleanup(t *testing.T) {
|
||||
input := `{
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"a": {"type": "string"},
|
||||
"b": {"type": "string"}
|
||||
},
|
||||
"required": ["a", "b", "c"]
|
||||
}`
|
||||
|
||||
expected := `{
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"a": {"type": "string"},
|
||||
"b": {"type": "string"}
|
||||
},
|
||||
"required": ["a", "b"]
|
||||
}`
|
||||
|
||||
result := CleanJSONSchemaForGemini(input)
|
||||
compareJSON(t, expected, result)
|
||||
}
|
||||
|
||||
func TestCleanJSONSchemaForGemini_AllOfMerging_DotKeys(t *testing.T) {
|
||||
input := `{
|
||||
"type": "object",
|
||||
"allOf": [
|
||||
{
|
||||
"properties": {
|
||||
"my.param": { "type": "string" }
|
||||
},
|
||||
"required": ["my.param"]
|
||||
},
|
||||
{
|
||||
"properties": {
|
||||
"b": { "type": "integer" }
|
||||
},
|
||||
"required": ["b"]
|
||||
}
|
||||
]
|
||||
}`
|
||||
|
||||
expected := `{
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"my.param": { "type": "string" },
|
||||
"b": { "type": "integer" }
|
||||
},
|
||||
"required": ["my.param", "b"]
|
||||
}`
|
||||
|
||||
result := CleanJSONSchemaForGemini(input)
|
||||
compareJSON(t, expected, result)
|
||||
}
|
||||
|
||||
func TestCleanJSONSchemaForGemini_PropertyNameCollision(t *testing.T) {
|
||||
// A tool has an argument named "pattern" - should NOT be treated as a constraint
|
||||
input := `{
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"pattern": {
|
||||
"type": "string",
|
||||
"description": "The regex pattern"
|
||||
}
|
||||
},
|
||||
"required": ["pattern"]
|
||||
}`
|
||||
|
||||
expected := `{
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"pattern": {
|
||||
"type": "string",
|
||||
"description": "The regex pattern"
|
||||
}
|
||||
},
|
||||
"required": ["pattern"]
|
||||
}`
|
||||
|
||||
result := CleanJSONSchemaForGemini(input)
|
||||
compareJSON(t, expected, result)
|
||||
|
||||
var resMap map[string]interface{}
|
||||
json.Unmarshal([]byte(result), &resMap)
|
||||
props, _ := resMap["properties"].(map[string]interface{})
|
||||
if _, ok := props["description"]; ok {
|
||||
t.Errorf("Invalid 'description' property injected into properties map")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCleanJSONSchemaForGemini_DotKeys(t *testing.T) {
|
||||
input := `{
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"my.param": {
|
||||
"type": "string",
|
||||
"$ref": "#/definitions/MyType"
|
||||
}
|
||||
},
|
||||
"definitions": {
|
||||
"MyType": { "type": "string" }
|
||||
}
|
||||
}`
|
||||
|
||||
result := CleanJSONSchemaForGemini(input)
|
||||
|
||||
var resMap map[string]interface{}
|
||||
if err := json.Unmarshal([]byte(result), &resMap); err != nil {
|
||||
t.Fatalf("Failed to unmarshal result: %v", err)
|
||||
}
|
||||
|
||||
props, ok := resMap["properties"].(map[string]interface{})
|
||||
if !ok {
|
||||
t.Fatalf("properties missing")
|
||||
}
|
||||
|
||||
if val, ok := props["my.param"]; !ok {
|
||||
t.Fatalf("Key 'my.param' is missing. Result: %s", result)
|
||||
} else {
|
||||
valMap, _ := val.(map[string]interface{})
|
||||
if _, hasRef := valMap["$ref"]; hasRef {
|
||||
t.Errorf("Key 'my.param' still contains $ref")
|
||||
}
|
||||
if _, ok := props["my"]; ok {
|
||||
t.Errorf("Artifact key 'my' created by sjson splitting")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestCleanJSONSchemaForGemini_AnyOfAlternativeHints(t *testing.T) {
|
||||
input := `{
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"value": {
|
||||
"anyOf": [
|
||||
{ "type": "string" },
|
||||
{ "type": "integer" },
|
||||
{ "type": "null" }
|
||||
]
|
||||
}
|
||||
}
|
||||
}`
|
||||
|
||||
result := CleanJSONSchemaForGemini(input)
|
||||
|
||||
if !strings.Contains(result, "Accepts:") {
|
||||
t.Errorf("Expected alternative types hint, got: %s", result)
|
||||
}
|
||||
if !strings.Contains(result, "string") || !strings.Contains(result, "integer") {
|
||||
t.Errorf("Expected all alternative types in hint, got: %s", result)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCleanJSONSchemaForGemini_NullableHint(t *testing.T) {
|
||||
input := `{
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"name": {
|
||||
"type": ["string", "null"],
|
||||
"description": "User name"
|
||||
}
|
||||
},
|
||||
"required": ["name"]
|
||||
}`
|
||||
|
||||
result := CleanJSONSchemaForGemini(input)
|
||||
|
||||
if !strings.Contains(result, "(nullable)") {
|
||||
t.Errorf("Expected nullable hint, got: %s", result)
|
||||
}
|
||||
if !strings.Contains(result, "User name") {
|
||||
t.Errorf("Expected original description to be preserved, got: %s", result)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCleanJSONSchemaForGemini_TypeFlattening_Nullable_DotKey(t *testing.T) {
|
||||
input := `{
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"my.param": {
|
||||
"type": ["string", "null"]
|
||||
},
|
||||
"other": {
|
||||
"type": "string"
|
||||
}
|
||||
},
|
||||
"required": ["my.param", "other"]
|
||||
}`
|
||||
|
||||
expected := `{
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"my.param": {
|
||||
"type": "string",
|
||||
"description": "(nullable)"
|
||||
},
|
||||
"other": {
|
||||
"type": "string"
|
||||
}
|
||||
},
|
||||
"required": ["other"]
|
||||
}`
|
||||
|
||||
result := CleanJSONSchemaForGemini(input)
|
||||
compareJSON(t, expected, result)
|
||||
}
|
||||
|
||||
func TestCleanJSONSchemaForGemini_EnumHint(t *testing.T) {
|
||||
input := `{
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"status": {
|
||||
"type": "string",
|
||||
"enum": ["active", "inactive", "pending"],
|
||||
"description": "Current status"
|
||||
}
|
||||
}
|
||||
}`
|
||||
|
||||
result := CleanJSONSchemaForGemini(input)
|
||||
|
||||
if !strings.Contains(result, "Allowed:") {
|
||||
t.Errorf("Expected enum values hint, got: %s", result)
|
||||
}
|
||||
if !strings.Contains(result, "active") || !strings.Contains(result, "inactive") {
|
||||
t.Errorf("Expected enum values in hint, got: %s", result)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCleanJSONSchemaForGemini_AdditionalPropertiesHint(t *testing.T) {
|
||||
input := `{
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"name": { "type": "string" }
|
||||
},
|
||||
"additionalProperties": false
|
||||
}`
|
||||
|
||||
result := CleanJSONSchemaForGemini(input)
|
||||
|
||||
if !strings.Contains(result, "No extra properties allowed") {
|
||||
t.Errorf("Expected additionalProperties hint, got: %s", result)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCleanJSONSchemaForGemini_AnyOfFlattening_PreservesDescription(t *testing.T) {
|
||||
input := `{
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"config": {
|
||||
"description": "Parent desc",
|
||||
"anyOf": [
|
||||
{ "type": "string", "description": "Child desc" },
|
||||
{ "type": "integer" }
|
||||
]
|
||||
}
|
||||
}
|
||||
}`
|
||||
|
||||
expected := `{
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"config": {
|
||||
"type": "string",
|
||||
"description": "Parent desc (Child desc) (Accepts: string | integer)"
|
||||
}
|
||||
}
|
||||
}`
|
||||
|
||||
result := CleanJSONSchemaForGemini(input)
|
||||
compareJSON(t, expected, result)
|
||||
}
|
||||
|
||||
func TestCleanJSONSchemaForGemini_SingleEnumNoHint(t *testing.T) {
|
||||
input := `{
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"kind": {
|
||||
"type": "string",
|
||||
"enum": ["fixed"]
|
||||
}
|
||||
}
|
||||
}`
|
||||
|
||||
result := CleanJSONSchemaForGemini(input)
|
||||
|
||||
if strings.Contains(result, "Allowed:") {
|
||||
t.Errorf("Single value enum should not add Allowed hint, got: %s", result)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCleanJSONSchemaForGemini_MultipleNonNullTypes(t *testing.T) {
|
||||
input := `{
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"value": {
|
||||
"type": ["string", "integer", "boolean"]
|
||||
}
|
||||
}
|
||||
}`
|
||||
|
||||
result := CleanJSONSchemaForGemini(input)
|
||||
|
||||
if !strings.Contains(result, "Accepts:") {
|
||||
t.Errorf("Expected multiple types hint, got: %s", result)
|
||||
}
|
||||
if !strings.Contains(result, "string") || !strings.Contains(result, "integer") || !strings.Contains(result, "boolean") {
|
||||
t.Errorf("Expected all types in hint, got: %s", result)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCleanJSONSchemaForGemini_PropertyNamesRemoval(t *testing.T) {
|
||||
// propertyNames is used to validate object property names (e.g., must match a pattern)
|
||||
// Gemini doesn't support this keyword and will reject requests containing it
|
||||
input := `{
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"metadata": {
|
||||
"type": "object",
|
||||
"propertyNames": {
|
||||
"pattern": "^[a-zA-Z_][a-zA-Z0-9_]*$"
|
||||
},
|
||||
"additionalProperties": {
|
||||
"type": "string"
|
||||
}
|
||||
}
|
||||
}
|
||||
}`
|
||||
|
||||
expected := `{
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"metadata": {
|
||||
"type": "object"
|
||||
}
|
||||
}
|
||||
}`
|
||||
|
||||
result := CleanJSONSchemaForGemini(input)
|
||||
compareJSON(t, expected, result)
|
||||
|
||||
// Verify propertyNames is completely removed
|
||||
if strings.Contains(result, "propertyNames") {
|
||||
t.Errorf("propertyNames keyword should be removed, got: %s", result)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCleanJSONSchemaForGemini_PropertyNamesRemoval_Nested(t *testing.T) {
|
||||
// Test deeply nested propertyNames (as seen in real Claude tool schemas)
|
||||
input := `{
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"items": {
|
||||
"type": "array",
|
||||
"items": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"config": {
|
||||
"type": "object",
|
||||
"propertyNames": {
|
||||
"type": "string"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}`
|
||||
|
||||
result := CleanJSONSchemaForGemini(input)
|
||||
|
||||
if strings.Contains(result, "propertyNames") {
|
||||
t.Errorf("Nested propertyNames should be removed, got: %s", result)
|
||||
}
|
||||
}
|
||||
|
||||
func compareJSON(t *testing.T, expectedJSON, actualJSON string) {
|
||||
var expMap, actMap map[string]interface{}
|
||||
errExp := json.Unmarshal([]byte(expectedJSON), &expMap)
|
||||
errAct := json.Unmarshal([]byte(actualJSON), &actMap)
|
||||
|
||||
if errExp != nil || errAct != nil {
|
||||
t.Fatalf("JSON Unmarshal error. Exp: %v, Act: %v", errExp, errAct)
|
||||
}
|
||||
|
||||
if !reflect.DeepEqual(expMap, actMap) {
|
||||
expBytes, _ := json.MarshalIndent(expMap, "", " ")
|
||||
actBytes, _ := json.MarshalIndent(actMap, "", " ")
|
||||
t.Errorf("JSON mismatch:\nExpected:\n%s\n\nActual:\n%s", string(expBytes), string(actBytes))
|
||||
}
|
||||
}
|
||||
@@ -1,6 +1,7 @@
|
||||
package util
|
||||
|
||||
import (
|
||||
"regexp"
|
||||
"strings"
|
||||
|
||||
"github.com/tidwall/gjson"
|
||||
@@ -13,6 +14,44 @@ const (
|
||||
GeminiOriginalModelMetadataKey = "gemini_original_model"
|
||||
)
|
||||
|
||||
// Gemini model family detection patterns
|
||||
var (
|
||||
gemini3Pattern = regexp.MustCompile(`(?i)^gemini[_-]?3[_-]`)
|
||||
gemini3ProPattern = regexp.MustCompile(`(?i)^gemini[_-]?3[_-]pro`)
|
||||
gemini3FlashPattern = regexp.MustCompile(`(?i)^gemini[_-]?3[_-]flash`)
|
||||
gemini25Pattern = regexp.MustCompile(`(?i)^gemini[_-]?2\.5[_-]`)
|
||||
)
|
||||
|
||||
// IsGemini3Model returns true if the model is a Gemini 3 family model.
|
||||
// Gemini 3 models should use thinkingLevel (string) instead of thinkingBudget (number).
|
||||
func IsGemini3Model(model string) bool {
|
||||
return gemini3Pattern.MatchString(model)
|
||||
}
|
||||
|
||||
// IsGemini3ProModel returns true if the model is a Gemini 3 Pro variant.
|
||||
// Gemini 3 Pro supports thinkingLevel: "low", "high" (default: "high")
|
||||
func IsGemini3ProModel(model string) bool {
|
||||
return gemini3ProPattern.MatchString(model)
|
||||
}
|
||||
|
||||
// IsGemini3FlashModel returns true if the model is a Gemini 3 Flash variant.
|
||||
// Gemini 3 Flash supports thinkingLevel: "minimal", "low", "medium", "high" (default: "high")
|
||||
func IsGemini3FlashModel(model string) bool {
|
||||
return gemini3FlashPattern.MatchString(model)
|
||||
}
|
||||
|
||||
// IsGemini25Model returns true if the model is a Gemini 2.5 family model.
|
||||
// Gemini 2.5 models should use thinkingBudget (number).
|
||||
func IsGemini25Model(model string) bool {
|
||||
return gemini25Pattern.MatchString(model)
|
||||
}
|
||||
|
||||
// Gemini3ProThinkingLevels are the valid thinkingLevel values for Gemini 3 Pro models.
|
||||
var Gemini3ProThinkingLevels = []string{"low", "high"}
|
||||
|
||||
// Gemini3FlashThinkingLevels are the valid thinkingLevel values for Gemini 3 Flash models.
|
||||
var Gemini3FlashThinkingLevels = []string{"minimal", "low", "medium", "high"}
|
||||
|
||||
func ApplyGeminiThinkingConfig(body []byte, budget *int, includeThoughts *bool) []byte {
|
||||
if budget == nil && includeThoughts == nil {
|
||||
return body
|
||||
@@ -69,10 +108,153 @@ func ApplyGeminiCLIThinkingConfig(body []byte, budget *int, includeThoughts *boo
|
||||
return updated
|
||||
}
|
||||
|
||||
// ApplyGeminiThinkingLevel applies thinkingLevel config for Gemini 3 models.
|
||||
// For standard Gemini API format (generationConfig.thinkingConfig path).
|
||||
// Per Google's documentation, Gemini 3 models should use thinkingLevel instead of thinkingBudget.
|
||||
func ApplyGeminiThinkingLevel(body []byte, level string, includeThoughts *bool) []byte {
|
||||
if level == "" && includeThoughts == nil {
|
||||
return body
|
||||
}
|
||||
updated := body
|
||||
if level != "" {
|
||||
valuePath := "generationConfig.thinkingConfig.thinkingLevel"
|
||||
rewritten, err := sjson.SetBytes(updated, valuePath, level)
|
||||
if err == nil {
|
||||
updated = rewritten
|
||||
}
|
||||
}
|
||||
// Default to including thoughts when a level is set but no explicit include flag is provided.
|
||||
incl := includeThoughts
|
||||
if incl == nil && level != "" {
|
||||
defaultInclude := true
|
||||
incl = &defaultInclude
|
||||
}
|
||||
if incl != nil {
|
||||
valuePath := "generationConfig.thinkingConfig.includeThoughts"
|
||||
rewritten, err := sjson.SetBytes(updated, valuePath, *incl)
|
||||
if err == nil {
|
||||
updated = rewritten
|
||||
}
|
||||
}
|
||||
if it := gjson.GetBytes(body, "generationConfig.thinkingConfig.include_thoughts"); it.Exists() {
|
||||
updated, _ = sjson.DeleteBytes(updated, "generationConfig.thinkingConfig.include_thoughts")
|
||||
}
|
||||
if tb := gjson.GetBytes(body, "generationConfig.thinkingConfig.thinkingBudget"); tb.Exists() {
|
||||
updated, _ = sjson.DeleteBytes(updated, "generationConfig.thinkingConfig.thinkingBudget")
|
||||
}
|
||||
return updated
|
||||
}
|
||||
|
||||
// ApplyGeminiCLIThinkingLevel applies thinkingLevel config for Gemini 3 models.
|
||||
// For Gemini CLI API format (request.generationConfig.thinkingConfig path).
|
||||
// Per Google's documentation, Gemini 3 models should use thinkingLevel instead of thinkingBudget.
|
||||
func ApplyGeminiCLIThinkingLevel(body []byte, level string, includeThoughts *bool) []byte {
|
||||
if level == "" && includeThoughts == nil {
|
||||
return body
|
||||
}
|
||||
updated := body
|
||||
if level != "" {
|
||||
valuePath := "request.generationConfig.thinkingConfig.thinkingLevel"
|
||||
rewritten, err := sjson.SetBytes(updated, valuePath, level)
|
||||
if err == nil {
|
||||
updated = rewritten
|
||||
}
|
||||
}
|
||||
// Default to including thoughts when a level is set but no explicit include flag is provided.
|
||||
incl := includeThoughts
|
||||
if incl == nil && level != "" {
|
||||
defaultInclude := true
|
||||
incl = &defaultInclude
|
||||
}
|
||||
if incl != nil {
|
||||
valuePath := "request.generationConfig.thinkingConfig.includeThoughts"
|
||||
rewritten, err := sjson.SetBytes(updated, valuePath, *incl)
|
||||
if err == nil {
|
||||
updated = rewritten
|
||||
}
|
||||
}
|
||||
if it := gjson.GetBytes(body, "request.generationConfig.thinkingConfig.include_thoughts"); it.Exists() {
|
||||
updated, _ = sjson.DeleteBytes(updated, "request.generationConfig.thinkingConfig.include_thoughts")
|
||||
}
|
||||
if tb := gjson.GetBytes(body, "request.generationConfig.thinkingConfig.thinkingBudget"); tb.Exists() {
|
||||
updated, _ = sjson.DeleteBytes(updated, "request.generationConfig.thinkingConfig.thinkingBudget")
|
||||
}
|
||||
return updated
|
||||
}
|
||||
|
||||
// ValidateGemini3ThinkingLevel validates that the thinkingLevel is valid for the Gemini 3 model variant.
|
||||
// Returns the validated level (normalized to lowercase) and true if valid, or empty string and false if invalid.
|
||||
func ValidateGemini3ThinkingLevel(model, level string) (string, bool) {
|
||||
if level == "" {
|
||||
return "", false
|
||||
}
|
||||
normalized := strings.ToLower(strings.TrimSpace(level))
|
||||
|
||||
var validLevels []string
|
||||
if IsGemini3ProModel(model) {
|
||||
validLevels = Gemini3ProThinkingLevels
|
||||
} else if IsGemini3FlashModel(model) {
|
||||
validLevels = Gemini3FlashThinkingLevels
|
||||
} else if IsGemini3Model(model) {
|
||||
// Unknown Gemini 3 variant - allow all levels as fallback
|
||||
validLevels = Gemini3FlashThinkingLevels
|
||||
} else {
|
||||
return "", false
|
||||
}
|
||||
|
||||
for _, valid := range validLevels {
|
||||
if normalized == valid {
|
||||
return normalized, true
|
||||
}
|
||||
}
|
||||
return "", false
|
||||
}
|
||||
|
||||
// ThinkingBudgetToGemini3Level converts a thinkingBudget to a thinkingLevel for Gemini 3 models.
|
||||
// This provides backward compatibility when thinkingBudget is provided for Gemini 3 models.
|
||||
// Returns the appropriate thinkingLevel and true if conversion is possible.
|
||||
func ThinkingBudgetToGemini3Level(model string, budget int) (string, bool) {
|
||||
if !IsGemini3Model(model) {
|
||||
return "", false
|
||||
}
|
||||
|
||||
// Map budget to level based on Google's documentation
|
||||
// Gemini 3 Pro: "low", "high" (default: "high")
|
||||
// Gemini 3 Flash: "minimal", "low", "medium", "high" (default: "high")
|
||||
switch {
|
||||
case budget == -1:
|
||||
// Dynamic budget maps to "high" (API default)
|
||||
return "high", true
|
||||
case budget == 0:
|
||||
// Zero budget - Gemini 3 doesn't support disabling thinking
|
||||
// Map to lowest available level
|
||||
if IsGemini3FlashModel(model) {
|
||||
return "minimal", true
|
||||
}
|
||||
return "low", true
|
||||
case budget > 0 && budget <= 512:
|
||||
if IsGemini3FlashModel(model) {
|
||||
return "minimal", true
|
||||
}
|
||||
return "low", true
|
||||
case budget <= 1024:
|
||||
return "low", true
|
||||
case budget <= 8192:
|
||||
if IsGemini3FlashModel(model) {
|
||||
return "medium", true
|
||||
}
|
||||
return "low", true // Pro doesn't have medium, use low
|
||||
default:
|
||||
return "high", true
|
||||
}
|
||||
}
|
||||
|
||||
// modelsWithDefaultThinking lists models that should have thinking enabled by default
|
||||
// when no explicit thinkingConfig is provided.
|
||||
var modelsWithDefaultThinking = map[string]bool{
|
||||
"gemini-3-pro-preview": true,
|
||||
"gemini-3-pro-preview": true,
|
||||
"gemini-3-pro-image-preview": true,
|
||||
// "gemini-3-flash-preview": true,
|
||||
}
|
||||
|
||||
// ModelHasDefaultThinking returns true if the model should have thinking enabled by default.
|
||||
@@ -83,6 +265,7 @@ func ModelHasDefaultThinking(model string) bool {
|
||||
// ApplyDefaultThinkingIfNeeded injects default thinkingConfig for models that require it.
|
||||
// For standard Gemini API format (generationConfig.thinkingConfig path).
|
||||
// Returns the modified body if thinkingConfig was added, otherwise returns the original.
|
||||
// For Gemini 3 models, uses thinkingLevel instead of thinkingBudget per Google's documentation.
|
||||
func ApplyDefaultThinkingIfNeeded(model string, body []byte) []byte {
|
||||
if !ModelHasDefaultThinking(model) {
|
||||
return body
|
||||
@@ -90,14 +273,59 @@ func ApplyDefaultThinkingIfNeeded(model string, body []byte) []byte {
|
||||
if gjson.GetBytes(body, "generationConfig.thinkingConfig").Exists() {
|
||||
return body
|
||||
}
|
||||
// Gemini 3 models use thinkingLevel instead of thinkingBudget
|
||||
if IsGemini3Model(model) {
|
||||
// Don't set a default - let the API use its dynamic default ("high")
|
||||
// Only set includeThoughts
|
||||
updated, _ := sjson.SetBytes(body, "generationConfig.thinkingConfig.includeThoughts", true)
|
||||
return updated
|
||||
}
|
||||
// Gemini 2.5 and other models use thinkingBudget
|
||||
updated, _ := sjson.SetBytes(body, "generationConfig.thinkingConfig.thinkingBudget", -1)
|
||||
updated, _ = sjson.SetBytes(updated, "generationConfig.thinkingConfig.include_thoughts", true)
|
||||
return updated
|
||||
}
|
||||
|
||||
// ApplyGemini3ThinkingLevelFromMetadata applies thinkingLevel from metadata for Gemini 3 models.
|
||||
// For standard Gemini API format (generationConfig.thinkingConfig path).
|
||||
// This handles the case where reasoning_effort is specified via model name suffix (e.g., model(minimal)).
|
||||
func ApplyGemini3ThinkingLevelFromMetadata(model string, metadata map[string]any, body []byte) []byte {
|
||||
if !IsGemini3Model(model) {
|
||||
return body
|
||||
}
|
||||
effort, ok := ReasoningEffortFromMetadata(metadata)
|
||||
if !ok || effort == "" {
|
||||
return body
|
||||
}
|
||||
// Validate and apply the thinkingLevel
|
||||
if level, valid := ValidateGemini3ThinkingLevel(model, effort); valid {
|
||||
return ApplyGeminiThinkingLevel(body, level, nil)
|
||||
}
|
||||
return body
|
||||
}
|
||||
|
||||
// ApplyGemini3ThinkingLevelFromMetadataCLI applies thinkingLevel from metadata for Gemini 3 models.
|
||||
// For Gemini CLI API format (request.generationConfig.thinkingConfig path).
|
||||
// This handles the case where reasoning_effort is specified via model name suffix (e.g., model(minimal)).
|
||||
func ApplyGemini3ThinkingLevelFromMetadataCLI(model string, metadata map[string]any, body []byte) []byte {
|
||||
if !IsGemini3Model(model) {
|
||||
return body
|
||||
}
|
||||
effort, ok := ReasoningEffortFromMetadata(metadata)
|
||||
if !ok || effort == "" {
|
||||
return body
|
||||
}
|
||||
// Validate and apply the thinkingLevel
|
||||
if level, valid := ValidateGemini3ThinkingLevel(model, effort); valid {
|
||||
return ApplyGeminiCLIThinkingLevel(body, level, nil)
|
||||
}
|
||||
return body
|
||||
}
|
||||
|
||||
// ApplyDefaultThinkingIfNeededCLI injects default thinkingConfig for models that require it.
|
||||
// For Gemini CLI API format (request.generationConfig.thinkingConfig path).
|
||||
// Returns the modified body if thinkingConfig was added, otherwise returns the original.
|
||||
// For Gemini 3 models, uses thinkingLevel instead of thinkingBudget per Google's documentation.
|
||||
func ApplyDefaultThinkingIfNeededCLI(model string, body []byte) []byte {
|
||||
if !ModelHasDefaultThinking(model) {
|
||||
return body
|
||||
@@ -105,6 +333,14 @@ func ApplyDefaultThinkingIfNeededCLI(model string, body []byte) []byte {
|
||||
if gjson.GetBytes(body, "request.generationConfig.thinkingConfig").Exists() {
|
||||
return body
|
||||
}
|
||||
// Gemini 3 models use thinkingLevel instead of thinkingBudget
|
||||
if IsGemini3Model(model) {
|
||||
// Don't set a default - let the API use its dynamic default ("high")
|
||||
// Only set includeThoughts
|
||||
updated, _ := sjson.SetBytes(body, "request.generationConfig.thinkingConfig.includeThoughts", true)
|
||||
return updated
|
||||
}
|
||||
// Gemini 2.5 and other models use thinkingBudget
|
||||
updated, _ := sjson.SetBytes(body, "request.generationConfig.thinkingConfig.thinkingBudget", -1)
|
||||
updated, _ = sjson.SetBytes(updated, "request.generationConfig.thinkingConfig.include_thoughts", true)
|
||||
return updated
|
||||
@@ -128,12 +364,31 @@ func StripThinkingConfigIfUnsupported(model string, body []byte) []byte {
|
||||
|
||||
// NormalizeGeminiThinkingBudget normalizes the thinkingBudget value in a standard Gemini
|
||||
// request body (generationConfig.thinkingConfig.thinkingBudget path).
|
||||
func NormalizeGeminiThinkingBudget(model string, body []byte) []byte {
|
||||
// For Gemini 3 models, converts thinkingBudget to thinkingLevel per Google's documentation,
|
||||
// unless skipGemini3Check is provided and true.
|
||||
func NormalizeGeminiThinkingBudget(model string, body []byte, skipGemini3Check ...bool) []byte {
|
||||
const budgetPath = "generationConfig.thinkingConfig.thinkingBudget"
|
||||
const levelPath = "generationConfig.thinkingConfig.thinkingLevel"
|
||||
|
||||
budget := gjson.GetBytes(body, budgetPath)
|
||||
if !budget.Exists() {
|
||||
return body
|
||||
}
|
||||
|
||||
// For Gemini 3 models, convert thinkingBudget to thinkingLevel
|
||||
skipGemini3 := len(skipGemini3Check) > 0 && skipGemini3Check[0]
|
||||
if IsGemini3Model(model) && !skipGemini3 {
|
||||
if level, ok := ThinkingBudgetToGemini3Level(model, int(budget.Int())); ok {
|
||||
updated, _ := sjson.SetBytes(body, levelPath, level)
|
||||
updated, _ = sjson.DeleteBytes(updated, budgetPath)
|
||||
return updated
|
||||
}
|
||||
// If conversion fails, just remove the budget (let API use default)
|
||||
updated, _ := sjson.DeleteBytes(body, budgetPath)
|
||||
return updated
|
||||
}
|
||||
|
||||
// For Gemini 2.5 and other models, normalize the budget value
|
||||
normalized := NormalizeThinkingBudget(model, int(budget.Int()))
|
||||
updated, _ := sjson.SetBytes(body, budgetPath, normalized)
|
||||
return updated
|
||||
@@ -141,56 +396,170 @@ func NormalizeGeminiThinkingBudget(model string, body []byte) []byte {
|
||||
|
||||
// NormalizeGeminiCLIThinkingBudget normalizes the thinkingBudget value in a Gemini CLI
|
||||
// request body (request.generationConfig.thinkingConfig.thinkingBudget path).
|
||||
func NormalizeGeminiCLIThinkingBudget(model string, body []byte) []byte {
|
||||
// For Gemini 3 models, converts thinkingBudget to thinkingLevel per Google's documentation,
|
||||
// unless skipGemini3Check is provided and true.
|
||||
func NormalizeGeminiCLIThinkingBudget(model string, body []byte, skipGemini3Check ...bool) []byte {
|
||||
const budgetPath = "request.generationConfig.thinkingConfig.thinkingBudget"
|
||||
const levelPath = "request.generationConfig.thinkingConfig.thinkingLevel"
|
||||
|
||||
budget := gjson.GetBytes(body, budgetPath)
|
||||
if !budget.Exists() {
|
||||
return body
|
||||
}
|
||||
|
||||
// For Gemini 3 models, convert thinkingBudget to thinkingLevel
|
||||
skipGemini3 := len(skipGemini3Check) > 0 && skipGemini3Check[0]
|
||||
if IsGemini3Model(model) && !skipGemini3 {
|
||||
if level, ok := ThinkingBudgetToGemini3Level(model, int(budget.Int())); ok {
|
||||
updated, _ := sjson.SetBytes(body, levelPath, level)
|
||||
updated, _ = sjson.DeleteBytes(updated, budgetPath)
|
||||
return updated
|
||||
}
|
||||
// If conversion fails, just remove the budget (let API use default)
|
||||
updated, _ := sjson.DeleteBytes(body, budgetPath)
|
||||
return updated
|
||||
}
|
||||
|
||||
// For Gemini 2.5 and other models, normalize the budget value
|
||||
normalized := NormalizeThinkingBudget(model, int(budget.Int()))
|
||||
updated, _ := sjson.SetBytes(body, budgetPath, normalized)
|
||||
return updated
|
||||
}
|
||||
|
||||
// ReasoningEffortBudgetMapping defines the thinkingBudget values for each reasoning effort level.
|
||||
var ReasoningEffortBudgetMapping = map[string]int{
|
||||
"none": 0,
|
||||
"auto": -1,
|
||||
"minimal": 512,
|
||||
"low": 1024,
|
||||
"medium": 8192,
|
||||
"high": 24576,
|
||||
"xhigh": 32768,
|
||||
}
|
||||
|
||||
// ApplyReasoningEffortToGemini applies OpenAI reasoning_effort to Gemini thinkingConfig
|
||||
// for standard Gemini API format (generationConfig.thinkingConfig path).
|
||||
// Returns the modified body with thinkingBudget and include_thoughts set.
|
||||
func ApplyReasoningEffortToGemini(body []byte, effort string) []byte {
|
||||
normalized := strings.ToLower(strings.TrimSpace(effort))
|
||||
if normalized == "" {
|
||||
return body
|
||||
}
|
||||
|
||||
budgetPath := "generationConfig.thinkingConfig.thinkingBudget"
|
||||
includePath := "generationConfig.thinkingConfig.include_thoughts"
|
||||
|
||||
if normalized == "none" {
|
||||
body, _ = sjson.DeleteBytes(body, "generationConfig.thinkingConfig")
|
||||
return body
|
||||
}
|
||||
|
||||
budget, ok := ReasoningEffortBudgetMapping[normalized]
|
||||
if !ok {
|
||||
return body
|
||||
}
|
||||
|
||||
body, _ = sjson.SetBytes(body, budgetPath, budget)
|
||||
body, _ = sjson.SetBytes(body, includePath, true)
|
||||
return body
|
||||
}
|
||||
|
||||
// ApplyReasoningEffortToGeminiCLI applies OpenAI reasoning_effort to Gemini CLI thinkingConfig
|
||||
// for Gemini CLI API format (request.generationConfig.thinkingConfig path).
|
||||
// Returns the modified body with thinkingBudget and include_thoughts set.
|
||||
func ApplyReasoningEffortToGeminiCLI(body []byte, effort string) []byte {
|
||||
normalized := strings.ToLower(strings.TrimSpace(effort))
|
||||
if normalized == "" {
|
||||
return body
|
||||
}
|
||||
|
||||
budgetPath := "request.generationConfig.thinkingConfig.thinkingBudget"
|
||||
includePath := "request.generationConfig.thinkingConfig.include_thoughts"
|
||||
|
||||
if normalized == "none" {
|
||||
body, _ = sjson.DeleteBytes(body, "request.generationConfig.thinkingConfig")
|
||||
return body
|
||||
}
|
||||
|
||||
budget, ok := ReasoningEffortBudgetMapping[normalized]
|
||||
if !ok {
|
||||
return body
|
||||
}
|
||||
|
||||
body, _ = sjson.SetBytes(body, budgetPath, budget)
|
||||
body, _ = sjson.SetBytes(body, includePath, true)
|
||||
return body
|
||||
}
|
||||
|
||||
// ConvertThinkingLevelToBudget checks for "generationConfig.thinkingConfig.thinkingLevel"
|
||||
// and converts it to "thinkingBudget".
|
||||
// "high" -> 32768
|
||||
// "low" -> 128
|
||||
// It removes "thinkingLevel" after conversion.
|
||||
func ConvertThinkingLevelToBudget(body []byte) []byte {
|
||||
// and converts it to "thinkingBudget" for Gemini 2.5 models.
|
||||
// For Gemini 3 models, preserves thinkingLevel unless skipGemini3Check is provided and true.
|
||||
// Mappings for Gemini 2.5:
|
||||
// - "high" -> 32768
|
||||
// - "medium" -> 8192
|
||||
// - "low" -> 1024
|
||||
// - "minimal" -> 512
|
||||
//
|
||||
// It removes "thinkingLevel" after conversion (for Gemini 2.5 only).
|
||||
func ConvertThinkingLevelToBudget(body []byte, model string, skipGemini3Check ...bool) []byte {
|
||||
levelPath := "generationConfig.thinkingConfig.thinkingLevel"
|
||||
res := gjson.GetBytes(body, levelPath)
|
||||
if !res.Exists() {
|
||||
return body
|
||||
}
|
||||
|
||||
level := strings.ToLower(res.String())
|
||||
var budget int
|
||||
switch level {
|
||||
case "high":
|
||||
budget = 32768
|
||||
case "low":
|
||||
budget = 128
|
||||
default:
|
||||
// If unknown level, we might just leave it or default.
|
||||
// User only specified high and low. We'll assume we shouldn't touch it if it's something else,
|
||||
// or maybe we should just remove the invalid level?
|
||||
// For safety adhering to strict instructions: "If high... if low...".
|
||||
// If it's something else, the upstream might fail anyway if we leave it,
|
||||
// but let's just delete the level if we processed it.
|
||||
// Actually, let's check if we need to do anything for other values.
|
||||
// For now, only handle high/low.
|
||||
// For Gemini 3 models, preserve thinkingLevel unless explicitly skipped
|
||||
skipGemini3 := len(skipGemini3Check) > 0 && skipGemini3Check[0]
|
||||
if IsGemini3Model(model) && !skipGemini3 {
|
||||
return body
|
||||
}
|
||||
|
||||
// Set budget
|
||||
budget, ok := ThinkingLevelToBudget(res.String())
|
||||
if !ok {
|
||||
updated, _ := sjson.DeleteBytes(body, levelPath)
|
||||
return updated
|
||||
}
|
||||
|
||||
budgetPath := "generationConfig.thinkingConfig.thinkingBudget"
|
||||
updated, err := sjson.SetBytes(body, budgetPath, budget)
|
||||
if err != nil {
|
||||
return body
|
||||
}
|
||||
|
||||
// Remove level
|
||||
updated, err = sjson.DeleteBytes(updated, levelPath)
|
||||
if err != nil {
|
||||
return body
|
||||
}
|
||||
return updated
|
||||
}
|
||||
|
||||
// ConvertThinkingLevelToBudgetCLI checks for "request.generationConfig.thinkingConfig.thinkingLevel"
|
||||
// and converts it to "thinkingBudget" for Gemini 2.5 models.
|
||||
// For Gemini 3 models, preserves thinkingLevel as-is (does not convert).
|
||||
func ConvertThinkingLevelToBudgetCLI(body []byte, model string) []byte {
|
||||
levelPath := "request.generationConfig.thinkingConfig.thinkingLevel"
|
||||
res := gjson.GetBytes(body, levelPath)
|
||||
if !res.Exists() {
|
||||
return body
|
||||
}
|
||||
|
||||
// For Gemini 3 models, preserve thinkingLevel - don't convert to budget
|
||||
if IsGemini3Model(model) {
|
||||
return body
|
||||
}
|
||||
|
||||
budget, ok := ThinkingLevelToBudget(res.String())
|
||||
if !ok {
|
||||
updated, _ := sjson.DeleteBytes(body, levelPath)
|
||||
return updated
|
||||
}
|
||||
|
||||
budgetPath := "request.generationConfig.thinkingConfig.thinkingBudget"
|
||||
updated, err := sjson.SetBytes(body, budgetPath, budget)
|
||||
if err != nil {
|
||||
return body
|
||||
}
|
||||
|
||||
updated, err = sjson.DeleteBytes(updated, levelPath)
|
||||
if err != nil {
|
||||
return body
|
||||
|
||||
@@ -1,34 +0,0 @@
|
||||
package util
|
||||
|
||||
// OpenAIThinkingBudgetToEffort maps a numeric thinking budget (tokens)
|
||||
// into an OpenAI-style reasoning effort level for level-based models.
|
||||
//
|
||||
// Ranges:
|
||||
// - 0 -> "none"
|
||||
// - 1..1024 -> "low"
|
||||
// - 1025..8192 -> "medium"
|
||||
// - 8193..24576 -> "high"
|
||||
// - 24577.. -> highest supported level for the model (defaults to "xhigh")
|
||||
//
|
||||
// Negative values (except the dynamic -1 handled elsewhere) are treated as unsupported.
|
||||
func OpenAIThinkingBudgetToEffort(model string, budget int) (string, bool) {
|
||||
switch {
|
||||
case budget < 0:
|
||||
return "", false
|
||||
case budget == 0:
|
||||
return "none", true
|
||||
case budget > 0 && budget <= 1024:
|
||||
return "low", true
|
||||
case budget <= 8192:
|
||||
return "medium", true
|
||||
case budget <= 24576:
|
||||
return "high", true
|
||||
case budget > 24576:
|
||||
if levels := GetModelThinkingLevels(model); len(levels) > 0 {
|
||||
return levels[len(levels)-1], true
|
||||
}
|
||||
return "xhigh", true
|
||||
default:
|
||||
return "", false
|
||||
}
|
||||
}
|
||||
@@ -118,3 +118,111 @@ func IsOpenAICompatibilityModel(model string) bool {
|
||||
}
|
||||
return strings.EqualFold(strings.TrimSpace(info.Type), "openai-compatibility")
|
||||
}
|
||||
|
||||
// ThinkingEffortToBudget maps a reasoning effort level to a numeric thinking budget (tokens),
|
||||
// clamping the result to the model's supported range.
|
||||
//
|
||||
// Mappings (values are normalized to model's supported range):
|
||||
// - "none" -> 0
|
||||
// - "auto" -> -1
|
||||
// - "minimal" -> 512
|
||||
// - "low" -> 1024
|
||||
// - "medium" -> 8192
|
||||
// - "high" -> 24576
|
||||
// - "xhigh" -> 32768
|
||||
//
|
||||
// Returns false when the effort level is empty or unsupported.
|
||||
func ThinkingEffortToBudget(model, effort string) (int, bool) {
|
||||
if effort == "" {
|
||||
return 0, false
|
||||
}
|
||||
normalized, ok := NormalizeReasoningEffortLevel(model, effort)
|
||||
if !ok {
|
||||
normalized = strings.ToLower(strings.TrimSpace(effort))
|
||||
}
|
||||
switch normalized {
|
||||
case "none":
|
||||
return 0, true
|
||||
case "auto":
|
||||
return NormalizeThinkingBudget(model, -1), true
|
||||
case "minimal":
|
||||
return NormalizeThinkingBudget(model, 512), true
|
||||
case "low":
|
||||
return NormalizeThinkingBudget(model, 1024), true
|
||||
case "medium":
|
||||
return NormalizeThinkingBudget(model, 8192), true
|
||||
case "high":
|
||||
return NormalizeThinkingBudget(model, 24576), true
|
||||
case "xhigh":
|
||||
return NormalizeThinkingBudget(model, 32768), true
|
||||
default:
|
||||
return 0, false
|
||||
}
|
||||
}
|
||||
|
||||
// ThinkingLevelToBudget maps a Gemini thinkingLevel to a numeric thinking budget (tokens).
|
||||
//
|
||||
// Mappings:
|
||||
// - "minimal" -> 512
|
||||
// - "low" -> 1024
|
||||
// - "medium" -> 8192
|
||||
// - "high" -> 32768
|
||||
//
|
||||
// Returns false when the level is empty or unsupported.
|
||||
func ThinkingLevelToBudget(level string) (int, bool) {
|
||||
if level == "" {
|
||||
return 0, false
|
||||
}
|
||||
normalized := strings.ToLower(strings.TrimSpace(level))
|
||||
switch normalized {
|
||||
case "minimal":
|
||||
return 512, true
|
||||
case "low":
|
||||
return 1024, true
|
||||
case "medium":
|
||||
return 8192, true
|
||||
case "high":
|
||||
return 32768, true
|
||||
default:
|
||||
return 0, false
|
||||
}
|
||||
}
|
||||
|
||||
// ThinkingBudgetToEffort maps a numeric thinking budget (tokens)
|
||||
// to a reasoning effort level for level-based models.
|
||||
//
|
||||
// Mappings:
|
||||
// - 0 -> "none" (or lowest supported level if model doesn't support "none")
|
||||
// - -1 -> "auto"
|
||||
// - 1..1024 -> "low"
|
||||
// - 1025..8192 -> "medium"
|
||||
// - 8193..24576 -> "high"
|
||||
// - 24577.. -> highest supported level for the model (defaults to "xhigh")
|
||||
//
|
||||
// Returns false when the budget is unsupported (negative values other than -1).
|
||||
func ThinkingBudgetToEffort(model string, budget int) (string, bool) {
|
||||
switch {
|
||||
case budget == -1:
|
||||
return "auto", true
|
||||
case budget < -1:
|
||||
return "", false
|
||||
case budget == 0:
|
||||
if levels := GetModelThinkingLevels(model); len(levels) > 0 {
|
||||
return levels[0], true
|
||||
}
|
||||
return "none", true
|
||||
case budget > 0 && budget <= 1024:
|
||||
return "low", true
|
||||
case budget <= 8192:
|
||||
return "medium", true
|
||||
case budget <= 24576:
|
||||
return "high", true
|
||||
case budget > 24576:
|
||||
if levels := GetModelThinkingLevels(model); len(levels) > 0 {
|
||||
return levels[len(levels)-1], true
|
||||
}
|
||||
return "xhigh", true
|
||||
default:
|
||||
return "", false
|
||||
}
|
||||
}
|
||||
|
||||
@@ -201,36 +201,6 @@ func ReasoningEffortFromMetadata(metadata map[string]any) (string, bool) {
|
||||
return "", true
|
||||
}
|
||||
|
||||
// ThinkingEffortToBudget maps reasoning effort levels to approximate budgets,
|
||||
// clamping the result to the model's supported range.
|
||||
func ThinkingEffortToBudget(model, effort string) (int, bool) {
|
||||
if effort == "" {
|
||||
return 0, false
|
||||
}
|
||||
normalized, ok := NormalizeReasoningEffortLevel(model, effort)
|
||||
if !ok {
|
||||
normalized = strings.ToLower(strings.TrimSpace(effort))
|
||||
}
|
||||
switch normalized {
|
||||
case "none":
|
||||
return 0, true
|
||||
case "auto":
|
||||
return NormalizeThinkingBudget(model, -1), true
|
||||
case "minimal":
|
||||
return NormalizeThinkingBudget(model, 512), true
|
||||
case "low":
|
||||
return NormalizeThinkingBudget(model, 1024), true
|
||||
case "medium":
|
||||
return NormalizeThinkingBudget(model, 8192), true
|
||||
case "high":
|
||||
return NormalizeThinkingBudget(model, 24576), true
|
||||
case "xhigh":
|
||||
return NormalizeThinkingBudget(model, 32768), true
|
||||
default:
|
||||
return 0, false
|
||||
}
|
||||
}
|
||||
|
||||
// ResolveOriginalModel returns the original model name stored in metadata (if present),
|
||||
// otherwise falls back to the provided model.
|
||||
func ResolveOriginalModel(model string, metadata map[string]any) string {
|
||||
|
||||
@@ -6,6 +6,7 @@ package util
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/tidwall/gjson"
|
||||
"github.com/tidwall/sjson"
|
||||
@@ -28,10 +29,17 @@ func Walk(value gjson.Result, path, field string, paths *[]string) {
|
||||
// For JSON objects and arrays, iterate through each child
|
||||
value.ForEach(func(key, val gjson.Result) bool {
|
||||
var childPath string
|
||||
// Escape special characters for gjson/sjson path syntax
|
||||
// . -> \.
|
||||
// * -> \*
|
||||
// ? -> \?
|
||||
var keyReplacer = strings.NewReplacer(".", "\\.", "*", "\\*", "?", "\\?")
|
||||
safeKey := keyReplacer.Replace(key.String())
|
||||
|
||||
if path == "" {
|
||||
childPath = key.String()
|
||||
childPath = safeKey
|
||||
} else {
|
||||
childPath = path + "." + key.String()
|
||||
childPath = path + "." + safeKey
|
||||
}
|
||||
if key.String() == field {
|
||||
*paths = append(*paths, childPath)
|
||||
|
||||
270
internal/watcher/clients.go
Normal file
270
internal/watcher/clients.go
Normal file
@@ -0,0 +1,270 @@
|
||||
// clients.go implements watcher client lifecycle logic and persistence helpers.
|
||||
// It reloads clients, handles incremental auth file changes, and persists updates when supported.
|
||||
package watcher
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
"io/fs"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
|
||||
coreauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
func (w *Watcher) reloadClients(rescanAuth bool, affectedOAuthProviders []string, forceAuthRefresh bool) {
|
||||
log.Debugf("starting full client load process")
|
||||
|
||||
w.clientsMutex.RLock()
|
||||
cfg := w.config
|
||||
w.clientsMutex.RUnlock()
|
||||
|
||||
if cfg == nil {
|
||||
log.Error("config is nil, cannot reload clients")
|
||||
return
|
||||
}
|
||||
|
||||
if len(affectedOAuthProviders) > 0 {
|
||||
w.clientsMutex.Lock()
|
||||
if w.currentAuths != nil {
|
||||
filtered := make(map[string]*coreauth.Auth, len(w.currentAuths))
|
||||
for id, auth := range w.currentAuths {
|
||||
if auth == nil {
|
||||
continue
|
||||
}
|
||||
provider := strings.ToLower(strings.TrimSpace(auth.Provider))
|
||||
if _, match := matchProvider(provider, affectedOAuthProviders); match {
|
||||
continue
|
||||
}
|
||||
filtered[id] = auth
|
||||
}
|
||||
w.currentAuths = filtered
|
||||
log.Debugf("applying oauth-excluded-models to providers %v", affectedOAuthProviders)
|
||||
} else {
|
||||
w.currentAuths = nil
|
||||
}
|
||||
w.clientsMutex.Unlock()
|
||||
}
|
||||
|
||||
geminiAPIKeyCount, vertexCompatAPIKeyCount, claudeAPIKeyCount, codexAPIKeyCount, openAICompatCount := BuildAPIKeyClients(cfg)
|
||||
totalAPIKeyClients := geminiAPIKeyCount + vertexCompatAPIKeyCount + claudeAPIKeyCount + codexAPIKeyCount + openAICompatCount
|
||||
log.Debugf("loaded %d API key clients", totalAPIKeyClients)
|
||||
|
||||
var authFileCount int
|
||||
if rescanAuth {
|
||||
authFileCount = w.loadFileClients(cfg)
|
||||
log.Debugf("loaded %d file-based clients", authFileCount)
|
||||
} else {
|
||||
w.clientsMutex.RLock()
|
||||
authFileCount = len(w.lastAuthHashes)
|
||||
w.clientsMutex.RUnlock()
|
||||
log.Debugf("skipping auth directory rescan; retaining %d existing auth files", authFileCount)
|
||||
}
|
||||
|
||||
if rescanAuth {
|
||||
w.clientsMutex.Lock()
|
||||
|
||||
w.lastAuthHashes = make(map[string]string)
|
||||
if resolvedAuthDir, errResolveAuthDir := util.ResolveAuthDir(cfg.AuthDir); errResolveAuthDir != nil {
|
||||
log.Errorf("failed to resolve auth directory for hash cache: %v", errResolveAuthDir)
|
||||
} else if resolvedAuthDir != "" {
|
||||
_ = filepath.Walk(resolvedAuthDir, func(path string, info fs.FileInfo, err error) error {
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
if !info.IsDir() && strings.HasSuffix(strings.ToLower(info.Name()), ".json") {
|
||||
if data, errReadFile := os.ReadFile(path); errReadFile == nil && len(data) > 0 {
|
||||
sum := sha256.Sum256(data)
|
||||
normalizedPath := w.normalizeAuthPath(path)
|
||||
w.lastAuthHashes[normalizedPath] = hex.EncodeToString(sum[:])
|
||||
}
|
||||
}
|
||||
return nil
|
||||
})
|
||||
}
|
||||
w.clientsMutex.Unlock()
|
||||
}
|
||||
|
||||
totalNewClients := authFileCount + geminiAPIKeyCount + vertexCompatAPIKeyCount + claudeAPIKeyCount + codexAPIKeyCount + openAICompatCount
|
||||
|
||||
if w.reloadCallback != nil {
|
||||
log.Debugf("triggering server update callback before auth refresh")
|
||||
w.reloadCallback(cfg)
|
||||
}
|
||||
|
||||
w.refreshAuthState(forceAuthRefresh)
|
||||
|
||||
log.Infof("full client load complete - %d clients (%d auth files + %d Gemini API keys + %d Vertex API keys + %d Claude API keys + %d Codex keys + %d OpenAI-compat)",
|
||||
totalNewClients,
|
||||
authFileCount,
|
||||
geminiAPIKeyCount,
|
||||
vertexCompatAPIKeyCount,
|
||||
claudeAPIKeyCount,
|
||||
codexAPIKeyCount,
|
||||
openAICompatCount,
|
||||
)
|
||||
}
|
||||
|
||||
func (w *Watcher) addOrUpdateClient(path string) {
|
||||
data, errRead := os.ReadFile(path)
|
||||
if errRead != nil {
|
||||
log.Errorf("failed to read auth file %s: %v", filepath.Base(path), errRead)
|
||||
return
|
||||
}
|
||||
if len(data) == 0 {
|
||||
log.Debugf("ignoring empty auth file: %s", filepath.Base(path))
|
||||
return
|
||||
}
|
||||
|
||||
sum := sha256.Sum256(data)
|
||||
curHash := hex.EncodeToString(sum[:])
|
||||
normalized := w.normalizeAuthPath(path)
|
||||
|
||||
w.clientsMutex.Lock()
|
||||
|
||||
cfg := w.config
|
||||
if cfg == nil {
|
||||
log.Error("config is nil, cannot add or update client")
|
||||
w.clientsMutex.Unlock()
|
||||
return
|
||||
}
|
||||
if prev, ok := w.lastAuthHashes[normalized]; ok && prev == curHash {
|
||||
log.Debugf("auth file unchanged (hash match), skipping reload: %s", filepath.Base(path))
|
||||
w.clientsMutex.Unlock()
|
||||
return
|
||||
}
|
||||
|
||||
w.lastAuthHashes[normalized] = curHash
|
||||
|
||||
w.clientsMutex.Unlock() // Unlock before the callback
|
||||
|
||||
w.refreshAuthState(false)
|
||||
|
||||
if w.reloadCallback != nil {
|
||||
log.Debugf("triggering server update callback after add/update")
|
||||
w.reloadCallback(cfg)
|
||||
}
|
||||
w.persistAuthAsync(fmt.Sprintf("Sync auth %s", filepath.Base(path)), path)
|
||||
}
|
||||
|
||||
func (w *Watcher) removeClient(path string) {
|
||||
normalized := w.normalizeAuthPath(path)
|
||||
w.clientsMutex.Lock()
|
||||
|
||||
cfg := w.config
|
||||
delete(w.lastAuthHashes, normalized)
|
||||
|
||||
w.clientsMutex.Unlock() // Release the lock before the callback
|
||||
|
||||
w.refreshAuthState(false)
|
||||
|
||||
if w.reloadCallback != nil {
|
||||
log.Debugf("triggering server update callback after removal")
|
||||
w.reloadCallback(cfg)
|
||||
}
|
||||
w.persistAuthAsync(fmt.Sprintf("Remove auth %s", filepath.Base(path)), path)
|
||||
}
|
||||
|
||||
func (w *Watcher) loadFileClients(cfg *config.Config) int {
|
||||
authFileCount := 0
|
||||
successfulAuthCount := 0
|
||||
|
||||
authDir, errResolveAuthDir := util.ResolveAuthDir(cfg.AuthDir)
|
||||
if errResolveAuthDir != nil {
|
||||
log.Errorf("failed to resolve auth directory: %v", errResolveAuthDir)
|
||||
return 0
|
||||
}
|
||||
if authDir == "" {
|
||||
return 0
|
||||
}
|
||||
|
||||
errWalk := filepath.Walk(authDir, func(path string, info fs.FileInfo, err error) error {
|
||||
if err != nil {
|
||||
log.Debugf("error accessing path %s: %v", path, err)
|
||||
return err
|
||||
}
|
||||
if !info.IsDir() && strings.HasSuffix(strings.ToLower(info.Name()), ".json") {
|
||||
authFileCount++
|
||||
log.Debugf("processing auth file %d: %s", authFileCount, filepath.Base(path))
|
||||
if data, errCreate := os.ReadFile(path); errCreate == nil && len(data) > 0 {
|
||||
successfulAuthCount++
|
||||
}
|
||||
}
|
||||
return nil
|
||||
})
|
||||
|
||||
if errWalk != nil {
|
||||
log.Errorf("error walking auth directory: %v", errWalk)
|
||||
}
|
||||
log.Debugf("auth directory scan complete - found %d .json files, %d readable", authFileCount, successfulAuthCount)
|
||||
return authFileCount
|
||||
}
|
||||
|
||||
func BuildAPIKeyClients(cfg *config.Config) (int, int, int, int, int) {
|
||||
geminiAPIKeyCount := 0
|
||||
vertexCompatAPIKeyCount := 0
|
||||
claudeAPIKeyCount := 0
|
||||
codexAPIKeyCount := 0
|
||||
openAICompatCount := 0
|
||||
|
||||
if len(cfg.GeminiKey) > 0 {
|
||||
geminiAPIKeyCount += len(cfg.GeminiKey)
|
||||
}
|
||||
if len(cfg.VertexCompatAPIKey) > 0 {
|
||||
vertexCompatAPIKeyCount += len(cfg.VertexCompatAPIKey)
|
||||
}
|
||||
if len(cfg.ClaudeKey) > 0 {
|
||||
claudeAPIKeyCount += len(cfg.ClaudeKey)
|
||||
}
|
||||
if len(cfg.CodexKey) > 0 {
|
||||
codexAPIKeyCount += len(cfg.CodexKey)
|
||||
}
|
||||
if len(cfg.OpenAICompatibility) > 0 {
|
||||
for _, compatConfig := range cfg.OpenAICompatibility {
|
||||
openAICompatCount += len(compatConfig.APIKeyEntries)
|
||||
}
|
||||
}
|
||||
return geminiAPIKeyCount, vertexCompatAPIKeyCount, claudeAPIKeyCount, codexAPIKeyCount, openAICompatCount
|
||||
}
|
||||
|
||||
func (w *Watcher) persistConfigAsync() {
|
||||
if w == nil || w.storePersister == nil {
|
||||
return
|
||||
}
|
||||
go func() {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
|
||||
defer cancel()
|
||||
if err := w.storePersister.PersistConfig(ctx); err != nil {
|
||||
log.Errorf("failed to persist config change: %v", err)
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
func (w *Watcher) persistAuthAsync(message string, paths ...string) {
|
||||
if w == nil || w.storePersister == nil {
|
||||
return
|
||||
}
|
||||
filtered := make([]string, 0, len(paths))
|
||||
for _, p := range paths {
|
||||
if trimmed := strings.TrimSpace(p); trimmed != "" {
|
||||
filtered = append(filtered, trimmed)
|
||||
}
|
||||
}
|
||||
if len(filtered) == 0 {
|
||||
return
|
||||
}
|
||||
go func() {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
|
||||
defer cancel()
|
||||
if err := w.storePersister.PersistAuthFiles(ctx, message, filtered...); err != nil {
|
||||
log.Errorf("failed to persist auth changes: %v", err)
|
||||
}
|
||||
}()
|
||||
}
|
||||
134
internal/watcher/config_reload.go
Normal file
134
internal/watcher/config_reload.go
Normal file
@@ -0,0 +1,134 @@
|
||||
// config_reload.go implements debounced configuration hot reload.
|
||||
// It detects material changes and reloads clients when the config changes.
|
||||
package watcher
|
||||
|
||||
import (
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
"os"
|
||||
"time"
|
||||
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/watcher/diff"
|
||||
"gopkg.in/yaml.v3"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
func (w *Watcher) stopConfigReloadTimer() {
|
||||
w.configReloadMu.Lock()
|
||||
if w.configReloadTimer != nil {
|
||||
w.configReloadTimer.Stop()
|
||||
w.configReloadTimer = nil
|
||||
}
|
||||
w.configReloadMu.Unlock()
|
||||
}
|
||||
|
||||
func (w *Watcher) scheduleConfigReload() {
|
||||
w.configReloadMu.Lock()
|
||||
defer w.configReloadMu.Unlock()
|
||||
if w.configReloadTimer != nil {
|
||||
w.configReloadTimer.Stop()
|
||||
}
|
||||
w.configReloadTimer = time.AfterFunc(configReloadDebounce, func() {
|
||||
w.configReloadMu.Lock()
|
||||
w.configReloadTimer = nil
|
||||
w.configReloadMu.Unlock()
|
||||
w.reloadConfigIfChanged()
|
||||
})
|
||||
}
|
||||
|
||||
func (w *Watcher) reloadConfigIfChanged() {
|
||||
data, err := os.ReadFile(w.configPath)
|
||||
if err != nil {
|
||||
log.Errorf("failed to read config file for hash check: %v", err)
|
||||
return
|
||||
}
|
||||
if len(data) == 0 {
|
||||
log.Debugf("ignoring empty config file write event")
|
||||
return
|
||||
}
|
||||
sum := sha256.Sum256(data)
|
||||
newHash := hex.EncodeToString(sum[:])
|
||||
|
||||
w.clientsMutex.RLock()
|
||||
currentHash := w.lastConfigHash
|
||||
w.clientsMutex.RUnlock()
|
||||
|
||||
if currentHash != "" && currentHash == newHash {
|
||||
log.Debugf("config file content unchanged (hash match), skipping reload")
|
||||
return
|
||||
}
|
||||
log.Infof("config file changed, reloading: %s", w.configPath)
|
||||
if w.reloadConfig() {
|
||||
finalHash := newHash
|
||||
if updatedData, errRead := os.ReadFile(w.configPath); errRead == nil && len(updatedData) > 0 {
|
||||
sumUpdated := sha256.Sum256(updatedData)
|
||||
finalHash = hex.EncodeToString(sumUpdated[:])
|
||||
} else if errRead != nil {
|
||||
log.WithError(errRead).Debug("failed to compute updated config hash after reload")
|
||||
}
|
||||
w.clientsMutex.Lock()
|
||||
w.lastConfigHash = finalHash
|
||||
w.clientsMutex.Unlock()
|
||||
w.persistConfigAsync()
|
||||
}
|
||||
}
|
||||
|
||||
func (w *Watcher) reloadConfig() bool {
|
||||
log.Debug("=========================== CONFIG RELOAD ============================")
|
||||
log.Debugf("starting config reload from: %s", w.configPath)
|
||||
|
||||
newConfig, errLoadConfig := config.LoadConfig(w.configPath)
|
||||
if errLoadConfig != nil {
|
||||
log.Errorf("failed to reload config: %v", errLoadConfig)
|
||||
return false
|
||||
}
|
||||
|
||||
if w.mirroredAuthDir != "" {
|
||||
newConfig.AuthDir = w.mirroredAuthDir
|
||||
} else {
|
||||
if resolvedAuthDir, errResolveAuthDir := util.ResolveAuthDir(newConfig.AuthDir); errResolveAuthDir != nil {
|
||||
log.Errorf("failed to resolve auth directory from config: %v", errResolveAuthDir)
|
||||
} else {
|
||||
newConfig.AuthDir = resolvedAuthDir
|
||||
}
|
||||
}
|
||||
|
||||
w.clientsMutex.Lock()
|
||||
var oldConfig *config.Config
|
||||
_ = yaml.Unmarshal(w.oldConfigYaml, &oldConfig)
|
||||
w.oldConfigYaml, _ = yaml.Marshal(newConfig)
|
||||
w.config = newConfig
|
||||
w.clientsMutex.Unlock()
|
||||
|
||||
var affectedOAuthProviders []string
|
||||
if oldConfig != nil {
|
||||
_, affectedOAuthProviders = diff.DiffOAuthExcludedModelChanges(oldConfig.OAuthExcludedModels, newConfig.OAuthExcludedModels)
|
||||
}
|
||||
|
||||
util.SetLogLevel(newConfig)
|
||||
if oldConfig != nil && oldConfig.Debug != newConfig.Debug {
|
||||
log.Debugf("log level updated - debug mode changed from %t to %t", oldConfig.Debug, newConfig.Debug)
|
||||
}
|
||||
|
||||
if oldConfig != nil {
|
||||
details := diff.BuildConfigChangeDetails(oldConfig, newConfig)
|
||||
if len(details) > 0 {
|
||||
log.Debugf("config changes detected:")
|
||||
for _, d := range details {
|
||||
log.Debugf(" %s", d)
|
||||
}
|
||||
} else {
|
||||
log.Debugf("no material config field changes detected")
|
||||
}
|
||||
}
|
||||
|
||||
authDirChanged := oldConfig == nil || oldConfig.AuthDir != newConfig.AuthDir
|
||||
forceAuthRefresh := oldConfig != nil && oldConfig.ForceModelPrefix != newConfig.ForceModelPrefix
|
||||
|
||||
log.Infof("config successfully reloaded, triggering client reload")
|
||||
w.reloadClients(authDirChanged, affectedOAuthProviders, forceAuthRefresh)
|
||||
return true
|
||||
}
|
||||
303
internal/watcher/diff/config_diff.go
Normal file
303
internal/watcher/diff/config_diff.go
Normal file
@@ -0,0 +1,303 @@
|
||||
package diff
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/url"
|
||||
"reflect"
|
||||
"strings"
|
||||
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||
)
|
||||
|
||||
// BuildConfigChangeDetails computes a redacted, human-readable list of config changes.
|
||||
// Secrets are never printed; only structural or non-sensitive fields are surfaced.
|
||||
func BuildConfigChangeDetails(oldCfg, newCfg *config.Config) []string {
|
||||
changes := make([]string, 0, 16)
|
||||
if oldCfg == nil || newCfg == nil {
|
||||
return changes
|
||||
}
|
||||
|
||||
// Simple scalars
|
||||
if oldCfg.Port != newCfg.Port {
|
||||
changes = append(changes, fmt.Sprintf("port: %d -> %d", oldCfg.Port, newCfg.Port))
|
||||
}
|
||||
if oldCfg.AuthDir != newCfg.AuthDir {
|
||||
changes = append(changes, fmt.Sprintf("auth-dir: %s -> %s", oldCfg.AuthDir, newCfg.AuthDir))
|
||||
}
|
||||
if oldCfg.Debug != newCfg.Debug {
|
||||
changes = append(changes, fmt.Sprintf("debug: %t -> %t", oldCfg.Debug, newCfg.Debug))
|
||||
}
|
||||
if oldCfg.LoggingToFile != newCfg.LoggingToFile {
|
||||
changes = append(changes, fmt.Sprintf("logging-to-file: %t -> %t", oldCfg.LoggingToFile, newCfg.LoggingToFile))
|
||||
}
|
||||
if oldCfg.UsageStatisticsEnabled != newCfg.UsageStatisticsEnabled {
|
||||
changes = append(changes, fmt.Sprintf("usage-statistics-enabled: %t -> %t", oldCfg.UsageStatisticsEnabled, newCfg.UsageStatisticsEnabled))
|
||||
}
|
||||
if oldCfg.DisableCooling != newCfg.DisableCooling {
|
||||
changes = append(changes, fmt.Sprintf("disable-cooling: %t -> %t", oldCfg.DisableCooling, newCfg.DisableCooling))
|
||||
}
|
||||
if oldCfg.RequestLog != newCfg.RequestLog {
|
||||
changes = append(changes, fmt.Sprintf("request-log: %t -> %t", oldCfg.RequestLog, newCfg.RequestLog))
|
||||
}
|
||||
if oldCfg.RequestRetry != newCfg.RequestRetry {
|
||||
changes = append(changes, fmt.Sprintf("request-retry: %d -> %d", oldCfg.RequestRetry, newCfg.RequestRetry))
|
||||
}
|
||||
if oldCfg.MaxRetryInterval != newCfg.MaxRetryInterval {
|
||||
changes = append(changes, fmt.Sprintf("max-retry-interval: %d -> %d", oldCfg.MaxRetryInterval, newCfg.MaxRetryInterval))
|
||||
}
|
||||
if oldCfg.ProxyURL != newCfg.ProxyURL {
|
||||
changes = append(changes, fmt.Sprintf("proxy-url: %s -> %s", formatProxyURL(oldCfg.ProxyURL), formatProxyURL(newCfg.ProxyURL)))
|
||||
}
|
||||
if oldCfg.WebsocketAuth != newCfg.WebsocketAuth {
|
||||
changes = append(changes, fmt.Sprintf("ws-auth: %t -> %t", oldCfg.WebsocketAuth, newCfg.WebsocketAuth))
|
||||
}
|
||||
if oldCfg.ForceModelPrefix != newCfg.ForceModelPrefix {
|
||||
changes = append(changes, fmt.Sprintf("force-model-prefix: %t -> %t", oldCfg.ForceModelPrefix, newCfg.ForceModelPrefix))
|
||||
}
|
||||
|
||||
// Quota-exceeded behavior
|
||||
if oldCfg.QuotaExceeded.SwitchProject != newCfg.QuotaExceeded.SwitchProject {
|
||||
changes = append(changes, fmt.Sprintf("quota-exceeded.switch-project: %t -> %t", oldCfg.QuotaExceeded.SwitchProject, newCfg.QuotaExceeded.SwitchProject))
|
||||
}
|
||||
if oldCfg.QuotaExceeded.SwitchPreviewModel != newCfg.QuotaExceeded.SwitchPreviewModel {
|
||||
changes = append(changes, fmt.Sprintf("quota-exceeded.switch-preview-model: %t -> %t", oldCfg.QuotaExceeded.SwitchPreviewModel, newCfg.QuotaExceeded.SwitchPreviewModel))
|
||||
}
|
||||
|
||||
// API keys (redacted) and counts
|
||||
if len(oldCfg.APIKeys) != len(newCfg.APIKeys) {
|
||||
changes = append(changes, fmt.Sprintf("api-keys count: %d -> %d", len(oldCfg.APIKeys), len(newCfg.APIKeys)))
|
||||
} else if !reflect.DeepEqual(trimStrings(oldCfg.APIKeys), trimStrings(newCfg.APIKeys)) {
|
||||
changes = append(changes, "api-keys: values updated (count unchanged, redacted)")
|
||||
}
|
||||
if len(oldCfg.GeminiKey) != len(newCfg.GeminiKey) {
|
||||
changes = append(changes, fmt.Sprintf("gemini-api-key count: %d -> %d", len(oldCfg.GeminiKey), len(newCfg.GeminiKey)))
|
||||
} else {
|
||||
for i := range oldCfg.GeminiKey {
|
||||
o := oldCfg.GeminiKey[i]
|
||||
n := newCfg.GeminiKey[i]
|
||||
if strings.TrimSpace(o.BaseURL) != strings.TrimSpace(n.BaseURL) {
|
||||
changes = append(changes, fmt.Sprintf("gemini[%d].base-url: %s -> %s", i, strings.TrimSpace(o.BaseURL), strings.TrimSpace(n.BaseURL)))
|
||||
}
|
||||
if strings.TrimSpace(o.ProxyURL) != strings.TrimSpace(n.ProxyURL) {
|
||||
changes = append(changes, fmt.Sprintf("gemini[%d].proxy-url: %s -> %s", i, formatProxyURL(o.ProxyURL), formatProxyURL(n.ProxyURL)))
|
||||
}
|
||||
if strings.TrimSpace(o.Prefix) != strings.TrimSpace(n.Prefix) {
|
||||
changes = append(changes, fmt.Sprintf("gemini[%d].prefix: %s -> %s", i, strings.TrimSpace(o.Prefix), strings.TrimSpace(n.Prefix)))
|
||||
}
|
||||
if strings.TrimSpace(o.APIKey) != strings.TrimSpace(n.APIKey) {
|
||||
changes = append(changes, fmt.Sprintf("gemini[%d].api-key: updated", i))
|
||||
}
|
||||
if !equalStringMap(o.Headers, n.Headers) {
|
||||
changes = append(changes, fmt.Sprintf("gemini[%d].headers: updated", i))
|
||||
}
|
||||
oldExcluded := SummarizeExcludedModels(o.ExcludedModels)
|
||||
newExcluded := SummarizeExcludedModels(n.ExcludedModels)
|
||||
if oldExcluded.hash != newExcluded.hash {
|
||||
changes = append(changes, fmt.Sprintf("gemini[%d].excluded-models: updated (%d -> %d entries)", i, oldExcluded.count, newExcluded.count))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Claude keys (do not print key material)
|
||||
if len(oldCfg.ClaudeKey) != len(newCfg.ClaudeKey) {
|
||||
changes = append(changes, fmt.Sprintf("claude-api-key count: %d -> %d", len(oldCfg.ClaudeKey), len(newCfg.ClaudeKey)))
|
||||
} else {
|
||||
for i := range oldCfg.ClaudeKey {
|
||||
o := oldCfg.ClaudeKey[i]
|
||||
n := newCfg.ClaudeKey[i]
|
||||
if strings.TrimSpace(o.BaseURL) != strings.TrimSpace(n.BaseURL) {
|
||||
changes = append(changes, fmt.Sprintf("claude[%d].base-url: %s -> %s", i, strings.TrimSpace(o.BaseURL), strings.TrimSpace(n.BaseURL)))
|
||||
}
|
||||
if strings.TrimSpace(o.ProxyURL) != strings.TrimSpace(n.ProxyURL) {
|
||||
changes = append(changes, fmt.Sprintf("claude[%d].proxy-url: %s -> %s", i, formatProxyURL(o.ProxyURL), formatProxyURL(n.ProxyURL)))
|
||||
}
|
||||
if strings.TrimSpace(o.Prefix) != strings.TrimSpace(n.Prefix) {
|
||||
changes = append(changes, fmt.Sprintf("claude[%d].prefix: %s -> %s", i, strings.TrimSpace(o.Prefix), strings.TrimSpace(n.Prefix)))
|
||||
}
|
||||
if strings.TrimSpace(o.APIKey) != strings.TrimSpace(n.APIKey) {
|
||||
changes = append(changes, fmt.Sprintf("claude[%d].api-key: updated", i))
|
||||
}
|
||||
if !equalStringMap(o.Headers, n.Headers) {
|
||||
changes = append(changes, fmt.Sprintf("claude[%d].headers: updated", i))
|
||||
}
|
||||
oldExcluded := SummarizeExcludedModels(o.ExcludedModels)
|
||||
newExcluded := SummarizeExcludedModels(n.ExcludedModels)
|
||||
if oldExcluded.hash != newExcluded.hash {
|
||||
changes = append(changes, fmt.Sprintf("claude[%d].excluded-models: updated (%d -> %d entries)", i, oldExcluded.count, newExcluded.count))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Codex keys (do not print key material)
|
||||
if len(oldCfg.CodexKey) != len(newCfg.CodexKey) {
|
||||
changes = append(changes, fmt.Sprintf("codex-api-key count: %d -> %d", len(oldCfg.CodexKey), len(newCfg.CodexKey)))
|
||||
} else {
|
||||
for i := range oldCfg.CodexKey {
|
||||
o := oldCfg.CodexKey[i]
|
||||
n := newCfg.CodexKey[i]
|
||||
if strings.TrimSpace(o.BaseURL) != strings.TrimSpace(n.BaseURL) {
|
||||
changes = append(changes, fmt.Sprintf("codex[%d].base-url: %s -> %s", i, strings.TrimSpace(o.BaseURL), strings.TrimSpace(n.BaseURL)))
|
||||
}
|
||||
if strings.TrimSpace(o.ProxyURL) != strings.TrimSpace(n.ProxyURL) {
|
||||
changes = append(changes, fmt.Sprintf("codex[%d].proxy-url: %s -> %s", i, formatProxyURL(o.ProxyURL), formatProxyURL(n.ProxyURL)))
|
||||
}
|
||||
if strings.TrimSpace(o.Prefix) != strings.TrimSpace(n.Prefix) {
|
||||
changes = append(changes, fmt.Sprintf("codex[%d].prefix: %s -> %s", i, strings.TrimSpace(o.Prefix), strings.TrimSpace(n.Prefix)))
|
||||
}
|
||||
if strings.TrimSpace(o.APIKey) != strings.TrimSpace(n.APIKey) {
|
||||
changes = append(changes, fmt.Sprintf("codex[%d].api-key: updated", i))
|
||||
}
|
||||
if !equalStringMap(o.Headers, n.Headers) {
|
||||
changes = append(changes, fmt.Sprintf("codex[%d].headers: updated", i))
|
||||
}
|
||||
oldExcluded := SummarizeExcludedModels(o.ExcludedModels)
|
||||
newExcluded := SummarizeExcludedModels(n.ExcludedModels)
|
||||
if oldExcluded.hash != newExcluded.hash {
|
||||
changes = append(changes, fmt.Sprintf("codex[%d].excluded-models: updated (%d -> %d entries)", i, oldExcluded.count, newExcluded.count))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// AmpCode settings (redacted where needed)
|
||||
oldAmpURL := strings.TrimSpace(oldCfg.AmpCode.UpstreamURL)
|
||||
newAmpURL := strings.TrimSpace(newCfg.AmpCode.UpstreamURL)
|
||||
if oldAmpURL != newAmpURL {
|
||||
changes = append(changes, fmt.Sprintf("ampcode.upstream-url: %s -> %s", oldAmpURL, newAmpURL))
|
||||
}
|
||||
oldAmpKey := strings.TrimSpace(oldCfg.AmpCode.UpstreamAPIKey)
|
||||
newAmpKey := strings.TrimSpace(newCfg.AmpCode.UpstreamAPIKey)
|
||||
switch {
|
||||
case oldAmpKey == "" && newAmpKey != "":
|
||||
changes = append(changes, "ampcode.upstream-api-key: added")
|
||||
case oldAmpKey != "" && newAmpKey == "":
|
||||
changes = append(changes, "ampcode.upstream-api-key: removed")
|
||||
case oldAmpKey != newAmpKey:
|
||||
changes = append(changes, "ampcode.upstream-api-key: updated")
|
||||
}
|
||||
if oldCfg.AmpCode.RestrictManagementToLocalhost != newCfg.AmpCode.RestrictManagementToLocalhost {
|
||||
changes = append(changes, fmt.Sprintf("ampcode.restrict-management-to-localhost: %t -> %t", oldCfg.AmpCode.RestrictManagementToLocalhost, newCfg.AmpCode.RestrictManagementToLocalhost))
|
||||
}
|
||||
oldMappings := SummarizeAmpModelMappings(oldCfg.AmpCode.ModelMappings)
|
||||
newMappings := SummarizeAmpModelMappings(newCfg.AmpCode.ModelMappings)
|
||||
if oldMappings.hash != newMappings.hash {
|
||||
changes = append(changes, fmt.Sprintf("ampcode.model-mappings: updated (%d -> %d entries)", oldMappings.count, newMappings.count))
|
||||
}
|
||||
if oldCfg.AmpCode.ForceModelMappings != newCfg.AmpCode.ForceModelMappings {
|
||||
changes = append(changes, fmt.Sprintf("ampcode.force-model-mappings: %t -> %t", oldCfg.AmpCode.ForceModelMappings, newCfg.AmpCode.ForceModelMappings))
|
||||
}
|
||||
|
||||
if entries, _ := DiffOAuthExcludedModelChanges(oldCfg.OAuthExcludedModels, newCfg.OAuthExcludedModels); len(entries) > 0 {
|
||||
changes = append(changes, entries...)
|
||||
}
|
||||
|
||||
// Remote management (never print the key)
|
||||
if oldCfg.RemoteManagement.AllowRemote != newCfg.RemoteManagement.AllowRemote {
|
||||
changes = append(changes, fmt.Sprintf("remote-management.allow-remote: %t -> %t", oldCfg.RemoteManagement.AllowRemote, newCfg.RemoteManagement.AllowRemote))
|
||||
}
|
||||
if oldCfg.RemoteManagement.DisableControlPanel != newCfg.RemoteManagement.DisableControlPanel {
|
||||
changes = append(changes, fmt.Sprintf("remote-management.disable-control-panel: %t -> %t", oldCfg.RemoteManagement.DisableControlPanel, newCfg.RemoteManagement.DisableControlPanel))
|
||||
}
|
||||
oldPanelRepo := strings.TrimSpace(oldCfg.RemoteManagement.PanelGitHubRepository)
|
||||
newPanelRepo := strings.TrimSpace(newCfg.RemoteManagement.PanelGitHubRepository)
|
||||
if oldPanelRepo != newPanelRepo {
|
||||
changes = append(changes, fmt.Sprintf("remote-management.panel-github-repository: %s -> %s", oldPanelRepo, newPanelRepo))
|
||||
}
|
||||
if oldCfg.RemoteManagement.SecretKey != newCfg.RemoteManagement.SecretKey {
|
||||
switch {
|
||||
case oldCfg.RemoteManagement.SecretKey == "" && newCfg.RemoteManagement.SecretKey != "":
|
||||
changes = append(changes, "remote-management.secret-key: created")
|
||||
case oldCfg.RemoteManagement.SecretKey != "" && newCfg.RemoteManagement.SecretKey == "":
|
||||
changes = append(changes, "remote-management.secret-key: deleted")
|
||||
default:
|
||||
changes = append(changes, "remote-management.secret-key: updated")
|
||||
}
|
||||
}
|
||||
|
||||
// OpenAI compatibility providers (summarized)
|
||||
if compat := DiffOpenAICompatibility(oldCfg.OpenAICompatibility, newCfg.OpenAICompatibility); len(compat) > 0 {
|
||||
changes = append(changes, "openai-compatibility:")
|
||||
for _, c := range compat {
|
||||
changes = append(changes, " "+c)
|
||||
}
|
||||
}
|
||||
|
||||
// Vertex-compatible API keys
|
||||
if len(oldCfg.VertexCompatAPIKey) != len(newCfg.VertexCompatAPIKey) {
|
||||
changes = append(changes, fmt.Sprintf("vertex-api-key count: %d -> %d", len(oldCfg.VertexCompatAPIKey), len(newCfg.VertexCompatAPIKey)))
|
||||
} else {
|
||||
for i := range oldCfg.VertexCompatAPIKey {
|
||||
o := oldCfg.VertexCompatAPIKey[i]
|
||||
n := newCfg.VertexCompatAPIKey[i]
|
||||
if strings.TrimSpace(o.BaseURL) != strings.TrimSpace(n.BaseURL) {
|
||||
changes = append(changes, fmt.Sprintf("vertex[%d].base-url: %s -> %s", i, strings.TrimSpace(o.BaseURL), strings.TrimSpace(n.BaseURL)))
|
||||
}
|
||||
if strings.TrimSpace(o.ProxyURL) != strings.TrimSpace(n.ProxyURL) {
|
||||
changes = append(changes, fmt.Sprintf("vertex[%d].proxy-url: %s -> %s", i, formatProxyURL(o.ProxyURL), formatProxyURL(n.ProxyURL)))
|
||||
}
|
||||
if strings.TrimSpace(o.Prefix) != strings.TrimSpace(n.Prefix) {
|
||||
changes = append(changes, fmt.Sprintf("vertex[%d].prefix: %s -> %s", i, strings.TrimSpace(o.Prefix), strings.TrimSpace(n.Prefix)))
|
||||
}
|
||||
if strings.TrimSpace(o.APIKey) != strings.TrimSpace(n.APIKey) {
|
||||
changes = append(changes, fmt.Sprintf("vertex[%d].api-key: updated", i))
|
||||
}
|
||||
oldModels := SummarizeVertexModels(o.Models)
|
||||
newModels := SummarizeVertexModels(n.Models)
|
||||
if oldModels.hash != newModels.hash {
|
||||
changes = append(changes, fmt.Sprintf("vertex[%d].models: updated (%d -> %d entries)", i, oldModels.count, newModels.count))
|
||||
}
|
||||
if !equalStringMap(o.Headers, n.Headers) {
|
||||
changes = append(changes, fmt.Sprintf("vertex[%d].headers: updated", i))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return changes
|
||||
}
|
||||
|
||||
func trimStrings(in []string) []string {
|
||||
out := make([]string, len(in))
|
||||
for i := range in {
|
||||
out[i] = strings.TrimSpace(in[i])
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func equalStringMap(a, b map[string]string) bool {
|
||||
if len(a) != len(b) {
|
||||
return false
|
||||
}
|
||||
for k, v := range a {
|
||||
if b[k] != v {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func formatProxyURL(raw string) string {
|
||||
trimmed := strings.TrimSpace(raw)
|
||||
if trimmed == "" {
|
||||
return "<none>"
|
||||
}
|
||||
parsed, err := url.Parse(trimmed)
|
||||
if err != nil {
|
||||
return "<redacted>"
|
||||
}
|
||||
host := strings.TrimSpace(parsed.Host)
|
||||
scheme := strings.TrimSpace(parsed.Scheme)
|
||||
if host == "" {
|
||||
// Allow host:port style without scheme.
|
||||
parsed2, err2 := url.Parse("http://" + trimmed)
|
||||
if err2 == nil {
|
||||
host = strings.TrimSpace(parsed2.Host)
|
||||
}
|
||||
scheme = ""
|
||||
}
|
||||
if host == "" {
|
||||
return "<redacted>"
|
||||
}
|
||||
if scheme == "" {
|
||||
return host
|
||||
}
|
||||
return scheme + "://" + host
|
||||
}
|
||||
529
internal/watcher/diff/config_diff_test.go
Normal file
529
internal/watcher/diff/config_diff_test.go
Normal file
@@ -0,0 +1,529 @@
|
||||
package diff
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||
sdkconfig "github.com/router-for-me/CLIProxyAPI/v6/sdk/config"
|
||||
)
|
||||
|
||||
func TestBuildConfigChangeDetails(t *testing.T) {
|
||||
oldCfg := &config.Config{
|
||||
Port: 8080,
|
||||
AuthDir: "/tmp/auth-old",
|
||||
GeminiKey: []config.GeminiKey{
|
||||
{APIKey: "old", BaseURL: "http://old", ExcludedModels: []string{"old-model"}},
|
||||
},
|
||||
AmpCode: config.AmpCode{
|
||||
UpstreamURL: "http://old-upstream",
|
||||
ModelMappings: []config.AmpModelMapping{{From: "from-old", To: "to-old"}},
|
||||
RestrictManagementToLocalhost: false,
|
||||
},
|
||||
RemoteManagement: config.RemoteManagement{
|
||||
AllowRemote: false,
|
||||
SecretKey: "old",
|
||||
DisableControlPanel: false,
|
||||
PanelGitHubRepository: "repo-old",
|
||||
},
|
||||
OAuthExcludedModels: map[string][]string{
|
||||
"providerA": {"m1"},
|
||||
},
|
||||
OpenAICompatibility: []config.OpenAICompatibility{
|
||||
{
|
||||
Name: "compat-a",
|
||||
APIKeyEntries: []config.OpenAICompatibilityAPIKey{
|
||||
{APIKey: "k1"},
|
||||
},
|
||||
Models: []config.OpenAICompatibilityModel{{Name: "m1"}},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
newCfg := &config.Config{
|
||||
Port: 9090,
|
||||
AuthDir: "/tmp/auth-new",
|
||||
GeminiKey: []config.GeminiKey{
|
||||
{APIKey: "old", BaseURL: "http://old", ExcludedModels: []string{"old-model", "extra"}},
|
||||
},
|
||||
AmpCode: config.AmpCode{
|
||||
UpstreamURL: "http://new-upstream",
|
||||
RestrictManagementToLocalhost: true,
|
||||
ModelMappings: []config.AmpModelMapping{
|
||||
{From: "from-old", To: "to-old"},
|
||||
{From: "from-new", To: "to-new"},
|
||||
},
|
||||
},
|
||||
RemoteManagement: config.RemoteManagement{
|
||||
AllowRemote: true,
|
||||
SecretKey: "new",
|
||||
DisableControlPanel: true,
|
||||
PanelGitHubRepository: "repo-new",
|
||||
},
|
||||
OAuthExcludedModels: map[string][]string{
|
||||
"providerA": {"m1", "m2"},
|
||||
"providerB": {"x"},
|
||||
},
|
||||
OpenAICompatibility: []config.OpenAICompatibility{
|
||||
{
|
||||
Name: "compat-a",
|
||||
APIKeyEntries: []config.OpenAICompatibilityAPIKey{
|
||||
{APIKey: "k1"},
|
||||
},
|
||||
Models: []config.OpenAICompatibilityModel{{Name: "m1"}, {Name: "m2"}},
|
||||
},
|
||||
{
|
||||
Name: "compat-b",
|
||||
APIKeyEntries: []config.OpenAICompatibilityAPIKey{
|
||||
{APIKey: "k2"},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
details := BuildConfigChangeDetails(oldCfg, newCfg)
|
||||
|
||||
expectContains(t, details, "port: 8080 -> 9090")
|
||||
expectContains(t, details, "auth-dir: /tmp/auth-old -> /tmp/auth-new")
|
||||
expectContains(t, details, "gemini[0].excluded-models: updated (1 -> 2 entries)")
|
||||
expectContains(t, details, "ampcode.upstream-url: http://old-upstream -> http://new-upstream")
|
||||
expectContains(t, details, "ampcode.model-mappings: updated (1 -> 2 entries)")
|
||||
expectContains(t, details, "remote-management.allow-remote: false -> true")
|
||||
expectContains(t, details, "remote-management.secret-key: updated")
|
||||
expectContains(t, details, "oauth-excluded-models[providera]: updated (1 -> 2 entries)")
|
||||
expectContains(t, details, "oauth-excluded-models[providerb]: added (1 entries)")
|
||||
expectContains(t, details, "openai-compatibility:")
|
||||
expectContains(t, details, " provider added: compat-b (api-keys=1, models=0)")
|
||||
expectContains(t, details, " provider updated: compat-a (models 1 -> 2)")
|
||||
}
|
||||
|
||||
func TestBuildConfigChangeDetails_NoChanges(t *testing.T) {
|
||||
cfg := &config.Config{
|
||||
Port: 8080,
|
||||
}
|
||||
if details := BuildConfigChangeDetails(cfg, cfg); len(details) != 0 {
|
||||
t.Fatalf("expected no change entries, got %v", details)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildConfigChangeDetails_GeminiVertexHeadersAndForceMappings(t *testing.T) {
|
||||
oldCfg := &config.Config{
|
||||
GeminiKey: []config.GeminiKey{
|
||||
{APIKey: "g1", Headers: map[string]string{"H": "1"}, ExcludedModels: []string{"a"}},
|
||||
},
|
||||
VertexCompatAPIKey: []config.VertexCompatKey{
|
||||
{APIKey: "v1", BaseURL: "http://v-old", Models: []config.VertexCompatModel{{Name: "m1"}}},
|
||||
},
|
||||
AmpCode: config.AmpCode{
|
||||
ModelMappings: []config.AmpModelMapping{{From: "a", To: "b"}},
|
||||
ForceModelMappings: false,
|
||||
},
|
||||
}
|
||||
newCfg := &config.Config{
|
||||
GeminiKey: []config.GeminiKey{
|
||||
{APIKey: "g1", Headers: map[string]string{"H": "2"}, ExcludedModels: []string{"a", "b"}},
|
||||
},
|
||||
VertexCompatAPIKey: []config.VertexCompatKey{
|
||||
{APIKey: "v1", BaseURL: "http://v-new", Models: []config.VertexCompatModel{{Name: "m1"}, {Name: "m2"}}},
|
||||
},
|
||||
AmpCode: config.AmpCode{
|
||||
ModelMappings: []config.AmpModelMapping{{From: "a", To: "c"}},
|
||||
ForceModelMappings: true,
|
||||
},
|
||||
}
|
||||
|
||||
details := BuildConfigChangeDetails(oldCfg, newCfg)
|
||||
expectContains(t, details, "gemini[0].headers: updated")
|
||||
expectContains(t, details, "gemini[0].excluded-models: updated (1 -> 2 entries)")
|
||||
expectContains(t, details, "ampcode.model-mappings: updated (1 -> 1 entries)")
|
||||
expectContains(t, details, "ampcode.force-model-mappings: false -> true")
|
||||
}
|
||||
|
||||
func TestBuildConfigChangeDetails_ModelPrefixes(t *testing.T) {
|
||||
oldCfg := &config.Config{
|
||||
GeminiKey: []config.GeminiKey{
|
||||
{APIKey: "g1", Prefix: "old-g", BaseURL: "http://g", ProxyURL: "http://gp"},
|
||||
},
|
||||
ClaudeKey: []config.ClaudeKey{
|
||||
{APIKey: "c1", Prefix: "old-c", BaseURL: "http://c", ProxyURL: "http://cp"},
|
||||
},
|
||||
CodexKey: []config.CodexKey{
|
||||
{APIKey: "x1", Prefix: "old-x", BaseURL: "http://x", ProxyURL: "http://xp"},
|
||||
},
|
||||
VertexCompatAPIKey: []config.VertexCompatKey{
|
||||
{APIKey: "v1", Prefix: "old-v", BaseURL: "http://v", ProxyURL: "http://vp"},
|
||||
},
|
||||
}
|
||||
newCfg := &config.Config{
|
||||
GeminiKey: []config.GeminiKey{
|
||||
{APIKey: "g1", Prefix: "new-g", BaseURL: "http://g", ProxyURL: "http://gp"},
|
||||
},
|
||||
ClaudeKey: []config.ClaudeKey{
|
||||
{APIKey: "c1", Prefix: "new-c", BaseURL: "http://c", ProxyURL: "http://cp"},
|
||||
},
|
||||
CodexKey: []config.CodexKey{
|
||||
{APIKey: "x1", Prefix: "new-x", BaseURL: "http://x", ProxyURL: "http://xp"},
|
||||
},
|
||||
VertexCompatAPIKey: []config.VertexCompatKey{
|
||||
{APIKey: "v1", Prefix: "new-v", BaseURL: "http://v", ProxyURL: "http://vp"},
|
||||
},
|
||||
}
|
||||
|
||||
changes := BuildConfigChangeDetails(oldCfg, newCfg)
|
||||
expectContains(t, changes, "gemini[0].prefix: old-g -> new-g")
|
||||
expectContains(t, changes, "claude[0].prefix: old-c -> new-c")
|
||||
expectContains(t, changes, "codex[0].prefix: old-x -> new-x")
|
||||
expectContains(t, changes, "vertex[0].prefix: old-v -> new-v")
|
||||
}
|
||||
|
||||
func TestBuildConfigChangeDetails_NilSafe(t *testing.T) {
|
||||
if details := BuildConfigChangeDetails(nil, &config.Config{}); len(details) != 0 {
|
||||
t.Fatalf("expected empty change list when old nil, got %v", details)
|
||||
}
|
||||
if details := BuildConfigChangeDetails(&config.Config{}, nil); len(details) != 0 {
|
||||
t.Fatalf("expected empty change list when new nil, got %v", details)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildConfigChangeDetails_SecretsAndCounts(t *testing.T) {
|
||||
oldCfg := &config.Config{
|
||||
SDKConfig: sdkconfig.SDKConfig{
|
||||
APIKeys: []string{"a"},
|
||||
},
|
||||
AmpCode: config.AmpCode{
|
||||
UpstreamAPIKey: "",
|
||||
},
|
||||
RemoteManagement: config.RemoteManagement{
|
||||
SecretKey: "",
|
||||
},
|
||||
}
|
||||
newCfg := &config.Config{
|
||||
SDKConfig: sdkconfig.SDKConfig{
|
||||
APIKeys: []string{"a", "b", "c"},
|
||||
},
|
||||
AmpCode: config.AmpCode{
|
||||
UpstreamAPIKey: "new-key",
|
||||
},
|
||||
RemoteManagement: config.RemoteManagement{
|
||||
SecretKey: "new-secret",
|
||||
},
|
||||
}
|
||||
|
||||
details := BuildConfigChangeDetails(oldCfg, newCfg)
|
||||
expectContains(t, details, "api-keys count: 1 -> 3")
|
||||
expectContains(t, details, "ampcode.upstream-api-key: added")
|
||||
expectContains(t, details, "remote-management.secret-key: created")
|
||||
}
|
||||
|
||||
func TestBuildConfigChangeDetails_FlagsAndKeys(t *testing.T) {
|
||||
oldCfg := &config.Config{
|
||||
Port: 1000,
|
||||
AuthDir: "/old",
|
||||
Debug: false,
|
||||
LoggingToFile: false,
|
||||
UsageStatisticsEnabled: false,
|
||||
DisableCooling: false,
|
||||
RequestRetry: 1,
|
||||
MaxRetryInterval: 1,
|
||||
WebsocketAuth: false,
|
||||
QuotaExceeded: config.QuotaExceeded{SwitchProject: false, SwitchPreviewModel: false},
|
||||
ClaudeKey: []config.ClaudeKey{{APIKey: "c1"}},
|
||||
CodexKey: []config.CodexKey{{APIKey: "x1"}},
|
||||
AmpCode: config.AmpCode{UpstreamAPIKey: "keep", RestrictManagementToLocalhost: false},
|
||||
RemoteManagement: config.RemoteManagement{DisableControlPanel: false, PanelGitHubRepository: "old/repo", SecretKey: "keep"},
|
||||
SDKConfig: sdkconfig.SDKConfig{
|
||||
RequestLog: false,
|
||||
ProxyURL: "http://old-proxy",
|
||||
APIKeys: []string{"key-1"},
|
||||
ForceModelPrefix: false,
|
||||
},
|
||||
}
|
||||
newCfg := &config.Config{
|
||||
Port: 2000,
|
||||
AuthDir: "/new",
|
||||
Debug: true,
|
||||
LoggingToFile: true,
|
||||
UsageStatisticsEnabled: true,
|
||||
DisableCooling: true,
|
||||
RequestRetry: 2,
|
||||
MaxRetryInterval: 3,
|
||||
WebsocketAuth: true,
|
||||
QuotaExceeded: config.QuotaExceeded{SwitchProject: true, SwitchPreviewModel: true},
|
||||
ClaudeKey: []config.ClaudeKey{
|
||||
{APIKey: "c1", BaseURL: "http://new", ProxyURL: "http://p", Headers: map[string]string{"H": "1"}, ExcludedModels: []string{"a"}},
|
||||
{APIKey: "c2"},
|
||||
},
|
||||
CodexKey: []config.CodexKey{
|
||||
{APIKey: "x1", BaseURL: "http://x", ProxyURL: "http://px", Headers: map[string]string{"H": "2"}, ExcludedModels: []string{"b"}},
|
||||
{APIKey: "x2"},
|
||||
},
|
||||
AmpCode: config.AmpCode{
|
||||
UpstreamAPIKey: "",
|
||||
RestrictManagementToLocalhost: true,
|
||||
ModelMappings: []config.AmpModelMapping{{From: "a", To: "b"}},
|
||||
},
|
||||
RemoteManagement: config.RemoteManagement{
|
||||
DisableControlPanel: true,
|
||||
PanelGitHubRepository: "new/repo",
|
||||
SecretKey: "",
|
||||
},
|
||||
SDKConfig: sdkconfig.SDKConfig{
|
||||
RequestLog: true,
|
||||
ProxyURL: "http://new-proxy",
|
||||
APIKeys: []string{" key-1 ", "key-2"},
|
||||
ForceModelPrefix: true,
|
||||
},
|
||||
}
|
||||
|
||||
details := BuildConfigChangeDetails(oldCfg, newCfg)
|
||||
expectContains(t, details, "debug: false -> true")
|
||||
expectContains(t, details, "logging-to-file: false -> true")
|
||||
expectContains(t, details, "usage-statistics-enabled: false -> true")
|
||||
expectContains(t, details, "disable-cooling: false -> true")
|
||||
expectContains(t, details, "request-log: false -> true")
|
||||
expectContains(t, details, "request-retry: 1 -> 2")
|
||||
expectContains(t, details, "max-retry-interval: 1 -> 3")
|
||||
expectContains(t, details, "proxy-url: http://old-proxy -> http://new-proxy")
|
||||
expectContains(t, details, "ws-auth: false -> true")
|
||||
expectContains(t, details, "force-model-prefix: false -> true")
|
||||
expectContains(t, details, "quota-exceeded.switch-project: false -> true")
|
||||
expectContains(t, details, "quota-exceeded.switch-preview-model: false -> true")
|
||||
expectContains(t, details, "api-keys count: 1 -> 2")
|
||||
expectContains(t, details, "claude-api-key count: 1 -> 2")
|
||||
expectContains(t, details, "codex-api-key count: 1 -> 2")
|
||||
expectContains(t, details, "ampcode.restrict-management-to-localhost: false -> true")
|
||||
expectContains(t, details, "ampcode.upstream-api-key: removed")
|
||||
expectContains(t, details, "remote-management.disable-control-panel: false -> true")
|
||||
expectContains(t, details, "remote-management.panel-github-repository: old/repo -> new/repo")
|
||||
expectContains(t, details, "remote-management.secret-key: deleted")
|
||||
}
|
||||
|
||||
func TestBuildConfigChangeDetails_AllBranches(t *testing.T) {
|
||||
oldCfg := &config.Config{
|
||||
Port: 1,
|
||||
AuthDir: "/a",
|
||||
Debug: false,
|
||||
LoggingToFile: false,
|
||||
UsageStatisticsEnabled: false,
|
||||
DisableCooling: false,
|
||||
RequestRetry: 1,
|
||||
MaxRetryInterval: 1,
|
||||
WebsocketAuth: false,
|
||||
QuotaExceeded: config.QuotaExceeded{SwitchProject: false, SwitchPreviewModel: false},
|
||||
GeminiKey: []config.GeminiKey{
|
||||
{APIKey: "g-old", BaseURL: "http://g-old", ProxyURL: "http://gp-old", Headers: map[string]string{"A": "1"}},
|
||||
},
|
||||
ClaudeKey: []config.ClaudeKey{
|
||||
{APIKey: "c-old", BaseURL: "http://c-old", ProxyURL: "http://cp-old", Headers: map[string]string{"H": "1"}, ExcludedModels: []string{"x"}},
|
||||
},
|
||||
CodexKey: []config.CodexKey{
|
||||
{APIKey: "x-old", BaseURL: "http://x-old", ProxyURL: "http://xp-old", Headers: map[string]string{"H": "1"}, ExcludedModels: []string{"x"}},
|
||||
},
|
||||
VertexCompatAPIKey: []config.VertexCompatKey{
|
||||
{APIKey: "v-old", BaseURL: "http://v-old", ProxyURL: "http://vp-old", Headers: map[string]string{"H": "1"}, Models: []config.VertexCompatModel{{Name: "m1"}}},
|
||||
},
|
||||
AmpCode: config.AmpCode{
|
||||
UpstreamURL: "http://amp-old",
|
||||
UpstreamAPIKey: "old-key",
|
||||
RestrictManagementToLocalhost: false,
|
||||
ModelMappings: []config.AmpModelMapping{{From: "a", To: "b"}},
|
||||
ForceModelMappings: false,
|
||||
},
|
||||
RemoteManagement: config.RemoteManagement{
|
||||
AllowRemote: false,
|
||||
DisableControlPanel: false,
|
||||
PanelGitHubRepository: "old/repo",
|
||||
SecretKey: "old",
|
||||
},
|
||||
SDKConfig: sdkconfig.SDKConfig{
|
||||
RequestLog: false,
|
||||
ProxyURL: "http://old-proxy",
|
||||
APIKeys: []string{" keyA "},
|
||||
},
|
||||
OAuthExcludedModels: map[string][]string{"p1": {"a"}},
|
||||
OpenAICompatibility: []config.OpenAICompatibility{
|
||||
{
|
||||
Name: "prov-old",
|
||||
APIKeyEntries: []config.OpenAICompatibilityAPIKey{
|
||||
{APIKey: "k1"},
|
||||
},
|
||||
Models: []config.OpenAICompatibilityModel{{Name: "m1"}},
|
||||
},
|
||||
},
|
||||
}
|
||||
newCfg := &config.Config{
|
||||
Port: 2,
|
||||
AuthDir: "/b",
|
||||
Debug: true,
|
||||
LoggingToFile: true,
|
||||
UsageStatisticsEnabled: true,
|
||||
DisableCooling: true,
|
||||
RequestRetry: 2,
|
||||
MaxRetryInterval: 3,
|
||||
WebsocketAuth: true,
|
||||
QuotaExceeded: config.QuotaExceeded{SwitchProject: true, SwitchPreviewModel: true},
|
||||
GeminiKey: []config.GeminiKey{
|
||||
{APIKey: "g-new", BaseURL: "http://g-new", ProxyURL: "http://gp-new", Headers: map[string]string{"A": "2"}, ExcludedModels: []string{"x", "y"}},
|
||||
},
|
||||
ClaudeKey: []config.ClaudeKey{
|
||||
{APIKey: "c-new", BaseURL: "http://c-new", ProxyURL: "http://cp-new", Headers: map[string]string{"H": "2"}, ExcludedModels: []string{"x", "y"}},
|
||||
},
|
||||
CodexKey: []config.CodexKey{
|
||||
{APIKey: "x-new", BaseURL: "http://x-new", ProxyURL: "http://xp-new", Headers: map[string]string{"H": "2"}, ExcludedModels: []string{"x", "y"}},
|
||||
},
|
||||
VertexCompatAPIKey: []config.VertexCompatKey{
|
||||
{APIKey: "v-new", BaseURL: "http://v-new", ProxyURL: "http://vp-new", Headers: map[string]string{"H": "2"}, Models: []config.VertexCompatModel{{Name: "m1"}, {Name: "m2"}}},
|
||||
},
|
||||
AmpCode: config.AmpCode{
|
||||
UpstreamURL: "http://amp-new",
|
||||
UpstreamAPIKey: "",
|
||||
RestrictManagementToLocalhost: true,
|
||||
ModelMappings: []config.AmpModelMapping{{From: "a", To: "c"}},
|
||||
ForceModelMappings: true,
|
||||
},
|
||||
RemoteManagement: config.RemoteManagement{
|
||||
AllowRemote: true,
|
||||
DisableControlPanel: true,
|
||||
PanelGitHubRepository: "new/repo",
|
||||
SecretKey: "",
|
||||
},
|
||||
SDKConfig: sdkconfig.SDKConfig{
|
||||
RequestLog: true,
|
||||
ProxyURL: "http://new-proxy",
|
||||
APIKeys: []string{"keyB"},
|
||||
},
|
||||
OAuthExcludedModels: map[string][]string{"p1": {"b", "c"}, "p2": {"d"}},
|
||||
OpenAICompatibility: []config.OpenAICompatibility{
|
||||
{
|
||||
Name: "prov-old",
|
||||
APIKeyEntries: []config.OpenAICompatibilityAPIKey{
|
||||
{APIKey: "k1"},
|
||||
{APIKey: "k2"},
|
||||
},
|
||||
Models: []config.OpenAICompatibilityModel{{Name: "m1"}, {Name: "m2"}},
|
||||
},
|
||||
{
|
||||
Name: "prov-new",
|
||||
APIKeyEntries: []config.OpenAICompatibilityAPIKey{{APIKey: "k3"}},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
changes := BuildConfigChangeDetails(oldCfg, newCfg)
|
||||
expectContains(t, changes, "port: 1 -> 2")
|
||||
expectContains(t, changes, "auth-dir: /a -> /b")
|
||||
expectContains(t, changes, "debug: false -> true")
|
||||
expectContains(t, changes, "logging-to-file: false -> true")
|
||||
expectContains(t, changes, "usage-statistics-enabled: false -> true")
|
||||
expectContains(t, changes, "disable-cooling: false -> true")
|
||||
expectContains(t, changes, "request-retry: 1 -> 2")
|
||||
expectContains(t, changes, "max-retry-interval: 1 -> 3")
|
||||
expectContains(t, changes, "proxy-url: http://old-proxy -> http://new-proxy")
|
||||
expectContains(t, changes, "ws-auth: false -> true")
|
||||
expectContains(t, changes, "quota-exceeded.switch-project: false -> true")
|
||||
expectContains(t, changes, "quota-exceeded.switch-preview-model: false -> true")
|
||||
expectContains(t, changes, "api-keys: values updated (count unchanged, redacted)")
|
||||
expectContains(t, changes, "gemini[0].base-url: http://g-old -> http://g-new")
|
||||
expectContains(t, changes, "gemini[0].proxy-url: http://gp-old -> http://gp-new")
|
||||
expectContains(t, changes, "gemini[0].api-key: updated")
|
||||
expectContains(t, changes, "gemini[0].headers: updated")
|
||||
expectContains(t, changes, "gemini[0].excluded-models: updated (0 -> 2 entries)")
|
||||
expectContains(t, changes, "claude[0].base-url: http://c-old -> http://c-new")
|
||||
expectContains(t, changes, "claude[0].proxy-url: http://cp-old -> http://cp-new")
|
||||
expectContains(t, changes, "claude[0].api-key: updated")
|
||||
expectContains(t, changes, "claude[0].headers: updated")
|
||||
expectContains(t, changes, "claude[0].excluded-models: updated (1 -> 2 entries)")
|
||||
expectContains(t, changes, "codex[0].base-url: http://x-old -> http://x-new")
|
||||
expectContains(t, changes, "codex[0].proxy-url: http://xp-old -> http://xp-new")
|
||||
expectContains(t, changes, "codex[0].api-key: updated")
|
||||
expectContains(t, changes, "codex[0].headers: updated")
|
||||
expectContains(t, changes, "codex[0].excluded-models: updated (1 -> 2 entries)")
|
||||
expectContains(t, changes, "vertex[0].base-url: http://v-old -> http://v-new")
|
||||
expectContains(t, changes, "vertex[0].proxy-url: http://vp-old -> http://vp-new")
|
||||
expectContains(t, changes, "vertex[0].api-key: updated")
|
||||
expectContains(t, changes, "vertex[0].models: updated (1 -> 2 entries)")
|
||||
expectContains(t, changes, "vertex[0].headers: updated")
|
||||
expectContains(t, changes, "ampcode.upstream-url: http://amp-old -> http://amp-new")
|
||||
expectContains(t, changes, "ampcode.upstream-api-key: removed")
|
||||
expectContains(t, changes, "ampcode.restrict-management-to-localhost: false -> true")
|
||||
expectContains(t, changes, "ampcode.model-mappings: updated (1 -> 1 entries)")
|
||||
expectContains(t, changes, "ampcode.force-model-mappings: false -> true")
|
||||
expectContains(t, changes, "oauth-excluded-models[p1]: updated (1 -> 2 entries)")
|
||||
expectContains(t, changes, "oauth-excluded-models[p2]: added (1 entries)")
|
||||
expectContains(t, changes, "remote-management.allow-remote: false -> true")
|
||||
expectContains(t, changes, "remote-management.disable-control-panel: false -> true")
|
||||
expectContains(t, changes, "remote-management.panel-github-repository: old/repo -> new/repo")
|
||||
expectContains(t, changes, "remote-management.secret-key: deleted")
|
||||
expectContains(t, changes, "openai-compatibility:")
|
||||
}
|
||||
|
||||
func TestFormatProxyURL(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
in string
|
||||
want string
|
||||
}{
|
||||
{name: "empty", in: "", want: "<none>"},
|
||||
{name: "invalid", in: "http://[::1", want: "<redacted>"},
|
||||
{name: "fullURLRedactsUserinfoAndPath", in: "http://user:pass@example.com:8080/path?x=1#frag", want: "http://example.com:8080"},
|
||||
{name: "socks5RedactsUserinfoAndPath", in: "socks5://user:pass@192.168.1.1:1080/path?x=1", want: "socks5://192.168.1.1:1080"},
|
||||
{name: "socks5HostPort", in: "socks5://proxy.example.com:1080/", want: "socks5://proxy.example.com:1080"},
|
||||
{name: "hostPortNoScheme", in: "example.com:1234/path?x=1", want: "example.com:1234"},
|
||||
{name: "relativePathRedacted", in: "/just/path", want: "<redacted>"},
|
||||
{name: "schemeAndHost", in: "https://example.com", want: "https://example.com"},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if got := formatProxyURL(tt.in); got != tt.want {
|
||||
t.Fatalf("expected %q, got %q", tt.want, got)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildConfigChangeDetails_SecretAndUpstreamUpdates(t *testing.T) {
|
||||
oldCfg := &config.Config{
|
||||
AmpCode: config.AmpCode{
|
||||
UpstreamAPIKey: "old",
|
||||
},
|
||||
RemoteManagement: config.RemoteManagement{
|
||||
SecretKey: "old",
|
||||
},
|
||||
}
|
||||
newCfg := &config.Config{
|
||||
AmpCode: config.AmpCode{
|
||||
UpstreamAPIKey: "new",
|
||||
},
|
||||
RemoteManagement: config.RemoteManagement{
|
||||
SecretKey: "new",
|
||||
},
|
||||
}
|
||||
|
||||
changes := BuildConfigChangeDetails(oldCfg, newCfg)
|
||||
expectContains(t, changes, "ampcode.upstream-api-key: updated")
|
||||
expectContains(t, changes, "remote-management.secret-key: updated")
|
||||
}
|
||||
|
||||
func TestBuildConfigChangeDetails_CountBranches(t *testing.T) {
|
||||
oldCfg := &config.Config{}
|
||||
newCfg := &config.Config{
|
||||
GeminiKey: []config.GeminiKey{{APIKey: "g"}},
|
||||
ClaudeKey: []config.ClaudeKey{{APIKey: "c"}},
|
||||
CodexKey: []config.CodexKey{{APIKey: "x"}},
|
||||
VertexCompatAPIKey: []config.VertexCompatKey{
|
||||
{APIKey: "v", BaseURL: "http://v"},
|
||||
},
|
||||
}
|
||||
|
||||
changes := BuildConfigChangeDetails(oldCfg, newCfg)
|
||||
expectContains(t, changes, "gemini-api-key count: 0 -> 1")
|
||||
expectContains(t, changes, "claude-api-key count: 0 -> 1")
|
||||
expectContains(t, changes, "codex-api-key count: 0 -> 1")
|
||||
expectContains(t, changes, "vertex-api-key count: 0 -> 1")
|
||||
}
|
||||
|
||||
func TestTrimStrings(t *testing.T) {
|
||||
out := trimStrings([]string{" a ", "b", " c"})
|
||||
if len(out) != 3 || out[0] != "a" || out[1] != "b" || out[2] != "c" {
|
||||
t.Fatalf("unexpected trimmed strings: %v", out)
|
||||
}
|
||||
}
|
||||
102
internal/watcher/diff/model_hash.go
Normal file
102
internal/watcher/diff/model_hash.go
Normal file
@@ -0,0 +1,102 @@
|
||||
package diff
|
||||
|
||||
import (
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
"encoding/json"
|
||||
"sort"
|
||||
"strings"
|
||||
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||
)
|
||||
|
||||
// ComputeOpenAICompatModelsHash returns a stable hash for OpenAI-compat models.
|
||||
// Used to detect model list changes during hot reload.
|
||||
func ComputeOpenAICompatModelsHash(models []config.OpenAICompatibilityModel) string {
|
||||
keys := normalizeModelPairs(func(out func(key string)) {
|
||||
for _, model := range models {
|
||||
name := strings.TrimSpace(model.Name)
|
||||
alias := strings.TrimSpace(model.Alias)
|
||||
if name == "" && alias == "" {
|
||||
continue
|
||||
}
|
||||
out(strings.ToLower(name) + "|" + strings.ToLower(alias))
|
||||
}
|
||||
})
|
||||
return hashJoined(keys)
|
||||
}
|
||||
|
||||
// ComputeVertexCompatModelsHash returns a stable hash for Vertex-compatible models.
|
||||
func ComputeVertexCompatModelsHash(models []config.VertexCompatModel) string {
|
||||
keys := normalizeModelPairs(func(out func(key string)) {
|
||||
for _, model := range models {
|
||||
name := strings.TrimSpace(model.Name)
|
||||
alias := strings.TrimSpace(model.Alias)
|
||||
if name == "" && alias == "" {
|
||||
continue
|
||||
}
|
||||
out(strings.ToLower(name) + "|" + strings.ToLower(alias))
|
||||
}
|
||||
})
|
||||
return hashJoined(keys)
|
||||
}
|
||||
|
||||
// ComputeClaudeModelsHash returns a stable hash for Claude model aliases.
|
||||
func ComputeClaudeModelsHash(models []config.ClaudeModel) string {
|
||||
keys := normalizeModelPairs(func(out func(key string)) {
|
||||
for _, model := range models {
|
||||
name := strings.TrimSpace(model.Name)
|
||||
alias := strings.TrimSpace(model.Alias)
|
||||
if name == "" && alias == "" {
|
||||
continue
|
||||
}
|
||||
out(strings.ToLower(name) + "|" + strings.ToLower(alias))
|
||||
}
|
||||
})
|
||||
return hashJoined(keys)
|
||||
}
|
||||
|
||||
// ComputeExcludedModelsHash returns a normalized hash for excluded model lists.
|
||||
func ComputeExcludedModelsHash(excluded []string) string {
|
||||
if len(excluded) == 0 {
|
||||
return ""
|
||||
}
|
||||
normalized := make([]string, 0, len(excluded))
|
||||
for _, entry := range excluded {
|
||||
if trimmed := strings.TrimSpace(entry); trimmed != "" {
|
||||
normalized = append(normalized, strings.ToLower(trimmed))
|
||||
}
|
||||
}
|
||||
if len(normalized) == 0 {
|
||||
return ""
|
||||
}
|
||||
sort.Strings(normalized)
|
||||
data, _ := json.Marshal(normalized)
|
||||
sum := sha256.Sum256(data)
|
||||
return hex.EncodeToString(sum[:])
|
||||
}
|
||||
|
||||
func normalizeModelPairs(collect func(out func(key string))) []string {
|
||||
seen := make(map[string]struct{})
|
||||
keys := make([]string, 0)
|
||||
collect(func(key string) {
|
||||
if _, exists := seen[key]; exists {
|
||||
return
|
||||
}
|
||||
seen[key] = struct{}{}
|
||||
keys = append(keys, key)
|
||||
})
|
||||
if len(keys) == 0 {
|
||||
return nil
|
||||
}
|
||||
sort.Strings(keys)
|
||||
return keys
|
||||
}
|
||||
|
||||
func hashJoined(keys []string) string {
|
||||
if len(keys) == 0 {
|
||||
return ""
|
||||
}
|
||||
sum := sha256.Sum256([]byte(strings.Join(keys, "\n")))
|
||||
return hex.EncodeToString(sum[:])
|
||||
}
|
||||
159
internal/watcher/diff/model_hash_test.go
Normal file
159
internal/watcher/diff/model_hash_test.go
Normal file
@@ -0,0 +1,159 @@
|
||||
package diff
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||
)
|
||||
|
||||
func TestComputeOpenAICompatModelsHash_Deterministic(t *testing.T) {
|
||||
models := []config.OpenAICompatibilityModel{
|
||||
{Name: "gpt-4", Alias: "gpt4"},
|
||||
{Name: "gpt-3.5-turbo"},
|
||||
}
|
||||
hash1 := ComputeOpenAICompatModelsHash(models)
|
||||
hash2 := ComputeOpenAICompatModelsHash(models)
|
||||
if hash1 == "" {
|
||||
t.Fatal("hash should not be empty")
|
||||
}
|
||||
if hash1 != hash2 {
|
||||
t.Fatalf("hash should be deterministic, got %s vs %s", hash1, hash2)
|
||||
}
|
||||
changed := ComputeOpenAICompatModelsHash([]config.OpenAICompatibilityModel{{Name: "gpt-4"}, {Name: "gpt-4.1"}})
|
||||
if hash1 == changed {
|
||||
t.Fatal("hash should change when model list changes")
|
||||
}
|
||||
}
|
||||
|
||||
func TestComputeOpenAICompatModelsHash_NormalizesAndDedups(t *testing.T) {
|
||||
a := []config.OpenAICompatibilityModel{
|
||||
{Name: "gpt-4", Alias: "gpt4"},
|
||||
{Name: " "},
|
||||
{Name: "GPT-4", Alias: "GPT4"},
|
||||
{Alias: "a1"},
|
||||
}
|
||||
b := []config.OpenAICompatibilityModel{
|
||||
{Alias: "A1"},
|
||||
{Name: "gpt-4", Alias: "gpt4"},
|
||||
}
|
||||
h1 := ComputeOpenAICompatModelsHash(a)
|
||||
h2 := ComputeOpenAICompatModelsHash(b)
|
||||
if h1 == "" || h2 == "" {
|
||||
t.Fatal("expected non-empty hashes for non-empty model sets")
|
||||
}
|
||||
if h1 != h2 {
|
||||
t.Fatalf("expected normalized hashes to match, got %s / %s", h1, h2)
|
||||
}
|
||||
}
|
||||
|
||||
func TestComputeVertexCompatModelsHash_DifferentInputs(t *testing.T) {
|
||||
models := []config.VertexCompatModel{{Name: "gemini-pro", Alias: "pro"}}
|
||||
hash1 := ComputeVertexCompatModelsHash(models)
|
||||
hash2 := ComputeVertexCompatModelsHash([]config.VertexCompatModel{{Name: "gemini-1.5-pro", Alias: "pro"}})
|
||||
if hash1 == "" || hash2 == "" {
|
||||
t.Fatal("hashes should not be empty for non-empty models")
|
||||
}
|
||||
if hash1 == hash2 {
|
||||
t.Fatal("hash should differ when model content differs")
|
||||
}
|
||||
}
|
||||
|
||||
func TestComputeVertexCompatModelsHash_IgnoresBlankAndOrder(t *testing.T) {
|
||||
a := []config.VertexCompatModel{
|
||||
{Name: "m1", Alias: "a1"},
|
||||
{Name: " "},
|
||||
{Name: "M1", Alias: "A1"},
|
||||
}
|
||||
b := []config.VertexCompatModel{
|
||||
{Name: "m1", Alias: "a1"},
|
||||
}
|
||||
if h1, h2 := ComputeVertexCompatModelsHash(a), ComputeVertexCompatModelsHash(b); h1 == "" || h1 != h2 {
|
||||
t.Fatalf("expected same hash ignoring blanks/dupes, got %q / %q", h1, h2)
|
||||
}
|
||||
}
|
||||
|
||||
func TestComputeClaudeModelsHash_Empty(t *testing.T) {
|
||||
if got := ComputeClaudeModelsHash(nil); got != "" {
|
||||
t.Fatalf("expected empty hash for nil models, got %q", got)
|
||||
}
|
||||
if got := ComputeClaudeModelsHash([]config.ClaudeModel{}); got != "" {
|
||||
t.Fatalf("expected empty hash for empty slice, got %q", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestComputeClaudeModelsHash_IgnoresBlankAndDedup(t *testing.T) {
|
||||
a := []config.ClaudeModel{
|
||||
{Name: "m1", Alias: "a1"},
|
||||
{Name: " "},
|
||||
{Name: "M1", Alias: "A1"},
|
||||
}
|
||||
b := []config.ClaudeModel{
|
||||
{Name: "m1", Alias: "a1"},
|
||||
}
|
||||
if h1, h2 := ComputeClaudeModelsHash(a), ComputeClaudeModelsHash(b); h1 == "" || h1 != h2 {
|
||||
t.Fatalf("expected same hash ignoring blanks/dupes, got %q / %q", h1, h2)
|
||||
}
|
||||
}
|
||||
|
||||
func TestComputeExcludedModelsHash_Normalizes(t *testing.T) {
|
||||
hash1 := ComputeExcludedModelsHash([]string{" A ", "b", "a"})
|
||||
hash2 := ComputeExcludedModelsHash([]string{"a", " b", "A"})
|
||||
if hash1 == "" || hash2 == "" {
|
||||
t.Fatal("hash should not be empty for non-empty input")
|
||||
}
|
||||
if hash1 != hash2 {
|
||||
t.Fatalf("hash should be order/space insensitive for same multiset, got %s vs %s", hash1, hash2)
|
||||
}
|
||||
hash3 := ComputeExcludedModelsHash([]string{"c"})
|
||||
if hash1 == hash3 {
|
||||
t.Fatal("hash should differ for different normalized sets")
|
||||
}
|
||||
}
|
||||
|
||||
func TestComputeOpenAICompatModelsHash_Empty(t *testing.T) {
|
||||
if got := ComputeOpenAICompatModelsHash(nil); got != "" {
|
||||
t.Fatalf("expected empty hash for nil input, got %q", got)
|
||||
}
|
||||
if got := ComputeOpenAICompatModelsHash([]config.OpenAICompatibilityModel{}); got != "" {
|
||||
t.Fatalf("expected empty hash for empty slice, got %q", got)
|
||||
}
|
||||
if got := ComputeOpenAICompatModelsHash([]config.OpenAICompatibilityModel{{Name: " "}, {Alias: ""}}); got != "" {
|
||||
t.Fatalf("expected empty hash for blank models, got %q", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestComputeVertexCompatModelsHash_Empty(t *testing.T) {
|
||||
if got := ComputeVertexCompatModelsHash(nil); got != "" {
|
||||
t.Fatalf("expected empty hash for nil input, got %q", got)
|
||||
}
|
||||
if got := ComputeVertexCompatModelsHash([]config.VertexCompatModel{}); got != "" {
|
||||
t.Fatalf("expected empty hash for empty slice, got %q", got)
|
||||
}
|
||||
if got := ComputeVertexCompatModelsHash([]config.VertexCompatModel{{Name: " "}}); got != "" {
|
||||
t.Fatalf("expected empty hash for blank models, got %q", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestComputeExcludedModelsHash_Empty(t *testing.T) {
|
||||
if got := ComputeExcludedModelsHash(nil); got != "" {
|
||||
t.Fatalf("expected empty hash for nil input, got %q", got)
|
||||
}
|
||||
if got := ComputeExcludedModelsHash([]string{}); got != "" {
|
||||
t.Fatalf("expected empty hash for empty slice, got %q", got)
|
||||
}
|
||||
if got := ComputeExcludedModelsHash([]string{" ", ""}); got != "" {
|
||||
t.Fatalf("expected empty hash for whitespace-only entries, got %q", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestComputeClaudeModelsHash_Deterministic(t *testing.T) {
|
||||
models := []config.ClaudeModel{{Name: "a", Alias: "A"}, {Name: "b"}}
|
||||
h1 := ComputeClaudeModelsHash(models)
|
||||
h2 := ComputeClaudeModelsHash(models)
|
||||
if h1 == "" || h1 != h2 {
|
||||
t.Fatalf("expected deterministic hash, got %s / %s", h1, h2)
|
||||
}
|
||||
if h3 := ComputeClaudeModelsHash([]config.ClaudeModel{{Name: "a"}}); h3 == h1 {
|
||||
t.Fatalf("expected different hash when models change, got %s", h3)
|
||||
}
|
||||
}
|
||||
151
internal/watcher/diff/oauth_excluded.go
Normal file
151
internal/watcher/diff/oauth_excluded.go
Normal file
@@ -0,0 +1,151 @@
|
||||
package diff
|
||||
|
||||
import (
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
"sort"
|
||||
"strings"
|
||||
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||
)
|
||||
|
||||
type ExcludedModelsSummary struct {
|
||||
hash string
|
||||
count int
|
||||
}
|
||||
|
||||
// SummarizeExcludedModels normalizes and hashes an excluded-model list.
|
||||
func SummarizeExcludedModels(list []string) ExcludedModelsSummary {
|
||||
if len(list) == 0 {
|
||||
return ExcludedModelsSummary{}
|
||||
}
|
||||
seen := make(map[string]struct{}, len(list))
|
||||
normalized := make([]string, 0, len(list))
|
||||
for _, entry := range list {
|
||||
if trimmed := strings.ToLower(strings.TrimSpace(entry)); trimmed != "" {
|
||||
if _, exists := seen[trimmed]; exists {
|
||||
continue
|
||||
}
|
||||
seen[trimmed] = struct{}{}
|
||||
normalized = append(normalized, trimmed)
|
||||
}
|
||||
}
|
||||
sort.Strings(normalized)
|
||||
return ExcludedModelsSummary{
|
||||
hash: ComputeExcludedModelsHash(normalized),
|
||||
count: len(normalized),
|
||||
}
|
||||
}
|
||||
|
||||
// SummarizeOAuthExcludedModels summarizes OAuth excluded models per provider.
|
||||
func SummarizeOAuthExcludedModels(entries map[string][]string) map[string]ExcludedModelsSummary {
|
||||
if len(entries) == 0 {
|
||||
return nil
|
||||
}
|
||||
out := make(map[string]ExcludedModelsSummary, len(entries))
|
||||
for k, v := range entries {
|
||||
key := strings.ToLower(strings.TrimSpace(k))
|
||||
if key == "" {
|
||||
continue
|
||||
}
|
||||
out[key] = SummarizeExcludedModels(v)
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
// DiffOAuthExcludedModelChanges compares OAuth excluded models maps.
|
||||
func DiffOAuthExcludedModelChanges(oldMap, newMap map[string][]string) ([]string, []string) {
|
||||
oldSummary := SummarizeOAuthExcludedModels(oldMap)
|
||||
newSummary := SummarizeOAuthExcludedModels(newMap)
|
||||
keys := make(map[string]struct{}, len(oldSummary)+len(newSummary))
|
||||
for k := range oldSummary {
|
||||
keys[k] = struct{}{}
|
||||
}
|
||||
for k := range newSummary {
|
||||
keys[k] = struct{}{}
|
||||
}
|
||||
changes := make([]string, 0, len(keys))
|
||||
affected := make([]string, 0, len(keys))
|
||||
for key := range keys {
|
||||
oldInfo, okOld := oldSummary[key]
|
||||
newInfo, okNew := newSummary[key]
|
||||
switch {
|
||||
case okOld && !okNew:
|
||||
changes = append(changes, fmt.Sprintf("oauth-excluded-models[%s]: removed", key))
|
||||
affected = append(affected, key)
|
||||
case !okOld && okNew:
|
||||
changes = append(changes, fmt.Sprintf("oauth-excluded-models[%s]: added (%d entries)", key, newInfo.count))
|
||||
affected = append(affected, key)
|
||||
case okOld && okNew && oldInfo.hash != newInfo.hash:
|
||||
changes = append(changes, fmt.Sprintf("oauth-excluded-models[%s]: updated (%d -> %d entries)", key, oldInfo.count, newInfo.count))
|
||||
affected = append(affected, key)
|
||||
}
|
||||
}
|
||||
sort.Strings(changes)
|
||||
sort.Strings(affected)
|
||||
return changes, affected
|
||||
}
|
||||
|
||||
type AmpModelMappingsSummary struct {
|
||||
hash string
|
||||
count int
|
||||
}
|
||||
|
||||
// SummarizeAmpModelMappings hashes Amp model mappings for change detection.
|
||||
func SummarizeAmpModelMappings(mappings []config.AmpModelMapping) AmpModelMappingsSummary {
|
||||
if len(mappings) == 0 {
|
||||
return AmpModelMappingsSummary{}
|
||||
}
|
||||
entries := make([]string, 0, len(mappings))
|
||||
for _, mapping := range mappings {
|
||||
from := strings.TrimSpace(mapping.From)
|
||||
to := strings.TrimSpace(mapping.To)
|
||||
if from == "" && to == "" {
|
||||
continue
|
||||
}
|
||||
entries = append(entries, from+"->"+to)
|
||||
}
|
||||
if len(entries) == 0 {
|
||||
return AmpModelMappingsSummary{}
|
||||
}
|
||||
sort.Strings(entries)
|
||||
sum := sha256.Sum256([]byte(strings.Join(entries, "|")))
|
||||
return AmpModelMappingsSummary{
|
||||
hash: hex.EncodeToString(sum[:]),
|
||||
count: len(entries),
|
||||
}
|
||||
}
|
||||
|
||||
type VertexModelsSummary struct {
|
||||
hash string
|
||||
count int
|
||||
}
|
||||
|
||||
// SummarizeVertexModels hashes vertex-compatible models for change detection.
|
||||
func SummarizeVertexModels(models []config.VertexCompatModel) VertexModelsSummary {
|
||||
if len(models) == 0 {
|
||||
return VertexModelsSummary{}
|
||||
}
|
||||
names := make([]string, 0, len(models))
|
||||
for _, m := range models {
|
||||
name := strings.TrimSpace(m.Name)
|
||||
alias := strings.TrimSpace(m.Alias)
|
||||
if name == "" && alias == "" {
|
||||
continue
|
||||
}
|
||||
if alias != "" {
|
||||
name = alias
|
||||
}
|
||||
names = append(names, name)
|
||||
}
|
||||
if len(names) == 0 {
|
||||
return VertexModelsSummary{}
|
||||
}
|
||||
sort.Strings(names)
|
||||
sum := sha256.Sum256([]byte(strings.Join(names, "|")))
|
||||
return VertexModelsSummary{
|
||||
hash: hex.EncodeToString(sum[:]),
|
||||
count: len(names),
|
||||
}
|
||||
}
|
||||
109
internal/watcher/diff/oauth_excluded_test.go
Normal file
109
internal/watcher/diff/oauth_excluded_test.go
Normal file
@@ -0,0 +1,109 @@
|
||||
package diff
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||
)
|
||||
|
||||
func TestSummarizeExcludedModels_NormalizesAndDedupes(t *testing.T) {
|
||||
summary := SummarizeExcludedModels([]string{"A", " a ", "B", "b"})
|
||||
if summary.count != 2 {
|
||||
t.Fatalf("expected 2 unique entries, got %d", summary.count)
|
||||
}
|
||||
if summary.hash == "" {
|
||||
t.Fatal("expected non-empty hash")
|
||||
}
|
||||
if empty := SummarizeExcludedModels(nil); empty.count != 0 || empty.hash != "" {
|
||||
t.Fatalf("expected empty summary for nil input, got %+v", empty)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDiffOAuthExcludedModelChanges(t *testing.T) {
|
||||
oldMap := map[string][]string{
|
||||
"ProviderA": {"model-1", "model-2"},
|
||||
"providerB": {"x"},
|
||||
}
|
||||
newMap := map[string][]string{
|
||||
"providerA": {"model-1", "model-3"},
|
||||
"providerC": {"y"},
|
||||
}
|
||||
|
||||
changes, affected := DiffOAuthExcludedModelChanges(oldMap, newMap)
|
||||
expectContains(t, changes, "oauth-excluded-models[providera]: updated (2 -> 2 entries)")
|
||||
expectContains(t, changes, "oauth-excluded-models[providerb]: removed")
|
||||
expectContains(t, changes, "oauth-excluded-models[providerc]: added (1 entries)")
|
||||
|
||||
if len(affected) != 3 {
|
||||
t.Fatalf("expected 3 affected providers, got %d", len(affected))
|
||||
}
|
||||
}
|
||||
|
||||
func TestSummarizeAmpModelMappings(t *testing.T) {
|
||||
summary := SummarizeAmpModelMappings([]config.AmpModelMapping{
|
||||
{From: "a", To: "A"},
|
||||
{From: "b", To: "B"},
|
||||
{From: " ", To: " "}, // ignored
|
||||
})
|
||||
if summary.count != 2 {
|
||||
t.Fatalf("expected 2 entries, got %d", summary.count)
|
||||
}
|
||||
if summary.hash == "" {
|
||||
t.Fatal("expected non-empty hash")
|
||||
}
|
||||
if empty := SummarizeAmpModelMappings(nil); empty.count != 0 || empty.hash != "" {
|
||||
t.Fatalf("expected empty summary for nil input, got %+v", empty)
|
||||
}
|
||||
if blank := SummarizeAmpModelMappings([]config.AmpModelMapping{{From: " ", To: " "}}); blank.count != 0 || blank.hash != "" {
|
||||
t.Fatalf("expected blank mappings ignored, got %+v", blank)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSummarizeOAuthExcludedModels_NormalizesKeys(t *testing.T) {
|
||||
out := SummarizeOAuthExcludedModels(map[string][]string{
|
||||
"ProvA": {"X"},
|
||||
"": {"ignored"},
|
||||
})
|
||||
if len(out) != 1 {
|
||||
t.Fatalf("expected only non-empty key summary, got %d", len(out))
|
||||
}
|
||||
if _, ok := out["prova"]; !ok {
|
||||
t.Fatalf("expected normalized key 'prova', got keys %v", out)
|
||||
}
|
||||
if out["prova"].count != 1 || out["prova"].hash == "" {
|
||||
t.Fatalf("unexpected summary %+v", out["prova"])
|
||||
}
|
||||
if outEmpty := SummarizeOAuthExcludedModels(nil); outEmpty != nil {
|
||||
t.Fatalf("expected nil map for nil input, got %v", outEmpty)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSummarizeVertexModels(t *testing.T) {
|
||||
summary := SummarizeVertexModels([]config.VertexCompatModel{
|
||||
{Name: "m1"},
|
||||
{Name: " ", Alias: "alias"},
|
||||
{}, // ignored
|
||||
})
|
||||
if summary.count != 2 {
|
||||
t.Fatalf("expected 2 vertex models, got %d", summary.count)
|
||||
}
|
||||
if summary.hash == "" {
|
||||
t.Fatal("expected non-empty hash")
|
||||
}
|
||||
if empty := SummarizeVertexModels(nil); empty.count != 0 || empty.hash != "" {
|
||||
t.Fatalf("expected empty summary for nil input, got %+v", empty)
|
||||
}
|
||||
if blank := SummarizeVertexModels([]config.VertexCompatModel{{Name: " "}}); blank.count != 0 || blank.hash != "" {
|
||||
t.Fatalf("expected blank model ignored, got %+v", blank)
|
||||
}
|
||||
}
|
||||
|
||||
func expectContains(t *testing.T, list []string, target string) {
|
||||
t.Helper()
|
||||
for _, entry := range list {
|
||||
if entry == target {
|
||||
return
|
||||
}
|
||||
}
|
||||
t.Fatalf("expected list to contain %q, got %#v", target, list)
|
||||
}
|
||||
183
internal/watcher/diff/openai_compat.go
Normal file
183
internal/watcher/diff/openai_compat.go
Normal file
@@ -0,0 +1,183 @@
|
||||
package diff
|
||||
|
||||
import (
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
"sort"
|
||||
"strings"
|
||||
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||
)
|
||||
|
||||
// DiffOpenAICompatibility produces human-readable change descriptions.
|
||||
func DiffOpenAICompatibility(oldList, newList []config.OpenAICompatibility) []string {
|
||||
changes := make([]string, 0)
|
||||
oldMap := make(map[string]config.OpenAICompatibility, len(oldList))
|
||||
oldLabels := make(map[string]string, len(oldList))
|
||||
for idx, entry := range oldList {
|
||||
key, label := openAICompatKey(entry, idx)
|
||||
oldMap[key] = entry
|
||||
oldLabels[key] = label
|
||||
}
|
||||
newMap := make(map[string]config.OpenAICompatibility, len(newList))
|
||||
newLabels := make(map[string]string, len(newList))
|
||||
for idx, entry := range newList {
|
||||
key, label := openAICompatKey(entry, idx)
|
||||
newMap[key] = entry
|
||||
newLabels[key] = label
|
||||
}
|
||||
keySet := make(map[string]struct{}, len(oldMap)+len(newMap))
|
||||
for key := range oldMap {
|
||||
keySet[key] = struct{}{}
|
||||
}
|
||||
for key := range newMap {
|
||||
keySet[key] = struct{}{}
|
||||
}
|
||||
orderedKeys := make([]string, 0, len(keySet))
|
||||
for key := range keySet {
|
||||
orderedKeys = append(orderedKeys, key)
|
||||
}
|
||||
sort.Strings(orderedKeys)
|
||||
for _, key := range orderedKeys {
|
||||
oldEntry, oldOk := oldMap[key]
|
||||
newEntry, newOk := newMap[key]
|
||||
label := oldLabels[key]
|
||||
if label == "" {
|
||||
label = newLabels[key]
|
||||
}
|
||||
switch {
|
||||
case !oldOk:
|
||||
changes = append(changes, fmt.Sprintf("provider added: %s (api-keys=%d, models=%d)", label, countAPIKeys(newEntry), countOpenAIModels(newEntry.Models)))
|
||||
case !newOk:
|
||||
changes = append(changes, fmt.Sprintf("provider removed: %s (api-keys=%d, models=%d)", label, countAPIKeys(oldEntry), countOpenAIModels(oldEntry.Models)))
|
||||
default:
|
||||
if detail := describeOpenAICompatibilityUpdate(oldEntry, newEntry); detail != "" {
|
||||
changes = append(changes, fmt.Sprintf("provider updated: %s %s", label, detail))
|
||||
}
|
||||
}
|
||||
}
|
||||
return changes
|
||||
}
|
||||
|
||||
func describeOpenAICompatibilityUpdate(oldEntry, newEntry config.OpenAICompatibility) string {
|
||||
oldKeyCount := countAPIKeys(oldEntry)
|
||||
newKeyCount := countAPIKeys(newEntry)
|
||||
oldModelCount := countOpenAIModels(oldEntry.Models)
|
||||
newModelCount := countOpenAIModels(newEntry.Models)
|
||||
details := make([]string, 0, 3)
|
||||
if oldKeyCount != newKeyCount {
|
||||
details = append(details, fmt.Sprintf("api-keys %d -> %d", oldKeyCount, newKeyCount))
|
||||
}
|
||||
if oldModelCount != newModelCount {
|
||||
details = append(details, fmt.Sprintf("models %d -> %d", oldModelCount, newModelCount))
|
||||
}
|
||||
if !equalStringMap(oldEntry.Headers, newEntry.Headers) {
|
||||
details = append(details, "headers updated")
|
||||
}
|
||||
if len(details) == 0 {
|
||||
return ""
|
||||
}
|
||||
return "(" + strings.Join(details, ", ") + ")"
|
||||
}
|
||||
|
||||
func countAPIKeys(entry config.OpenAICompatibility) int {
|
||||
count := 0
|
||||
for _, keyEntry := range entry.APIKeyEntries {
|
||||
if strings.TrimSpace(keyEntry.APIKey) != "" {
|
||||
count++
|
||||
}
|
||||
}
|
||||
return count
|
||||
}
|
||||
|
||||
func countOpenAIModels(models []config.OpenAICompatibilityModel) int {
|
||||
count := 0
|
||||
for _, model := range models {
|
||||
name := strings.TrimSpace(model.Name)
|
||||
alias := strings.TrimSpace(model.Alias)
|
||||
if name == "" && alias == "" {
|
||||
continue
|
||||
}
|
||||
count++
|
||||
}
|
||||
return count
|
||||
}
|
||||
|
||||
func openAICompatKey(entry config.OpenAICompatibility, index int) (string, string) {
|
||||
name := strings.TrimSpace(entry.Name)
|
||||
if name != "" {
|
||||
return "name:" + name, name
|
||||
}
|
||||
base := strings.TrimSpace(entry.BaseURL)
|
||||
if base != "" {
|
||||
return "base:" + base, base
|
||||
}
|
||||
for _, model := range entry.Models {
|
||||
alias := strings.TrimSpace(model.Alias)
|
||||
if alias == "" {
|
||||
alias = strings.TrimSpace(model.Name)
|
||||
}
|
||||
if alias != "" {
|
||||
return "alias:" + alias, alias
|
||||
}
|
||||
}
|
||||
sig := openAICompatSignature(entry)
|
||||
if sig == "" {
|
||||
return fmt.Sprintf("index:%d", index), fmt.Sprintf("entry-%d", index+1)
|
||||
}
|
||||
short := sig
|
||||
if len(short) > 8 {
|
||||
short = short[:8]
|
||||
}
|
||||
return "sig:" + sig, "compat-" + short
|
||||
}
|
||||
|
||||
func openAICompatSignature(entry config.OpenAICompatibility) string {
|
||||
var parts []string
|
||||
|
||||
if v := strings.TrimSpace(entry.Name); v != "" {
|
||||
parts = append(parts, "name="+strings.ToLower(v))
|
||||
}
|
||||
if v := strings.TrimSpace(entry.BaseURL); v != "" {
|
||||
parts = append(parts, "base="+v)
|
||||
}
|
||||
|
||||
models := make([]string, 0, len(entry.Models))
|
||||
for _, model := range entry.Models {
|
||||
name := strings.TrimSpace(model.Name)
|
||||
alias := strings.TrimSpace(model.Alias)
|
||||
if name == "" && alias == "" {
|
||||
continue
|
||||
}
|
||||
models = append(models, strings.ToLower(name)+"|"+strings.ToLower(alias))
|
||||
}
|
||||
if len(models) > 0 {
|
||||
sort.Strings(models)
|
||||
parts = append(parts, "models="+strings.Join(models, ","))
|
||||
}
|
||||
|
||||
if len(entry.Headers) > 0 {
|
||||
keys := make([]string, 0, len(entry.Headers))
|
||||
for k := range entry.Headers {
|
||||
if trimmed := strings.TrimSpace(k); trimmed != "" {
|
||||
keys = append(keys, strings.ToLower(trimmed))
|
||||
}
|
||||
}
|
||||
if len(keys) > 0 {
|
||||
sort.Strings(keys)
|
||||
parts = append(parts, "headers="+strings.Join(keys, ","))
|
||||
}
|
||||
}
|
||||
|
||||
// Intentionally exclude API key material; only count non-empty entries.
|
||||
if count := countAPIKeys(entry); count > 0 {
|
||||
parts = append(parts, fmt.Sprintf("api_keys=%d", count))
|
||||
}
|
||||
|
||||
if len(parts) == 0 {
|
||||
return ""
|
||||
}
|
||||
sum := sha256.Sum256([]byte(strings.Join(parts, "|")))
|
||||
return hex.EncodeToString(sum[:])
|
||||
}
|
||||
187
internal/watcher/diff/openai_compat_test.go
Normal file
187
internal/watcher/diff/openai_compat_test.go
Normal file
@@ -0,0 +1,187 @@
|
||||
package diff
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||
)
|
||||
|
||||
func TestDiffOpenAICompatibility(t *testing.T) {
|
||||
oldList := []config.OpenAICompatibility{
|
||||
{
|
||||
Name: "provider-a",
|
||||
APIKeyEntries: []config.OpenAICompatibilityAPIKey{
|
||||
{APIKey: "key-a"},
|
||||
},
|
||||
Models: []config.OpenAICompatibilityModel{
|
||||
{Name: "m1"},
|
||||
},
|
||||
},
|
||||
}
|
||||
newList := []config.OpenAICompatibility{
|
||||
{
|
||||
Name: "provider-a",
|
||||
APIKeyEntries: []config.OpenAICompatibilityAPIKey{
|
||||
{APIKey: "key-a"},
|
||||
{APIKey: "key-b"},
|
||||
},
|
||||
Models: []config.OpenAICompatibilityModel{
|
||||
{Name: "m1"},
|
||||
{Name: "m2"},
|
||||
},
|
||||
Headers: map[string]string{"X-Test": "1"},
|
||||
},
|
||||
{
|
||||
Name: "provider-b",
|
||||
APIKeyEntries: []config.OpenAICompatibilityAPIKey{{APIKey: "key-b"}},
|
||||
},
|
||||
}
|
||||
|
||||
changes := DiffOpenAICompatibility(oldList, newList)
|
||||
expectContains(t, changes, "provider added: provider-b (api-keys=1, models=0)")
|
||||
expectContains(t, changes, "provider updated: provider-a (api-keys 1 -> 2, models 1 -> 2, headers updated)")
|
||||
}
|
||||
|
||||
func TestDiffOpenAICompatibility_RemovedAndUnchanged(t *testing.T) {
|
||||
oldList := []config.OpenAICompatibility{
|
||||
{
|
||||
Name: "provider-a",
|
||||
APIKeyEntries: []config.OpenAICompatibilityAPIKey{{APIKey: "key-a"}},
|
||||
Models: []config.OpenAICompatibilityModel{{Name: "m1"}},
|
||||
},
|
||||
}
|
||||
newList := []config.OpenAICompatibility{
|
||||
{
|
||||
Name: "provider-a",
|
||||
APIKeyEntries: []config.OpenAICompatibilityAPIKey{{APIKey: "key-a"}},
|
||||
Models: []config.OpenAICompatibilityModel{{Name: "m1"}},
|
||||
},
|
||||
}
|
||||
if changes := DiffOpenAICompatibility(oldList, newList); len(changes) != 0 {
|
||||
t.Fatalf("expected no changes, got %v", changes)
|
||||
}
|
||||
|
||||
newList = nil
|
||||
changes := DiffOpenAICompatibility(oldList, newList)
|
||||
expectContains(t, changes, "provider removed: provider-a (api-keys=1, models=1)")
|
||||
}
|
||||
|
||||
func TestOpenAICompatKeyFallbacks(t *testing.T) {
|
||||
entry := config.OpenAICompatibility{
|
||||
BaseURL: "http://base",
|
||||
Models: []config.OpenAICompatibilityModel{{Alias: "alias-only"}},
|
||||
}
|
||||
key, label := openAICompatKey(entry, 0)
|
||||
if key != "base:http://base" || label != "http://base" {
|
||||
t.Fatalf("expected base key, got %s/%s", key, label)
|
||||
}
|
||||
|
||||
entry.BaseURL = ""
|
||||
key, label = openAICompatKey(entry, 1)
|
||||
if key != "alias:alias-only" || label != "alias-only" {
|
||||
t.Fatalf("expected alias fallback, got %s/%s", key, label)
|
||||
}
|
||||
|
||||
entry.Models = nil
|
||||
key, label = openAICompatKey(entry, 2)
|
||||
if key != "index:2" || label != "entry-3" {
|
||||
t.Fatalf("expected index fallback, got %s/%s", key, label)
|
||||
}
|
||||
}
|
||||
|
||||
func TestOpenAICompatKey_UsesName(t *testing.T) {
|
||||
entry := config.OpenAICompatibility{Name: "My-Provider"}
|
||||
key, label := openAICompatKey(entry, 0)
|
||||
if key != "name:My-Provider" || label != "My-Provider" {
|
||||
t.Fatalf("expected name key, got %s/%s", key, label)
|
||||
}
|
||||
}
|
||||
|
||||
func TestOpenAICompatKey_SignatureFallbackWhenOnlyAPIKeys(t *testing.T) {
|
||||
entry := config.OpenAICompatibility{
|
||||
APIKeyEntries: []config.OpenAICompatibilityAPIKey{{APIKey: "k1"}, {APIKey: "k2"}},
|
||||
}
|
||||
key, label := openAICompatKey(entry, 0)
|
||||
if !strings.HasPrefix(key, "sig:") || !strings.HasPrefix(label, "compat-") {
|
||||
t.Fatalf("expected signature key, got %s/%s", key, label)
|
||||
}
|
||||
}
|
||||
|
||||
func TestOpenAICompatSignature_EmptyReturnsEmpty(t *testing.T) {
|
||||
if got := openAICompatSignature(config.OpenAICompatibility{}); got != "" {
|
||||
t.Fatalf("expected empty signature, got %q", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestOpenAICompatSignature_StableAndNormalized(t *testing.T) {
|
||||
a := config.OpenAICompatibility{
|
||||
Name: " Provider ",
|
||||
BaseURL: "http://base",
|
||||
Models: []config.OpenAICompatibilityModel{
|
||||
{Name: "m1"},
|
||||
{Name: " "},
|
||||
{Alias: "A1"},
|
||||
},
|
||||
Headers: map[string]string{
|
||||
"X-Test": "1",
|
||||
" ": "ignored",
|
||||
},
|
||||
APIKeyEntries: []config.OpenAICompatibilityAPIKey{
|
||||
{APIKey: "k1"},
|
||||
{APIKey: " "},
|
||||
},
|
||||
}
|
||||
b := config.OpenAICompatibility{
|
||||
Name: "provider",
|
||||
BaseURL: "http://base",
|
||||
Models: []config.OpenAICompatibilityModel{
|
||||
{Alias: "a1"},
|
||||
{Name: "m1"},
|
||||
},
|
||||
Headers: map[string]string{
|
||||
"x-test": "2",
|
||||
},
|
||||
APIKeyEntries: []config.OpenAICompatibilityAPIKey{
|
||||
{APIKey: "k2"},
|
||||
},
|
||||
}
|
||||
|
||||
sigA := openAICompatSignature(a)
|
||||
sigB := openAICompatSignature(b)
|
||||
if sigA == "" || sigB == "" {
|
||||
t.Fatalf("expected non-empty signatures, got %q / %q", sigA, sigB)
|
||||
}
|
||||
if sigA != sigB {
|
||||
t.Fatalf("expected normalized signatures to match, got %s / %s", sigA, sigB)
|
||||
}
|
||||
|
||||
c := b
|
||||
c.Models = append(c.Models, config.OpenAICompatibilityModel{Name: "m2"})
|
||||
if sigC := openAICompatSignature(c); sigC == sigB {
|
||||
t.Fatalf("expected signature to change when models change, got %s", sigC)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCountOpenAIModelsSkipsBlanks(t *testing.T) {
|
||||
models := []config.OpenAICompatibilityModel{
|
||||
{Name: "m1"},
|
||||
{Name: ""},
|
||||
{Alias: ""},
|
||||
{Name: " "},
|
||||
{Alias: "a1"},
|
||||
}
|
||||
if got := countOpenAIModels(models); got != 2 {
|
||||
t.Fatalf("expected 2 counted models, got %d", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestOpenAICompatKeyUsesModelNameWhenAliasEmpty(t *testing.T) {
|
||||
entry := config.OpenAICompatibility{
|
||||
Models: []config.OpenAICompatibilityModel{{Name: "model-name"}},
|
||||
}
|
||||
key, label := openAICompatKey(entry, 5)
|
||||
if key != "alias:model-name" || label != "model-name" {
|
||||
t.Fatalf("expected model-name fallback, got %s/%s", key, label)
|
||||
}
|
||||
}
|
||||
273
internal/watcher/dispatcher.go
Normal file
273
internal/watcher/dispatcher.go
Normal file
@@ -0,0 +1,273 @@
|
||||
// dispatcher.go implements auth update dispatching and queue management.
|
||||
// It batches, deduplicates, and delivers auth updates to registered consumers.
|
||||
package watcher
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"reflect"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/watcher/synthesizer"
|
||||
coreauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
|
||||
)
|
||||
|
||||
func (w *Watcher) setAuthUpdateQueue(queue chan<- AuthUpdate) {
|
||||
w.clientsMutex.Lock()
|
||||
defer w.clientsMutex.Unlock()
|
||||
w.authQueue = queue
|
||||
if w.dispatchCond == nil {
|
||||
w.dispatchCond = sync.NewCond(&w.dispatchMu)
|
||||
}
|
||||
if w.dispatchCancel != nil {
|
||||
w.dispatchCancel()
|
||||
if w.dispatchCond != nil {
|
||||
w.dispatchMu.Lock()
|
||||
w.dispatchCond.Broadcast()
|
||||
w.dispatchMu.Unlock()
|
||||
}
|
||||
w.dispatchCancel = nil
|
||||
}
|
||||
if queue != nil {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
w.dispatchCancel = cancel
|
||||
go w.dispatchLoop(ctx)
|
||||
}
|
||||
}
|
||||
|
||||
func (w *Watcher) dispatchRuntimeAuthUpdate(update AuthUpdate) bool {
|
||||
if w == nil {
|
||||
return false
|
||||
}
|
||||
w.clientsMutex.Lock()
|
||||
if w.runtimeAuths == nil {
|
||||
w.runtimeAuths = make(map[string]*coreauth.Auth)
|
||||
}
|
||||
switch update.Action {
|
||||
case AuthUpdateActionAdd, AuthUpdateActionModify:
|
||||
if update.Auth != nil && update.Auth.ID != "" {
|
||||
clone := update.Auth.Clone()
|
||||
w.runtimeAuths[clone.ID] = clone
|
||||
if w.currentAuths == nil {
|
||||
w.currentAuths = make(map[string]*coreauth.Auth)
|
||||
}
|
||||
w.currentAuths[clone.ID] = clone.Clone()
|
||||
}
|
||||
case AuthUpdateActionDelete:
|
||||
id := update.ID
|
||||
if id == "" && update.Auth != nil {
|
||||
id = update.Auth.ID
|
||||
}
|
||||
if id != "" {
|
||||
delete(w.runtimeAuths, id)
|
||||
if w.currentAuths != nil {
|
||||
delete(w.currentAuths, id)
|
||||
}
|
||||
}
|
||||
}
|
||||
w.clientsMutex.Unlock()
|
||||
if w.getAuthQueue() == nil {
|
||||
return false
|
||||
}
|
||||
w.dispatchAuthUpdates([]AuthUpdate{update})
|
||||
return true
|
||||
}
|
||||
|
||||
func (w *Watcher) refreshAuthState(force bool) {
|
||||
auths := w.SnapshotCoreAuths()
|
||||
w.clientsMutex.Lock()
|
||||
if len(w.runtimeAuths) > 0 {
|
||||
for _, a := range w.runtimeAuths {
|
||||
if a != nil {
|
||||
auths = append(auths, a.Clone())
|
||||
}
|
||||
}
|
||||
}
|
||||
updates := w.prepareAuthUpdatesLocked(auths, force)
|
||||
w.clientsMutex.Unlock()
|
||||
w.dispatchAuthUpdates(updates)
|
||||
}
|
||||
|
||||
func (w *Watcher) prepareAuthUpdatesLocked(auths []*coreauth.Auth, force bool) []AuthUpdate {
|
||||
newState := make(map[string]*coreauth.Auth, len(auths))
|
||||
for _, auth := range auths {
|
||||
if auth == nil || auth.ID == "" {
|
||||
continue
|
||||
}
|
||||
newState[auth.ID] = auth.Clone()
|
||||
}
|
||||
if w.currentAuths == nil {
|
||||
w.currentAuths = newState
|
||||
if w.authQueue == nil {
|
||||
return nil
|
||||
}
|
||||
updates := make([]AuthUpdate, 0, len(newState))
|
||||
for id, auth := range newState {
|
||||
updates = append(updates, AuthUpdate{Action: AuthUpdateActionAdd, ID: id, Auth: auth.Clone()})
|
||||
}
|
||||
return updates
|
||||
}
|
||||
if w.authQueue == nil {
|
||||
w.currentAuths = newState
|
||||
return nil
|
||||
}
|
||||
updates := make([]AuthUpdate, 0, len(newState)+len(w.currentAuths))
|
||||
for id, auth := range newState {
|
||||
if existing, ok := w.currentAuths[id]; !ok {
|
||||
updates = append(updates, AuthUpdate{Action: AuthUpdateActionAdd, ID: id, Auth: auth.Clone()})
|
||||
} else if force || !authEqual(existing, auth) {
|
||||
updates = append(updates, AuthUpdate{Action: AuthUpdateActionModify, ID: id, Auth: auth.Clone()})
|
||||
}
|
||||
}
|
||||
for id := range w.currentAuths {
|
||||
if _, ok := newState[id]; !ok {
|
||||
updates = append(updates, AuthUpdate{Action: AuthUpdateActionDelete, ID: id})
|
||||
}
|
||||
}
|
||||
w.currentAuths = newState
|
||||
return updates
|
||||
}
|
||||
|
||||
func (w *Watcher) dispatchAuthUpdates(updates []AuthUpdate) {
|
||||
if len(updates) == 0 {
|
||||
return
|
||||
}
|
||||
queue := w.getAuthQueue()
|
||||
if queue == nil {
|
||||
return
|
||||
}
|
||||
baseTS := time.Now().UnixNano()
|
||||
w.dispatchMu.Lock()
|
||||
if w.pendingUpdates == nil {
|
||||
w.pendingUpdates = make(map[string]AuthUpdate)
|
||||
}
|
||||
for idx, update := range updates {
|
||||
key := w.authUpdateKey(update, baseTS+int64(idx))
|
||||
if _, exists := w.pendingUpdates[key]; !exists {
|
||||
w.pendingOrder = append(w.pendingOrder, key)
|
||||
}
|
||||
w.pendingUpdates[key] = update
|
||||
}
|
||||
if w.dispatchCond != nil {
|
||||
w.dispatchCond.Signal()
|
||||
}
|
||||
w.dispatchMu.Unlock()
|
||||
}
|
||||
|
||||
func (w *Watcher) authUpdateKey(update AuthUpdate, ts int64) string {
|
||||
if update.ID != "" {
|
||||
return update.ID
|
||||
}
|
||||
return fmt.Sprintf("%s:%d", update.Action, ts)
|
||||
}
|
||||
|
||||
func (w *Watcher) dispatchLoop(ctx context.Context) {
|
||||
for {
|
||||
batch, ok := w.nextPendingBatch(ctx)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
queue := w.getAuthQueue()
|
||||
if queue == nil {
|
||||
if ctx.Err() != nil {
|
||||
return
|
||||
}
|
||||
time.Sleep(10 * time.Millisecond)
|
||||
continue
|
||||
}
|
||||
for _, update := range batch {
|
||||
select {
|
||||
case queue <- update:
|
||||
case <-ctx.Done():
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (w *Watcher) nextPendingBatch(ctx context.Context) ([]AuthUpdate, bool) {
|
||||
w.dispatchMu.Lock()
|
||||
defer w.dispatchMu.Unlock()
|
||||
for len(w.pendingOrder) == 0 {
|
||||
if ctx.Err() != nil {
|
||||
return nil, false
|
||||
}
|
||||
w.dispatchCond.Wait()
|
||||
if ctx.Err() != nil {
|
||||
return nil, false
|
||||
}
|
||||
}
|
||||
batch := make([]AuthUpdate, 0, len(w.pendingOrder))
|
||||
for _, key := range w.pendingOrder {
|
||||
batch = append(batch, w.pendingUpdates[key])
|
||||
delete(w.pendingUpdates, key)
|
||||
}
|
||||
w.pendingOrder = w.pendingOrder[:0]
|
||||
return batch, true
|
||||
}
|
||||
|
||||
func (w *Watcher) getAuthQueue() chan<- AuthUpdate {
|
||||
w.clientsMutex.RLock()
|
||||
defer w.clientsMutex.RUnlock()
|
||||
return w.authQueue
|
||||
}
|
||||
|
||||
func (w *Watcher) stopDispatch() {
|
||||
if w.dispatchCancel != nil {
|
||||
w.dispatchCancel()
|
||||
w.dispatchCancel = nil
|
||||
}
|
||||
w.dispatchMu.Lock()
|
||||
w.pendingOrder = nil
|
||||
w.pendingUpdates = nil
|
||||
if w.dispatchCond != nil {
|
||||
w.dispatchCond.Broadcast()
|
||||
}
|
||||
w.dispatchMu.Unlock()
|
||||
w.clientsMutex.Lock()
|
||||
w.authQueue = nil
|
||||
w.clientsMutex.Unlock()
|
||||
}
|
||||
|
||||
func authEqual(a, b *coreauth.Auth) bool {
|
||||
return reflect.DeepEqual(normalizeAuth(a), normalizeAuth(b))
|
||||
}
|
||||
|
||||
func normalizeAuth(a *coreauth.Auth) *coreauth.Auth {
|
||||
if a == nil {
|
||||
return nil
|
||||
}
|
||||
clone := a.Clone()
|
||||
clone.CreatedAt = time.Time{}
|
||||
clone.UpdatedAt = time.Time{}
|
||||
clone.LastRefreshedAt = time.Time{}
|
||||
clone.NextRefreshAfter = time.Time{}
|
||||
clone.Runtime = nil
|
||||
clone.Quota.NextRecoverAt = time.Time{}
|
||||
return clone
|
||||
}
|
||||
|
||||
func snapshotCoreAuths(cfg *config.Config, authDir string) []*coreauth.Auth {
|
||||
ctx := &synthesizer.SynthesisContext{
|
||||
Config: cfg,
|
||||
AuthDir: authDir,
|
||||
Now: time.Now(),
|
||||
IDGenerator: synthesizer.NewStableIDGenerator(),
|
||||
}
|
||||
|
||||
var out []*coreauth.Auth
|
||||
|
||||
configSynth := synthesizer.NewConfigSynthesizer()
|
||||
if auths, err := configSynth.Synthesize(ctx); err == nil {
|
||||
out = append(out, auths...)
|
||||
}
|
||||
|
||||
fileSynth := synthesizer.NewFileSynthesizer()
|
||||
if auths, err := fileSynth.Synthesize(ctx); err == nil {
|
||||
out = append(out, auths...)
|
||||
}
|
||||
|
||||
return out
|
||||
}
|
||||
194
internal/watcher/events.go
Normal file
194
internal/watcher/events.go
Normal file
@@ -0,0 +1,194 @@
|
||||
// events.go implements fsnotify event handling for config and auth file changes.
|
||||
// It normalizes paths, debounces noisy events, and triggers reload/update logic.
|
||||
package watcher
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/fsnotify/fsnotify"
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
func matchProvider(provider string, targets []string) (string, bool) {
|
||||
p := strings.ToLower(strings.TrimSpace(provider))
|
||||
for _, t := range targets {
|
||||
if strings.EqualFold(p, strings.TrimSpace(t)) {
|
||||
return p, true
|
||||
}
|
||||
}
|
||||
return p, false
|
||||
}
|
||||
|
||||
func (w *Watcher) start(ctx context.Context) error {
|
||||
if errAddConfig := w.watcher.Add(w.configPath); errAddConfig != nil {
|
||||
log.Errorf("failed to watch config file %s: %v", w.configPath, errAddConfig)
|
||||
return errAddConfig
|
||||
}
|
||||
log.Debugf("watching config file: %s", w.configPath)
|
||||
|
||||
if errAddAuthDir := w.watcher.Add(w.authDir); errAddAuthDir != nil {
|
||||
log.Errorf("failed to watch auth directory %s: %v", w.authDir, errAddAuthDir)
|
||||
return errAddAuthDir
|
||||
}
|
||||
log.Debugf("watching auth directory: %s", w.authDir)
|
||||
|
||||
go w.processEvents(ctx)
|
||||
|
||||
w.reloadClients(true, nil, false)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (w *Watcher) processEvents(ctx context.Context) {
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case event, ok := <-w.watcher.Events:
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
w.handleEvent(event)
|
||||
case errWatch, ok := <-w.watcher.Errors:
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
log.Errorf("file watcher error: %v", errWatch)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (w *Watcher) handleEvent(event fsnotify.Event) {
|
||||
// Filter only relevant events: config file or auth-dir JSON files.
|
||||
configOps := fsnotify.Write | fsnotify.Create | fsnotify.Rename
|
||||
normalizedName := w.normalizeAuthPath(event.Name)
|
||||
normalizedConfigPath := w.normalizeAuthPath(w.configPath)
|
||||
normalizedAuthDir := w.normalizeAuthPath(w.authDir)
|
||||
isConfigEvent := normalizedName == normalizedConfigPath && event.Op&configOps != 0
|
||||
authOps := fsnotify.Create | fsnotify.Write | fsnotify.Remove | fsnotify.Rename
|
||||
isAuthJSON := strings.HasPrefix(normalizedName, normalizedAuthDir) && strings.HasSuffix(normalizedName, ".json") && event.Op&authOps != 0
|
||||
if !isConfigEvent && !isAuthJSON {
|
||||
// Ignore unrelated files (e.g., cookie snapshots *.cookie) and other noise.
|
||||
return
|
||||
}
|
||||
|
||||
now := time.Now()
|
||||
log.Debugf("file system event detected: %s %s", event.Op.String(), event.Name)
|
||||
|
||||
// Handle config file changes
|
||||
if isConfigEvent {
|
||||
log.Debugf("config file change details - operation: %s, timestamp: %s", event.Op.String(), now.Format("2006-01-02 15:04:05.000"))
|
||||
w.scheduleConfigReload()
|
||||
return
|
||||
}
|
||||
|
||||
// Handle auth directory changes incrementally (.json only)
|
||||
if event.Op&(fsnotify.Remove|fsnotify.Rename) != 0 {
|
||||
if w.shouldDebounceRemove(normalizedName, now) {
|
||||
log.Debugf("debouncing remove event for %s", filepath.Base(event.Name))
|
||||
return
|
||||
}
|
||||
// Atomic replace on some platforms may surface as Rename (or Remove) before the new file is ready.
|
||||
// Wait briefly; if the path exists again, treat as an update instead of removal.
|
||||
time.Sleep(replaceCheckDelay)
|
||||
if _, statErr := os.Stat(event.Name); statErr == nil {
|
||||
if unchanged, errSame := w.authFileUnchanged(event.Name); errSame == nil && unchanged {
|
||||
log.Debugf("auth file unchanged (hash match), skipping reload: %s", filepath.Base(event.Name))
|
||||
return
|
||||
}
|
||||
log.Infof("auth file changed (%s): %s, processing incrementally", event.Op.String(), filepath.Base(event.Name))
|
||||
w.addOrUpdateClient(event.Name)
|
||||
return
|
||||
}
|
||||
if !w.isKnownAuthFile(event.Name) {
|
||||
log.Debugf("ignoring remove for unknown auth file: %s", filepath.Base(event.Name))
|
||||
return
|
||||
}
|
||||
log.Infof("auth file changed (%s): %s, processing incrementally", event.Op.String(), filepath.Base(event.Name))
|
||||
w.removeClient(event.Name)
|
||||
return
|
||||
}
|
||||
if event.Op&(fsnotify.Create|fsnotify.Write) != 0 {
|
||||
if unchanged, errSame := w.authFileUnchanged(event.Name); errSame == nil && unchanged {
|
||||
log.Debugf("auth file unchanged (hash match), skipping reload: %s", filepath.Base(event.Name))
|
||||
return
|
||||
}
|
||||
log.Infof("auth file changed (%s): %s, processing incrementally", event.Op.String(), filepath.Base(event.Name))
|
||||
w.addOrUpdateClient(event.Name)
|
||||
}
|
||||
}
|
||||
|
||||
func (w *Watcher) authFileUnchanged(path string) (bool, error) {
|
||||
data, errRead := os.ReadFile(path)
|
||||
if errRead != nil {
|
||||
return false, errRead
|
||||
}
|
||||
if len(data) == 0 {
|
||||
return false, nil
|
||||
}
|
||||
sum := sha256.Sum256(data)
|
||||
curHash := hex.EncodeToString(sum[:])
|
||||
|
||||
normalized := w.normalizeAuthPath(path)
|
||||
w.clientsMutex.RLock()
|
||||
prevHash, ok := w.lastAuthHashes[normalized]
|
||||
w.clientsMutex.RUnlock()
|
||||
if ok && prevHash == curHash {
|
||||
return true, nil
|
||||
}
|
||||
return false, nil
|
||||
}
|
||||
|
||||
func (w *Watcher) isKnownAuthFile(path string) bool {
|
||||
normalized := w.normalizeAuthPath(path)
|
||||
w.clientsMutex.RLock()
|
||||
defer w.clientsMutex.RUnlock()
|
||||
_, ok := w.lastAuthHashes[normalized]
|
||||
return ok
|
||||
}
|
||||
|
||||
func (w *Watcher) normalizeAuthPath(path string) string {
|
||||
trimmed := strings.TrimSpace(path)
|
||||
if trimmed == "" {
|
||||
return ""
|
||||
}
|
||||
cleaned := filepath.Clean(trimmed)
|
||||
if runtime.GOOS == "windows" {
|
||||
cleaned = strings.TrimPrefix(cleaned, `\\?\`)
|
||||
cleaned = strings.ToLower(cleaned)
|
||||
}
|
||||
return cleaned
|
||||
}
|
||||
|
||||
func (w *Watcher) shouldDebounceRemove(normalizedPath string, now time.Time) bool {
|
||||
if normalizedPath == "" {
|
||||
return false
|
||||
}
|
||||
w.clientsMutex.Lock()
|
||||
if w.lastRemoveTimes == nil {
|
||||
w.lastRemoveTimes = make(map[string]time.Time)
|
||||
}
|
||||
if last, ok := w.lastRemoveTimes[normalizedPath]; ok {
|
||||
if now.Sub(last) < authRemoveDebounceWindow {
|
||||
w.clientsMutex.Unlock()
|
||||
return true
|
||||
}
|
||||
}
|
||||
w.lastRemoveTimes[normalizedPath] = now
|
||||
if len(w.lastRemoveTimes) > 128 {
|
||||
cutoff := now.Add(-2 * authRemoveDebounceWindow)
|
||||
for p, t := range w.lastRemoveTimes {
|
||||
if t.Before(cutoff) {
|
||||
delete(w.lastRemoveTimes, p)
|
||||
}
|
||||
}
|
||||
}
|
||||
w.clientsMutex.Unlock()
|
||||
return false
|
||||
}
|
||||
294
internal/watcher/synthesizer/config.go
Normal file
294
internal/watcher/synthesizer/config.go
Normal file
@@ -0,0 +1,294 @@
|
||||
package synthesizer
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/watcher/diff"
|
||||
coreauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
|
||||
)
|
||||
|
||||
// ConfigSynthesizer generates Auth entries from configuration API keys.
|
||||
// It handles Gemini, Claude, Codex, OpenAI-compat, and Vertex-compat providers.
|
||||
type ConfigSynthesizer struct{}
|
||||
|
||||
// NewConfigSynthesizer creates a new ConfigSynthesizer instance.
|
||||
func NewConfigSynthesizer() *ConfigSynthesizer {
|
||||
return &ConfigSynthesizer{}
|
||||
}
|
||||
|
||||
// Synthesize generates Auth entries from config API keys.
|
||||
func (s *ConfigSynthesizer) Synthesize(ctx *SynthesisContext) ([]*coreauth.Auth, error) {
|
||||
out := make([]*coreauth.Auth, 0, 32)
|
||||
if ctx == nil || ctx.Config == nil {
|
||||
return out, nil
|
||||
}
|
||||
|
||||
// Gemini API Keys
|
||||
out = append(out, s.synthesizeGeminiKeys(ctx)...)
|
||||
// Claude API Keys
|
||||
out = append(out, s.synthesizeClaudeKeys(ctx)...)
|
||||
// Codex API Keys
|
||||
out = append(out, s.synthesizeCodexKeys(ctx)...)
|
||||
// OpenAI-compat
|
||||
out = append(out, s.synthesizeOpenAICompat(ctx)...)
|
||||
// Vertex-compat
|
||||
out = append(out, s.synthesizeVertexCompat(ctx)...)
|
||||
|
||||
return out, nil
|
||||
}
|
||||
|
||||
// synthesizeGeminiKeys creates Auth entries for Gemini API keys.
|
||||
func (s *ConfigSynthesizer) synthesizeGeminiKeys(ctx *SynthesisContext) []*coreauth.Auth {
|
||||
cfg := ctx.Config
|
||||
now := ctx.Now
|
||||
idGen := ctx.IDGenerator
|
||||
|
||||
out := make([]*coreauth.Auth, 0, len(cfg.GeminiKey))
|
||||
for i := range cfg.GeminiKey {
|
||||
entry := cfg.GeminiKey[i]
|
||||
key := strings.TrimSpace(entry.APIKey)
|
||||
if key == "" {
|
||||
continue
|
||||
}
|
||||
prefix := strings.TrimSpace(entry.Prefix)
|
||||
base := strings.TrimSpace(entry.BaseURL)
|
||||
proxyURL := strings.TrimSpace(entry.ProxyURL)
|
||||
id, token := idGen.Next("gemini:apikey", key, base)
|
||||
attrs := map[string]string{
|
||||
"source": fmt.Sprintf("config:gemini[%s]", token),
|
||||
"api_key": key,
|
||||
}
|
||||
if base != "" {
|
||||
attrs["base_url"] = base
|
||||
}
|
||||
addConfigHeadersToAttrs(entry.Headers, attrs)
|
||||
a := &coreauth.Auth{
|
||||
ID: id,
|
||||
Provider: "gemini",
|
||||
Label: "gemini-apikey",
|
||||
Prefix: prefix,
|
||||
Status: coreauth.StatusActive,
|
||||
ProxyURL: proxyURL,
|
||||
Attributes: attrs,
|
||||
CreatedAt: now,
|
||||
UpdatedAt: now,
|
||||
}
|
||||
ApplyAuthExcludedModelsMeta(a, cfg, entry.ExcludedModels, "apikey")
|
||||
out = append(out, a)
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
// synthesizeClaudeKeys creates Auth entries for Claude API keys.
|
||||
func (s *ConfigSynthesizer) synthesizeClaudeKeys(ctx *SynthesisContext) []*coreauth.Auth {
|
||||
cfg := ctx.Config
|
||||
now := ctx.Now
|
||||
idGen := ctx.IDGenerator
|
||||
|
||||
out := make([]*coreauth.Auth, 0, len(cfg.ClaudeKey))
|
||||
for i := range cfg.ClaudeKey {
|
||||
ck := cfg.ClaudeKey[i]
|
||||
key := strings.TrimSpace(ck.APIKey)
|
||||
if key == "" {
|
||||
continue
|
||||
}
|
||||
prefix := strings.TrimSpace(ck.Prefix)
|
||||
base := strings.TrimSpace(ck.BaseURL)
|
||||
id, token := idGen.Next("claude:apikey", key, base)
|
||||
attrs := map[string]string{
|
||||
"source": fmt.Sprintf("config:claude[%s]", token),
|
||||
"api_key": key,
|
||||
}
|
||||
if base != "" {
|
||||
attrs["base_url"] = base
|
||||
}
|
||||
if hash := diff.ComputeClaudeModelsHash(ck.Models); hash != "" {
|
||||
attrs["models_hash"] = hash
|
||||
}
|
||||
addConfigHeadersToAttrs(ck.Headers, attrs)
|
||||
proxyURL := strings.TrimSpace(ck.ProxyURL)
|
||||
a := &coreauth.Auth{
|
||||
ID: id,
|
||||
Provider: "claude",
|
||||
Label: "claude-apikey",
|
||||
Prefix: prefix,
|
||||
Status: coreauth.StatusActive,
|
||||
ProxyURL: proxyURL,
|
||||
Attributes: attrs,
|
||||
CreatedAt: now,
|
||||
UpdatedAt: now,
|
||||
}
|
||||
ApplyAuthExcludedModelsMeta(a, cfg, ck.ExcludedModels, "apikey")
|
||||
out = append(out, a)
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
// synthesizeCodexKeys creates Auth entries for Codex API keys.
|
||||
func (s *ConfigSynthesizer) synthesizeCodexKeys(ctx *SynthesisContext) []*coreauth.Auth {
|
||||
cfg := ctx.Config
|
||||
now := ctx.Now
|
||||
idGen := ctx.IDGenerator
|
||||
|
||||
out := make([]*coreauth.Auth, 0, len(cfg.CodexKey))
|
||||
for i := range cfg.CodexKey {
|
||||
ck := cfg.CodexKey[i]
|
||||
key := strings.TrimSpace(ck.APIKey)
|
||||
if key == "" {
|
||||
continue
|
||||
}
|
||||
prefix := strings.TrimSpace(ck.Prefix)
|
||||
id, token := idGen.Next("codex:apikey", key, ck.BaseURL)
|
||||
attrs := map[string]string{
|
||||
"source": fmt.Sprintf("config:codex[%s]", token),
|
||||
"api_key": key,
|
||||
}
|
||||
if ck.BaseURL != "" {
|
||||
attrs["base_url"] = ck.BaseURL
|
||||
}
|
||||
addConfigHeadersToAttrs(ck.Headers, attrs)
|
||||
proxyURL := strings.TrimSpace(ck.ProxyURL)
|
||||
a := &coreauth.Auth{
|
||||
ID: id,
|
||||
Provider: "codex",
|
||||
Label: "codex-apikey",
|
||||
Prefix: prefix,
|
||||
Status: coreauth.StatusActive,
|
||||
ProxyURL: proxyURL,
|
||||
Attributes: attrs,
|
||||
CreatedAt: now,
|
||||
UpdatedAt: now,
|
||||
}
|
||||
ApplyAuthExcludedModelsMeta(a, cfg, ck.ExcludedModels, "apikey")
|
||||
out = append(out, a)
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
// synthesizeOpenAICompat creates Auth entries for OpenAI-compatible providers.
|
||||
func (s *ConfigSynthesizer) synthesizeOpenAICompat(ctx *SynthesisContext) []*coreauth.Auth {
|
||||
cfg := ctx.Config
|
||||
now := ctx.Now
|
||||
idGen := ctx.IDGenerator
|
||||
|
||||
out := make([]*coreauth.Auth, 0)
|
||||
for i := range cfg.OpenAICompatibility {
|
||||
compat := &cfg.OpenAICompatibility[i]
|
||||
prefix := strings.TrimSpace(compat.Prefix)
|
||||
providerName := strings.ToLower(strings.TrimSpace(compat.Name))
|
||||
if providerName == "" {
|
||||
providerName = "openai-compatibility"
|
||||
}
|
||||
base := strings.TrimSpace(compat.BaseURL)
|
||||
|
||||
// Handle new APIKeyEntries format (preferred)
|
||||
createdEntries := 0
|
||||
for j := range compat.APIKeyEntries {
|
||||
entry := &compat.APIKeyEntries[j]
|
||||
key := strings.TrimSpace(entry.APIKey)
|
||||
proxyURL := strings.TrimSpace(entry.ProxyURL)
|
||||
idKind := fmt.Sprintf("openai-compatibility:%s", providerName)
|
||||
id, token := idGen.Next(idKind, key, base, proxyURL)
|
||||
attrs := map[string]string{
|
||||
"source": fmt.Sprintf("config:%s[%s]", providerName, token),
|
||||
"base_url": base,
|
||||
"compat_name": compat.Name,
|
||||
"provider_key": providerName,
|
||||
}
|
||||
if key != "" {
|
||||
attrs["api_key"] = key
|
||||
}
|
||||
if hash := diff.ComputeOpenAICompatModelsHash(compat.Models); hash != "" {
|
||||
attrs["models_hash"] = hash
|
||||
}
|
||||
addConfigHeadersToAttrs(compat.Headers, attrs)
|
||||
a := &coreauth.Auth{
|
||||
ID: id,
|
||||
Provider: providerName,
|
||||
Label: compat.Name,
|
||||
Prefix: prefix,
|
||||
Status: coreauth.StatusActive,
|
||||
ProxyURL: proxyURL,
|
||||
Attributes: attrs,
|
||||
CreatedAt: now,
|
||||
UpdatedAt: now,
|
||||
}
|
||||
out = append(out, a)
|
||||
createdEntries++
|
||||
}
|
||||
// Fallback: create entry without API key if no APIKeyEntries
|
||||
if createdEntries == 0 {
|
||||
idKind := fmt.Sprintf("openai-compatibility:%s", providerName)
|
||||
id, token := idGen.Next(idKind, base)
|
||||
attrs := map[string]string{
|
||||
"source": fmt.Sprintf("config:%s[%s]", providerName, token),
|
||||
"base_url": base,
|
||||
"compat_name": compat.Name,
|
||||
"provider_key": providerName,
|
||||
}
|
||||
if hash := diff.ComputeOpenAICompatModelsHash(compat.Models); hash != "" {
|
||||
attrs["models_hash"] = hash
|
||||
}
|
||||
addConfigHeadersToAttrs(compat.Headers, attrs)
|
||||
a := &coreauth.Auth{
|
||||
ID: id,
|
||||
Provider: providerName,
|
||||
Label: compat.Name,
|
||||
Prefix: prefix,
|
||||
Status: coreauth.StatusActive,
|
||||
Attributes: attrs,
|
||||
CreatedAt: now,
|
||||
UpdatedAt: now,
|
||||
}
|
||||
out = append(out, a)
|
||||
}
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
// synthesizeVertexCompat creates Auth entries for Vertex-compatible providers.
|
||||
func (s *ConfigSynthesizer) synthesizeVertexCompat(ctx *SynthesisContext) []*coreauth.Auth {
|
||||
cfg := ctx.Config
|
||||
now := ctx.Now
|
||||
idGen := ctx.IDGenerator
|
||||
|
||||
out := make([]*coreauth.Auth, 0, len(cfg.VertexCompatAPIKey))
|
||||
for i := range cfg.VertexCompatAPIKey {
|
||||
compat := &cfg.VertexCompatAPIKey[i]
|
||||
providerName := "vertex"
|
||||
base := strings.TrimSpace(compat.BaseURL)
|
||||
|
||||
key := strings.TrimSpace(compat.APIKey)
|
||||
prefix := strings.TrimSpace(compat.Prefix)
|
||||
proxyURL := strings.TrimSpace(compat.ProxyURL)
|
||||
idKind := "vertex:apikey"
|
||||
id, token := idGen.Next(idKind, key, base, proxyURL)
|
||||
attrs := map[string]string{
|
||||
"source": fmt.Sprintf("config:vertex-apikey[%s]", token),
|
||||
"base_url": base,
|
||||
"provider_key": providerName,
|
||||
}
|
||||
if key != "" {
|
||||
attrs["api_key"] = key
|
||||
}
|
||||
if hash := diff.ComputeVertexCompatModelsHash(compat.Models); hash != "" {
|
||||
attrs["models_hash"] = hash
|
||||
}
|
||||
addConfigHeadersToAttrs(compat.Headers, attrs)
|
||||
a := &coreauth.Auth{
|
||||
ID: id,
|
||||
Provider: providerName,
|
||||
Label: "vertex-apikey",
|
||||
Prefix: prefix,
|
||||
Status: coreauth.StatusActive,
|
||||
ProxyURL: proxyURL,
|
||||
Attributes: attrs,
|
||||
CreatedAt: now,
|
||||
UpdatedAt: now,
|
||||
}
|
||||
ApplyAuthExcludedModelsMeta(a, cfg, nil, "apikey")
|
||||
out = append(out, a)
|
||||
}
|
||||
return out
|
||||
}
|
||||
613
internal/watcher/synthesizer/config_test.go
Normal file
613
internal/watcher/synthesizer/config_test.go
Normal file
@@ -0,0 +1,613 @@
|
||||
package synthesizer
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||
coreauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
|
||||
)
|
||||
|
||||
func TestNewConfigSynthesizer(t *testing.T) {
|
||||
synth := NewConfigSynthesizer()
|
||||
if synth == nil {
|
||||
t.Fatal("expected non-nil synthesizer")
|
||||
}
|
||||
}
|
||||
|
||||
func TestConfigSynthesizer_Synthesize_NilContext(t *testing.T) {
|
||||
synth := NewConfigSynthesizer()
|
||||
auths, err := synth.Synthesize(nil)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if len(auths) != 0 {
|
||||
t.Fatalf("expected empty auths, got %d", len(auths))
|
||||
}
|
||||
}
|
||||
|
||||
func TestConfigSynthesizer_Synthesize_NilConfig(t *testing.T) {
|
||||
synth := NewConfigSynthesizer()
|
||||
ctx := &SynthesisContext{
|
||||
Config: nil,
|
||||
Now: time.Now(),
|
||||
IDGenerator: NewStableIDGenerator(),
|
||||
}
|
||||
auths, err := synth.Synthesize(ctx)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if len(auths) != 0 {
|
||||
t.Fatalf("expected empty auths, got %d", len(auths))
|
||||
}
|
||||
}
|
||||
|
||||
func TestConfigSynthesizer_GeminiKeys(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
geminiKeys []config.GeminiKey
|
||||
wantLen int
|
||||
validate func(*testing.T, []*coreauth.Auth)
|
||||
}{
|
||||
{
|
||||
name: "single gemini key",
|
||||
geminiKeys: []config.GeminiKey{
|
||||
{APIKey: "test-key-123", Prefix: "team-a"},
|
||||
},
|
||||
wantLen: 1,
|
||||
validate: func(t *testing.T, auths []*coreauth.Auth) {
|
||||
if auths[0].Provider != "gemini" {
|
||||
t.Errorf("expected provider gemini, got %s", auths[0].Provider)
|
||||
}
|
||||
if auths[0].Prefix != "team-a" {
|
||||
t.Errorf("expected prefix team-a, got %s", auths[0].Prefix)
|
||||
}
|
||||
if auths[0].Label != "gemini-apikey" {
|
||||
t.Errorf("expected label gemini-apikey, got %s", auths[0].Label)
|
||||
}
|
||||
if auths[0].Attributes["api_key"] != "test-key-123" {
|
||||
t.Errorf("expected api_key test-key-123, got %s", auths[0].Attributes["api_key"])
|
||||
}
|
||||
if auths[0].Status != coreauth.StatusActive {
|
||||
t.Errorf("expected status active, got %s", auths[0].Status)
|
||||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "gemini key with base url and proxy",
|
||||
geminiKeys: []config.GeminiKey{
|
||||
{
|
||||
APIKey: "api-key",
|
||||
BaseURL: "https://custom.api.com",
|
||||
ProxyURL: "http://proxy.local:8080",
|
||||
Prefix: "custom",
|
||||
},
|
||||
},
|
||||
wantLen: 1,
|
||||
validate: func(t *testing.T, auths []*coreauth.Auth) {
|
||||
if auths[0].Attributes["base_url"] != "https://custom.api.com" {
|
||||
t.Errorf("expected base_url https://custom.api.com, got %s", auths[0].Attributes["base_url"])
|
||||
}
|
||||
if auths[0].ProxyURL != "http://proxy.local:8080" {
|
||||
t.Errorf("expected proxy_url http://proxy.local:8080, got %s", auths[0].ProxyURL)
|
||||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "gemini key with headers",
|
||||
geminiKeys: []config.GeminiKey{
|
||||
{
|
||||
APIKey: "api-key",
|
||||
Headers: map[string]string{"X-Custom": "value"},
|
||||
},
|
||||
},
|
||||
wantLen: 1,
|
||||
validate: func(t *testing.T, auths []*coreauth.Auth) {
|
||||
if auths[0].Attributes["header:X-Custom"] != "value" {
|
||||
t.Errorf("expected header:X-Custom=value, got %s", auths[0].Attributes["header:X-Custom"])
|
||||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "empty api key skipped",
|
||||
geminiKeys: []config.GeminiKey{
|
||||
{APIKey: ""},
|
||||
{APIKey: " "},
|
||||
{APIKey: "valid-key"},
|
||||
},
|
||||
wantLen: 1,
|
||||
},
|
||||
{
|
||||
name: "multiple gemini keys",
|
||||
geminiKeys: []config.GeminiKey{
|
||||
{APIKey: "key-1", Prefix: "a"},
|
||||
{APIKey: "key-2", Prefix: "b"},
|
||||
{APIKey: "key-3", Prefix: "c"},
|
||||
},
|
||||
wantLen: 3,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
synth := NewConfigSynthesizer()
|
||||
ctx := &SynthesisContext{
|
||||
Config: &config.Config{
|
||||
GeminiKey: tt.geminiKeys,
|
||||
},
|
||||
Now: time.Date(2025, 1, 1, 0, 0, 0, 0, time.UTC),
|
||||
IDGenerator: NewStableIDGenerator(),
|
||||
}
|
||||
|
||||
auths, err := synth.Synthesize(ctx)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if len(auths) != tt.wantLen {
|
||||
t.Fatalf("expected %d auths, got %d", tt.wantLen, len(auths))
|
||||
}
|
||||
|
||||
if tt.validate != nil && len(auths) > 0 {
|
||||
tt.validate(t, auths)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestConfigSynthesizer_ClaudeKeys(t *testing.T) {
|
||||
synth := NewConfigSynthesizer()
|
||||
ctx := &SynthesisContext{
|
||||
Config: &config.Config{
|
||||
ClaudeKey: []config.ClaudeKey{
|
||||
{
|
||||
APIKey: "sk-ant-api-xxx",
|
||||
Prefix: "main",
|
||||
BaseURL: "https://api.anthropic.com",
|
||||
Models: []config.ClaudeModel{
|
||||
{Name: "claude-3-opus"},
|
||||
{Name: "claude-3-sonnet"},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
Now: time.Now(),
|
||||
IDGenerator: NewStableIDGenerator(),
|
||||
}
|
||||
|
||||
auths, err := synth.Synthesize(ctx)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if len(auths) != 1 {
|
||||
t.Fatalf("expected 1 auth, got %d", len(auths))
|
||||
}
|
||||
|
||||
if auths[0].Provider != "claude" {
|
||||
t.Errorf("expected provider claude, got %s", auths[0].Provider)
|
||||
}
|
||||
if auths[0].Label != "claude-apikey" {
|
||||
t.Errorf("expected label claude-apikey, got %s", auths[0].Label)
|
||||
}
|
||||
if auths[0].Prefix != "main" {
|
||||
t.Errorf("expected prefix main, got %s", auths[0].Prefix)
|
||||
}
|
||||
if auths[0].Attributes["api_key"] != "sk-ant-api-xxx" {
|
||||
t.Errorf("expected api_key sk-ant-api-xxx, got %s", auths[0].Attributes["api_key"])
|
||||
}
|
||||
if _, ok := auths[0].Attributes["models_hash"]; !ok {
|
||||
t.Error("expected models_hash in attributes")
|
||||
}
|
||||
}
|
||||
|
||||
func TestConfigSynthesizer_ClaudeKeys_SkipsEmptyAndHeaders(t *testing.T) {
|
||||
synth := NewConfigSynthesizer()
|
||||
ctx := &SynthesisContext{
|
||||
Config: &config.Config{
|
||||
ClaudeKey: []config.ClaudeKey{
|
||||
{APIKey: ""}, // empty, should be skipped
|
||||
{APIKey: " "}, // whitespace, should be skipped
|
||||
{APIKey: "valid-key", Headers: map[string]string{"X-Custom": "value"}},
|
||||
},
|
||||
},
|
||||
Now: time.Now(),
|
||||
IDGenerator: NewStableIDGenerator(),
|
||||
}
|
||||
|
||||
auths, err := synth.Synthesize(ctx)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if len(auths) != 1 {
|
||||
t.Fatalf("expected 1 auth (empty keys skipped), got %d", len(auths))
|
||||
}
|
||||
if auths[0].Attributes["header:X-Custom"] != "value" {
|
||||
t.Errorf("expected header:X-Custom=value, got %s", auths[0].Attributes["header:X-Custom"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestConfigSynthesizer_CodexKeys(t *testing.T) {
|
||||
synth := NewConfigSynthesizer()
|
||||
ctx := &SynthesisContext{
|
||||
Config: &config.Config{
|
||||
CodexKey: []config.CodexKey{
|
||||
{
|
||||
APIKey: "codex-key-123",
|
||||
Prefix: "dev",
|
||||
BaseURL: "https://api.openai.com",
|
||||
ProxyURL: "http://proxy.local",
|
||||
},
|
||||
},
|
||||
},
|
||||
Now: time.Now(),
|
||||
IDGenerator: NewStableIDGenerator(),
|
||||
}
|
||||
|
||||
auths, err := synth.Synthesize(ctx)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if len(auths) != 1 {
|
||||
t.Fatalf("expected 1 auth, got %d", len(auths))
|
||||
}
|
||||
|
||||
if auths[0].Provider != "codex" {
|
||||
t.Errorf("expected provider codex, got %s", auths[0].Provider)
|
||||
}
|
||||
if auths[0].Label != "codex-apikey" {
|
||||
t.Errorf("expected label codex-apikey, got %s", auths[0].Label)
|
||||
}
|
||||
if auths[0].ProxyURL != "http://proxy.local" {
|
||||
t.Errorf("expected proxy_url http://proxy.local, got %s", auths[0].ProxyURL)
|
||||
}
|
||||
}
|
||||
|
||||
func TestConfigSynthesizer_CodexKeys_SkipsEmptyAndHeaders(t *testing.T) {
|
||||
synth := NewConfigSynthesizer()
|
||||
ctx := &SynthesisContext{
|
||||
Config: &config.Config{
|
||||
CodexKey: []config.CodexKey{
|
||||
{APIKey: ""}, // empty, should be skipped
|
||||
{APIKey: " "}, // whitespace, should be skipped
|
||||
{APIKey: "valid-key", Headers: map[string]string{"Authorization": "Bearer xyz"}},
|
||||
},
|
||||
},
|
||||
Now: time.Now(),
|
||||
IDGenerator: NewStableIDGenerator(),
|
||||
}
|
||||
|
||||
auths, err := synth.Synthesize(ctx)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if len(auths) != 1 {
|
||||
t.Fatalf("expected 1 auth (empty keys skipped), got %d", len(auths))
|
||||
}
|
||||
if auths[0].Attributes["header:Authorization"] != "Bearer xyz" {
|
||||
t.Errorf("expected header:Authorization=Bearer xyz, got %s", auths[0].Attributes["header:Authorization"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestConfigSynthesizer_OpenAICompat(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
compat []config.OpenAICompatibility
|
||||
wantLen int
|
||||
}{
|
||||
{
|
||||
name: "with APIKeyEntries",
|
||||
compat: []config.OpenAICompatibility{
|
||||
{
|
||||
Name: "CustomProvider",
|
||||
BaseURL: "https://custom.api.com",
|
||||
APIKeyEntries: []config.OpenAICompatibilityAPIKey{
|
||||
{APIKey: "key-1"},
|
||||
{APIKey: "key-2"},
|
||||
},
|
||||
},
|
||||
},
|
||||
wantLen: 2,
|
||||
},
|
||||
{
|
||||
name: "empty APIKeyEntries included (legacy)",
|
||||
compat: []config.OpenAICompatibility{
|
||||
{
|
||||
Name: "EmptyKeys",
|
||||
BaseURL: "https://empty.api.com",
|
||||
APIKeyEntries: []config.OpenAICompatibilityAPIKey{
|
||||
{APIKey: ""},
|
||||
{APIKey: " "},
|
||||
},
|
||||
},
|
||||
},
|
||||
wantLen: 2,
|
||||
},
|
||||
{
|
||||
name: "without APIKeyEntries (fallback)",
|
||||
compat: []config.OpenAICompatibility{
|
||||
{
|
||||
Name: "NoKeyProvider",
|
||||
BaseURL: "https://no-key.api.com",
|
||||
},
|
||||
},
|
||||
wantLen: 1,
|
||||
},
|
||||
{
|
||||
name: "empty name defaults",
|
||||
compat: []config.OpenAICompatibility{
|
||||
{
|
||||
Name: "",
|
||||
BaseURL: "https://default.api.com",
|
||||
},
|
||||
},
|
||||
wantLen: 1,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
synth := NewConfigSynthesizer()
|
||||
ctx := &SynthesisContext{
|
||||
Config: &config.Config{
|
||||
OpenAICompatibility: tt.compat,
|
||||
},
|
||||
Now: time.Now(),
|
||||
IDGenerator: NewStableIDGenerator(),
|
||||
}
|
||||
|
||||
auths, err := synth.Synthesize(ctx)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if len(auths) != tt.wantLen {
|
||||
t.Fatalf("expected %d auths, got %d", tt.wantLen, len(auths))
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestConfigSynthesizer_VertexCompat(t *testing.T) {
|
||||
synth := NewConfigSynthesizer()
|
||||
ctx := &SynthesisContext{
|
||||
Config: &config.Config{
|
||||
VertexCompatAPIKey: []config.VertexCompatKey{
|
||||
{
|
||||
APIKey: "vertex-key-123",
|
||||
BaseURL: "https://vertex.googleapis.com",
|
||||
Prefix: "vertex-prod",
|
||||
},
|
||||
},
|
||||
},
|
||||
Now: time.Now(),
|
||||
IDGenerator: NewStableIDGenerator(),
|
||||
}
|
||||
|
||||
auths, err := synth.Synthesize(ctx)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if len(auths) != 1 {
|
||||
t.Fatalf("expected 1 auth, got %d", len(auths))
|
||||
}
|
||||
|
||||
if auths[0].Provider != "vertex" {
|
||||
t.Errorf("expected provider vertex, got %s", auths[0].Provider)
|
||||
}
|
||||
if auths[0].Label != "vertex-apikey" {
|
||||
t.Errorf("expected label vertex-apikey, got %s", auths[0].Label)
|
||||
}
|
||||
if auths[0].Prefix != "vertex-prod" {
|
||||
t.Errorf("expected prefix vertex-prod, got %s", auths[0].Prefix)
|
||||
}
|
||||
}
|
||||
|
||||
func TestConfigSynthesizer_VertexCompat_SkipsEmptyAndHeaders(t *testing.T) {
|
||||
synth := NewConfigSynthesizer()
|
||||
ctx := &SynthesisContext{
|
||||
Config: &config.Config{
|
||||
VertexCompatAPIKey: []config.VertexCompatKey{
|
||||
{APIKey: "", BaseURL: "https://vertex.api"}, // empty key creates auth without api_key attr
|
||||
{APIKey: " ", BaseURL: "https://vertex.api"}, // whitespace key creates auth without api_key attr
|
||||
{APIKey: "valid-key", BaseURL: "https://vertex.api", Headers: map[string]string{"X-Vertex": "test"}},
|
||||
},
|
||||
},
|
||||
Now: time.Now(),
|
||||
IDGenerator: NewStableIDGenerator(),
|
||||
}
|
||||
|
||||
auths, err := synth.Synthesize(ctx)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
// Vertex compat doesn't skip empty keys - it creates auths without api_key attribute
|
||||
if len(auths) != 3 {
|
||||
t.Fatalf("expected 3 auths, got %d", len(auths))
|
||||
}
|
||||
// First two should not have api_key attribute
|
||||
if _, ok := auths[0].Attributes["api_key"]; ok {
|
||||
t.Error("expected first auth to not have api_key attribute")
|
||||
}
|
||||
if _, ok := auths[1].Attributes["api_key"]; ok {
|
||||
t.Error("expected second auth to not have api_key attribute")
|
||||
}
|
||||
// Third should have headers
|
||||
if auths[2].Attributes["header:X-Vertex"] != "test" {
|
||||
t.Errorf("expected header:X-Vertex=test, got %s", auths[2].Attributes["header:X-Vertex"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestConfigSynthesizer_OpenAICompat_WithModelsHash(t *testing.T) {
|
||||
synth := NewConfigSynthesizer()
|
||||
ctx := &SynthesisContext{
|
||||
Config: &config.Config{
|
||||
OpenAICompatibility: []config.OpenAICompatibility{
|
||||
{
|
||||
Name: "TestProvider",
|
||||
BaseURL: "https://test.api.com",
|
||||
Models: []config.OpenAICompatibilityModel{
|
||||
{Name: "model-a"},
|
||||
{Name: "model-b"},
|
||||
},
|
||||
APIKeyEntries: []config.OpenAICompatibilityAPIKey{
|
||||
{APIKey: "key-with-models"},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
Now: time.Now(),
|
||||
IDGenerator: NewStableIDGenerator(),
|
||||
}
|
||||
|
||||
auths, err := synth.Synthesize(ctx)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if len(auths) != 1 {
|
||||
t.Fatalf("expected 1 auth, got %d", len(auths))
|
||||
}
|
||||
if _, ok := auths[0].Attributes["models_hash"]; !ok {
|
||||
t.Error("expected models_hash in attributes")
|
||||
}
|
||||
if auths[0].Attributes["api_key"] != "key-with-models" {
|
||||
t.Errorf("expected api_key key-with-models, got %s", auths[0].Attributes["api_key"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestConfigSynthesizer_OpenAICompat_FallbackWithModels(t *testing.T) {
|
||||
synth := NewConfigSynthesizer()
|
||||
ctx := &SynthesisContext{
|
||||
Config: &config.Config{
|
||||
OpenAICompatibility: []config.OpenAICompatibility{
|
||||
{
|
||||
Name: "NoKeyWithModels",
|
||||
BaseURL: "https://nokey.api.com",
|
||||
Models: []config.OpenAICompatibilityModel{
|
||||
{Name: "model-x"},
|
||||
},
|
||||
Headers: map[string]string{"X-API": "header-value"},
|
||||
// No APIKeyEntries - should use fallback path
|
||||
},
|
||||
},
|
||||
},
|
||||
Now: time.Now(),
|
||||
IDGenerator: NewStableIDGenerator(),
|
||||
}
|
||||
|
||||
auths, err := synth.Synthesize(ctx)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if len(auths) != 1 {
|
||||
t.Fatalf("expected 1 auth, got %d", len(auths))
|
||||
}
|
||||
if _, ok := auths[0].Attributes["models_hash"]; !ok {
|
||||
t.Error("expected models_hash in fallback path")
|
||||
}
|
||||
if auths[0].Attributes["header:X-API"] != "header-value" {
|
||||
t.Errorf("expected header:X-API=header-value, got %s", auths[0].Attributes["header:X-API"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestConfigSynthesizer_VertexCompat_WithModels(t *testing.T) {
|
||||
synth := NewConfigSynthesizer()
|
||||
ctx := &SynthesisContext{
|
||||
Config: &config.Config{
|
||||
VertexCompatAPIKey: []config.VertexCompatKey{
|
||||
{
|
||||
APIKey: "vertex-key",
|
||||
BaseURL: "https://vertex.api",
|
||||
Models: []config.VertexCompatModel{
|
||||
{Name: "gemini-pro", Alias: "pro"},
|
||||
{Name: "gemini-ultra", Alias: "ultra"},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
Now: time.Now(),
|
||||
IDGenerator: NewStableIDGenerator(),
|
||||
}
|
||||
|
||||
auths, err := synth.Synthesize(ctx)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if len(auths) != 1 {
|
||||
t.Fatalf("expected 1 auth, got %d", len(auths))
|
||||
}
|
||||
if _, ok := auths[0].Attributes["models_hash"]; !ok {
|
||||
t.Error("expected models_hash in vertex auth with models")
|
||||
}
|
||||
}
|
||||
|
||||
func TestConfigSynthesizer_IDStability(t *testing.T) {
|
||||
cfg := &config.Config{
|
||||
GeminiKey: []config.GeminiKey{
|
||||
{APIKey: "stable-key", Prefix: "test"},
|
||||
},
|
||||
}
|
||||
|
||||
// Generate IDs twice with fresh generators
|
||||
synth1 := NewConfigSynthesizer()
|
||||
ctx1 := &SynthesisContext{
|
||||
Config: cfg,
|
||||
Now: time.Date(2025, 1, 1, 0, 0, 0, 0, time.UTC),
|
||||
IDGenerator: NewStableIDGenerator(),
|
||||
}
|
||||
auths1, _ := synth1.Synthesize(ctx1)
|
||||
|
||||
synth2 := NewConfigSynthesizer()
|
||||
ctx2 := &SynthesisContext{
|
||||
Config: cfg,
|
||||
Now: time.Date(2025, 1, 1, 0, 0, 0, 0, time.UTC),
|
||||
IDGenerator: NewStableIDGenerator(),
|
||||
}
|
||||
auths2, _ := synth2.Synthesize(ctx2)
|
||||
|
||||
if auths1[0].ID != auths2[0].ID {
|
||||
t.Errorf("same config should produce same ID: got %q and %q", auths1[0].ID, auths2[0].ID)
|
||||
}
|
||||
}
|
||||
|
||||
func TestConfigSynthesizer_AllProviders(t *testing.T) {
|
||||
synth := NewConfigSynthesizer()
|
||||
ctx := &SynthesisContext{
|
||||
Config: &config.Config{
|
||||
GeminiKey: []config.GeminiKey{
|
||||
{APIKey: "gemini-key"},
|
||||
},
|
||||
ClaudeKey: []config.ClaudeKey{
|
||||
{APIKey: "claude-key"},
|
||||
},
|
||||
CodexKey: []config.CodexKey{
|
||||
{APIKey: "codex-key"},
|
||||
},
|
||||
OpenAICompatibility: []config.OpenAICompatibility{
|
||||
{Name: "compat", BaseURL: "https://compat.api"},
|
||||
},
|
||||
VertexCompatAPIKey: []config.VertexCompatKey{
|
||||
{APIKey: "vertex-key", BaseURL: "https://vertex.api"},
|
||||
},
|
||||
},
|
||||
Now: time.Now(),
|
||||
IDGenerator: NewStableIDGenerator(),
|
||||
}
|
||||
|
||||
auths, err := synth.Synthesize(ctx)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if len(auths) != 5 {
|
||||
t.Fatalf("expected 5 auths, got %d", len(auths))
|
||||
}
|
||||
|
||||
providers := make(map[string]bool)
|
||||
for _, a := range auths {
|
||||
providers[a.Provider] = true
|
||||
}
|
||||
|
||||
expected := []string{"gemini", "claude", "codex", "compat", "vertex"}
|
||||
for _, p := range expected {
|
||||
if !providers[p] {
|
||||
t.Errorf("expected provider %s not found", p)
|
||||
}
|
||||
}
|
||||
}
|
||||
19
internal/watcher/synthesizer/context.go
Normal file
19
internal/watcher/synthesizer/context.go
Normal file
@@ -0,0 +1,19 @@
|
||||
package synthesizer
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||
)
|
||||
|
||||
// SynthesisContext provides the context needed for auth synthesis.
|
||||
type SynthesisContext struct {
|
||||
// Config is the current configuration
|
||||
Config *config.Config
|
||||
// AuthDir is the directory containing auth files
|
||||
AuthDir string
|
||||
// Now is the current time for timestamps
|
||||
Now time.Time
|
||||
// IDGenerator generates stable IDs for auth entries
|
||||
IDGenerator *StableIDGenerator
|
||||
}
|
||||
224
internal/watcher/synthesizer/file.go
Normal file
224
internal/watcher/synthesizer/file.go
Normal file
@@ -0,0 +1,224 @@
|
||||
package synthesizer
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/runtime/geminicli"
|
||||
coreauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
|
||||
)
|
||||
|
||||
// FileSynthesizer generates Auth entries from OAuth JSON files.
|
||||
// It handles file-based authentication and Gemini virtual auth generation.
|
||||
type FileSynthesizer struct{}
|
||||
|
||||
// NewFileSynthesizer creates a new FileSynthesizer instance.
|
||||
func NewFileSynthesizer() *FileSynthesizer {
|
||||
return &FileSynthesizer{}
|
||||
}
|
||||
|
||||
// Synthesize generates Auth entries from auth files in the auth directory.
|
||||
func (s *FileSynthesizer) Synthesize(ctx *SynthesisContext) ([]*coreauth.Auth, error) {
|
||||
out := make([]*coreauth.Auth, 0, 16)
|
||||
if ctx == nil || ctx.AuthDir == "" {
|
||||
return out, nil
|
||||
}
|
||||
|
||||
entries, err := os.ReadDir(ctx.AuthDir)
|
||||
if err != nil {
|
||||
// Not an error if directory doesn't exist
|
||||
return out, nil
|
||||
}
|
||||
|
||||
now := ctx.Now
|
||||
cfg := ctx.Config
|
||||
|
||||
for _, e := range entries {
|
||||
if e.IsDir() {
|
||||
continue
|
||||
}
|
||||
name := e.Name()
|
||||
if !strings.HasSuffix(strings.ToLower(name), ".json") {
|
||||
continue
|
||||
}
|
||||
full := filepath.Join(ctx.AuthDir, name)
|
||||
data, errRead := os.ReadFile(full)
|
||||
if errRead != nil || len(data) == 0 {
|
||||
continue
|
||||
}
|
||||
var metadata map[string]any
|
||||
if errUnmarshal := json.Unmarshal(data, &metadata); errUnmarshal != nil {
|
||||
continue
|
||||
}
|
||||
t, _ := metadata["type"].(string)
|
||||
if t == "" {
|
||||
continue
|
||||
}
|
||||
provider := strings.ToLower(t)
|
||||
if provider == "gemini" {
|
||||
provider = "gemini-cli"
|
||||
}
|
||||
label := provider
|
||||
if email, _ := metadata["email"].(string); email != "" {
|
||||
label = email
|
||||
}
|
||||
// Use relative path under authDir as ID to stay consistent with the file-based token store
|
||||
id := full
|
||||
if rel, errRel := filepath.Rel(ctx.AuthDir, full); errRel == nil && rel != "" {
|
||||
id = rel
|
||||
}
|
||||
|
||||
proxyURL := ""
|
||||
if p, ok := metadata["proxy_url"].(string); ok {
|
||||
proxyURL = p
|
||||
}
|
||||
|
||||
prefix := ""
|
||||
if rawPrefix, ok := metadata["prefix"].(string); ok {
|
||||
trimmed := strings.TrimSpace(rawPrefix)
|
||||
trimmed = strings.Trim(trimmed, "/")
|
||||
if trimmed != "" && !strings.Contains(trimmed, "/") {
|
||||
prefix = trimmed
|
||||
}
|
||||
}
|
||||
|
||||
a := &coreauth.Auth{
|
||||
ID: id,
|
||||
Provider: provider,
|
||||
Label: label,
|
||||
Prefix: prefix,
|
||||
Status: coreauth.StatusActive,
|
||||
Attributes: map[string]string{
|
||||
"source": full,
|
||||
"path": full,
|
||||
},
|
||||
ProxyURL: proxyURL,
|
||||
Metadata: metadata,
|
||||
CreatedAt: now,
|
||||
UpdatedAt: now,
|
||||
}
|
||||
ApplyAuthExcludedModelsMeta(a, cfg, nil, "oauth")
|
||||
if provider == "gemini-cli" {
|
||||
if virtuals := SynthesizeGeminiVirtualAuths(a, metadata, now); len(virtuals) > 0 {
|
||||
for _, v := range virtuals {
|
||||
ApplyAuthExcludedModelsMeta(v, cfg, nil, "oauth")
|
||||
}
|
||||
out = append(out, a)
|
||||
out = append(out, virtuals...)
|
||||
continue
|
||||
}
|
||||
}
|
||||
out = append(out, a)
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
||||
// SynthesizeGeminiVirtualAuths creates virtual Auth entries for multi-project Gemini credentials.
|
||||
// It disables the primary auth and creates one virtual auth per project.
|
||||
func SynthesizeGeminiVirtualAuths(primary *coreauth.Auth, metadata map[string]any, now time.Time) []*coreauth.Auth {
|
||||
if primary == nil || metadata == nil {
|
||||
return nil
|
||||
}
|
||||
projects := splitGeminiProjectIDs(metadata)
|
||||
if len(projects) <= 1 {
|
||||
return nil
|
||||
}
|
||||
email, _ := metadata["email"].(string)
|
||||
shared := geminicli.NewSharedCredential(primary.ID, email, metadata, projects)
|
||||
primary.Disabled = true
|
||||
primary.Status = coreauth.StatusDisabled
|
||||
primary.Runtime = shared
|
||||
if primary.Attributes == nil {
|
||||
primary.Attributes = make(map[string]string)
|
||||
}
|
||||
primary.Attributes["gemini_virtual_primary"] = "true"
|
||||
primary.Attributes["virtual_children"] = strings.Join(projects, ",")
|
||||
source := primary.Attributes["source"]
|
||||
authPath := primary.Attributes["path"]
|
||||
originalProvider := primary.Provider
|
||||
if originalProvider == "" {
|
||||
originalProvider = "gemini-cli"
|
||||
}
|
||||
label := primary.Label
|
||||
if label == "" {
|
||||
label = originalProvider
|
||||
}
|
||||
virtuals := make([]*coreauth.Auth, 0, len(projects))
|
||||
for _, projectID := range projects {
|
||||
attrs := map[string]string{
|
||||
"runtime_only": "true",
|
||||
"gemini_virtual_parent": primary.ID,
|
||||
"gemini_virtual_project": projectID,
|
||||
}
|
||||
if source != "" {
|
||||
attrs["source"] = source
|
||||
}
|
||||
if authPath != "" {
|
||||
attrs["path"] = authPath
|
||||
}
|
||||
metadataCopy := map[string]any{
|
||||
"email": email,
|
||||
"project_id": projectID,
|
||||
"virtual": true,
|
||||
"virtual_parent_id": primary.ID,
|
||||
"type": metadata["type"],
|
||||
}
|
||||
proxy := strings.TrimSpace(primary.ProxyURL)
|
||||
if proxy != "" {
|
||||
metadataCopy["proxy_url"] = proxy
|
||||
}
|
||||
virtual := &coreauth.Auth{
|
||||
ID: buildGeminiVirtualID(primary.ID, projectID),
|
||||
Provider: originalProvider,
|
||||
Label: fmt.Sprintf("%s [%s]", label, projectID),
|
||||
Status: coreauth.StatusActive,
|
||||
Attributes: attrs,
|
||||
Metadata: metadataCopy,
|
||||
ProxyURL: primary.ProxyURL,
|
||||
Prefix: primary.Prefix,
|
||||
CreatedAt: primary.CreatedAt,
|
||||
UpdatedAt: primary.UpdatedAt,
|
||||
Runtime: geminicli.NewVirtualCredential(projectID, shared),
|
||||
}
|
||||
virtuals = append(virtuals, virtual)
|
||||
}
|
||||
return virtuals
|
||||
}
|
||||
|
||||
// splitGeminiProjectIDs extracts and deduplicates project IDs from metadata.
|
||||
func splitGeminiProjectIDs(metadata map[string]any) []string {
|
||||
raw, _ := metadata["project_id"].(string)
|
||||
trimmed := strings.TrimSpace(raw)
|
||||
if trimmed == "" {
|
||||
return nil
|
||||
}
|
||||
parts := strings.Split(trimmed, ",")
|
||||
result := make([]string, 0, len(parts))
|
||||
seen := make(map[string]struct{}, len(parts))
|
||||
for _, part := range parts {
|
||||
id := strings.TrimSpace(part)
|
||||
if id == "" {
|
||||
continue
|
||||
}
|
||||
if _, ok := seen[id]; ok {
|
||||
continue
|
||||
}
|
||||
seen[id] = struct{}{}
|
||||
result = append(result, id)
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
// buildGeminiVirtualID constructs a virtual auth ID from base ID and project ID.
|
||||
func buildGeminiVirtualID(baseID, projectID string) string {
|
||||
project := strings.TrimSpace(projectID)
|
||||
if project == "" {
|
||||
project = "project"
|
||||
}
|
||||
replacer := strings.NewReplacer("/", "_", "\\", "_", " ", "_")
|
||||
return fmt.Sprintf("%s::%s", baseID, replacer.Replace(project))
|
||||
}
|
||||
612
internal/watcher/synthesizer/file_test.go
Normal file
612
internal/watcher/synthesizer/file_test.go
Normal file
@@ -0,0 +1,612 @@
|
||||
package synthesizer
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||
coreauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
|
||||
)
|
||||
|
||||
func TestNewFileSynthesizer(t *testing.T) {
|
||||
synth := NewFileSynthesizer()
|
||||
if synth == nil {
|
||||
t.Fatal("expected non-nil synthesizer")
|
||||
}
|
||||
}
|
||||
|
||||
func TestFileSynthesizer_Synthesize_NilContext(t *testing.T) {
|
||||
synth := NewFileSynthesizer()
|
||||
auths, err := synth.Synthesize(nil)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if len(auths) != 0 {
|
||||
t.Fatalf("expected empty auths, got %d", len(auths))
|
||||
}
|
||||
}
|
||||
|
||||
func TestFileSynthesizer_Synthesize_EmptyAuthDir(t *testing.T) {
|
||||
synth := NewFileSynthesizer()
|
||||
ctx := &SynthesisContext{
|
||||
Config: &config.Config{},
|
||||
AuthDir: "",
|
||||
Now: time.Now(),
|
||||
IDGenerator: NewStableIDGenerator(),
|
||||
}
|
||||
auths, err := synth.Synthesize(ctx)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if len(auths) != 0 {
|
||||
t.Fatalf("expected empty auths, got %d", len(auths))
|
||||
}
|
||||
}
|
||||
|
||||
func TestFileSynthesizer_Synthesize_NonExistentDir(t *testing.T) {
|
||||
synth := NewFileSynthesizer()
|
||||
ctx := &SynthesisContext{
|
||||
Config: &config.Config{},
|
||||
AuthDir: "/non/existent/path",
|
||||
Now: time.Now(),
|
||||
IDGenerator: NewStableIDGenerator(),
|
||||
}
|
||||
auths, err := synth.Synthesize(ctx)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if len(auths) != 0 {
|
||||
t.Fatalf("expected empty auths, got %d", len(auths))
|
||||
}
|
||||
}
|
||||
|
||||
func TestFileSynthesizer_Synthesize_ValidAuthFile(t *testing.T) {
|
||||
tempDir := t.TempDir()
|
||||
|
||||
// Create a valid auth file
|
||||
authData := map[string]any{
|
||||
"type": "claude",
|
||||
"email": "test@example.com",
|
||||
"proxy_url": "http://proxy.local",
|
||||
"prefix": "test-prefix",
|
||||
}
|
||||
data, _ := json.Marshal(authData)
|
||||
err := os.WriteFile(filepath.Join(tempDir, "claude-auth.json"), data, 0644)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to write auth file: %v", err)
|
||||
}
|
||||
|
||||
synth := NewFileSynthesizer()
|
||||
ctx := &SynthesisContext{
|
||||
Config: &config.Config{},
|
||||
AuthDir: tempDir,
|
||||
Now: time.Date(2025, 1, 1, 0, 0, 0, 0, time.UTC),
|
||||
IDGenerator: NewStableIDGenerator(),
|
||||
}
|
||||
|
||||
auths, err := synth.Synthesize(ctx)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if len(auths) != 1 {
|
||||
t.Fatalf("expected 1 auth, got %d", len(auths))
|
||||
}
|
||||
|
||||
if auths[0].Provider != "claude" {
|
||||
t.Errorf("expected provider claude, got %s", auths[0].Provider)
|
||||
}
|
||||
if auths[0].Label != "test@example.com" {
|
||||
t.Errorf("expected label test@example.com, got %s", auths[0].Label)
|
||||
}
|
||||
if auths[0].Prefix != "test-prefix" {
|
||||
t.Errorf("expected prefix test-prefix, got %s", auths[0].Prefix)
|
||||
}
|
||||
if auths[0].ProxyURL != "http://proxy.local" {
|
||||
t.Errorf("expected proxy_url http://proxy.local, got %s", auths[0].ProxyURL)
|
||||
}
|
||||
if auths[0].Status != coreauth.StatusActive {
|
||||
t.Errorf("expected status active, got %s", auths[0].Status)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFileSynthesizer_Synthesize_GeminiProviderMapping(t *testing.T) {
|
||||
tempDir := t.TempDir()
|
||||
|
||||
// Gemini type should be mapped to gemini-cli
|
||||
authData := map[string]any{
|
||||
"type": "gemini",
|
||||
"email": "gemini@example.com",
|
||||
}
|
||||
data, _ := json.Marshal(authData)
|
||||
err := os.WriteFile(filepath.Join(tempDir, "gemini-auth.json"), data, 0644)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to write auth file: %v", err)
|
||||
}
|
||||
|
||||
synth := NewFileSynthesizer()
|
||||
ctx := &SynthesisContext{
|
||||
Config: &config.Config{},
|
||||
AuthDir: tempDir,
|
||||
Now: time.Now(),
|
||||
IDGenerator: NewStableIDGenerator(),
|
||||
}
|
||||
|
||||
auths, err := synth.Synthesize(ctx)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if len(auths) != 1 {
|
||||
t.Fatalf("expected 1 auth, got %d", len(auths))
|
||||
}
|
||||
|
||||
if auths[0].Provider != "gemini-cli" {
|
||||
t.Errorf("gemini should be mapped to gemini-cli, got %s", auths[0].Provider)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFileSynthesizer_Synthesize_SkipsInvalidFiles(t *testing.T) {
|
||||
tempDir := t.TempDir()
|
||||
|
||||
// Create various invalid files
|
||||
_ = os.WriteFile(filepath.Join(tempDir, "not-json.txt"), []byte("text content"), 0644)
|
||||
_ = os.WriteFile(filepath.Join(tempDir, "invalid.json"), []byte("not valid json"), 0644)
|
||||
_ = os.WriteFile(filepath.Join(tempDir, "empty.json"), []byte(""), 0644)
|
||||
_ = os.WriteFile(filepath.Join(tempDir, "no-type.json"), []byte(`{"email": "test@example.com"}`), 0644)
|
||||
|
||||
// Create one valid file
|
||||
validData, _ := json.Marshal(map[string]any{"type": "claude", "email": "valid@example.com"})
|
||||
_ = os.WriteFile(filepath.Join(tempDir, "valid.json"), validData, 0644)
|
||||
|
||||
synth := NewFileSynthesizer()
|
||||
ctx := &SynthesisContext{
|
||||
Config: &config.Config{},
|
||||
AuthDir: tempDir,
|
||||
Now: time.Now(),
|
||||
IDGenerator: NewStableIDGenerator(),
|
||||
}
|
||||
|
||||
auths, err := synth.Synthesize(ctx)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if len(auths) != 1 {
|
||||
t.Fatalf("only valid auth file should be processed, got %d", len(auths))
|
||||
}
|
||||
if auths[0].Label != "valid@example.com" {
|
||||
t.Errorf("expected label valid@example.com, got %s", auths[0].Label)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFileSynthesizer_Synthesize_SkipsDirectories(t *testing.T) {
|
||||
tempDir := t.TempDir()
|
||||
|
||||
// Create a subdirectory with a json file inside
|
||||
subDir := filepath.Join(tempDir, "subdir.json")
|
||||
err := os.Mkdir(subDir, 0755)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create subdir: %v", err)
|
||||
}
|
||||
|
||||
// Create a valid file in root
|
||||
validData, _ := json.Marshal(map[string]any{"type": "claude"})
|
||||
_ = os.WriteFile(filepath.Join(tempDir, "valid.json"), validData, 0644)
|
||||
|
||||
synth := NewFileSynthesizer()
|
||||
ctx := &SynthesisContext{
|
||||
Config: &config.Config{},
|
||||
AuthDir: tempDir,
|
||||
Now: time.Now(),
|
||||
IDGenerator: NewStableIDGenerator(),
|
||||
}
|
||||
|
||||
auths, err := synth.Synthesize(ctx)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if len(auths) != 1 {
|
||||
t.Fatalf("expected 1 auth, got %d", len(auths))
|
||||
}
|
||||
}
|
||||
|
||||
func TestFileSynthesizer_Synthesize_RelativeID(t *testing.T) {
|
||||
tempDir := t.TempDir()
|
||||
|
||||
authData := map[string]any{"type": "claude"}
|
||||
data, _ := json.Marshal(authData)
|
||||
err := os.WriteFile(filepath.Join(tempDir, "my-auth.json"), data, 0644)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to write auth file: %v", err)
|
||||
}
|
||||
|
||||
synth := NewFileSynthesizer()
|
||||
ctx := &SynthesisContext{
|
||||
Config: &config.Config{},
|
||||
AuthDir: tempDir,
|
||||
Now: time.Now(),
|
||||
IDGenerator: NewStableIDGenerator(),
|
||||
}
|
||||
|
||||
auths, err := synth.Synthesize(ctx)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if len(auths) != 1 {
|
||||
t.Fatalf("expected 1 auth, got %d", len(auths))
|
||||
}
|
||||
|
||||
// ID should be relative path
|
||||
if auths[0].ID != "my-auth.json" {
|
||||
t.Errorf("expected ID my-auth.json, got %s", auths[0].ID)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFileSynthesizer_Synthesize_PrefixValidation(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
prefix string
|
||||
wantPrefix string
|
||||
}{
|
||||
{"valid prefix", "myprefix", "myprefix"},
|
||||
{"prefix with slashes trimmed", "/myprefix/", "myprefix"},
|
||||
{"prefix with spaces trimmed", " myprefix ", "myprefix"},
|
||||
{"prefix with internal slash rejected", "my/prefix", ""},
|
||||
{"empty prefix", "", ""},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
tempDir := t.TempDir()
|
||||
authData := map[string]any{
|
||||
"type": "claude",
|
||||
"prefix": tt.prefix,
|
||||
}
|
||||
data, _ := json.Marshal(authData)
|
||||
_ = os.WriteFile(filepath.Join(tempDir, "auth.json"), data, 0644)
|
||||
|
||||
synth := NewFileSynthesizer()
|
||||
ctx := &SynthesisContext{
|
||||
Config: &config.Config{},
|
||||
AuthDir: tempDir,
|
||||
Now: time.Now(),
|
||||
IDGenerator: NewStableIDGenerator(),
|
||||
}
|
||||
|
||||
auths, err := synth.Synthesize(ctx)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if len(auths) != 1 {
|
||||
t.Fatalf("expected 1 auth, got %d", len(auths))
|
||||
}
|
||||
if auths[0].Prefix != tt.wantPrefix {
|
||||
t.Errorf("expected prefix %q, got %q", tt.wantPrefix, auths[0].Prefix)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSynthesizeGeminiVirtualAuths_NilInputs(t *testing.T) {
|
||||
now := time.Now()
|
||||
|
||||
if SynthesizeGeminiVirtualAuths(nil, nil, now) != nil {
|
||||
t.Error("expected nil for nil primary")
|
||||
}
|
||||
if SynthesizeGeminiVirtualAuths(&coreauth.Auth{}, nil, now) != nil {
|
||||
t.Error("expected nil for nil metadata")
|
||||
}
|
||||
if SynthesizeGeminiVirtualAuths(nil, map[string]any{}, now) != nil {
|
||||
t.Error("expected nil for nil primary with metadata")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSynthesizeGeminiVirtualAuths_SingleProject(t *testing.T) {
|
||||
now := time.Now()
|
||||
primary := &coreauth.Auth{
|
||||
ID: "test-id",
|
||||
Provider: "gemini-cli",
|
||||
Label: "test@example.com",
|
||||
}
|
||||
metadata := map[string]any{
|
||||
"project_id": "single-project",
|
||||
"email": "test@example.com",
|
||||
"type": "gemini",
|
||||
}
|
||||
|
||||
virtuals := SynthesizeGeminiVirtualAuths(primary, metadata, now)
|
||||
if virtuals != nil {
|
||||
t.Error("single project should not create virtuals")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSynthesizeGeminiVirtualAuths_MultiProject(t *testing.T) {
|
||||
now := time.Now()
|
||||
primary := &coreauth.Auth{
|
||||
ID: "primary-id",
|
||||
Provider: "gemini-cli",
|
||||
Label: "test@example.com",
|
||||
Prefix: "test-prefix",
|
||||
ProxyURL: "http://proxy.local",
|
||||
Attributes: map[string]string{
|
||||
"source": "test-source",
|
||||
"path": "/path/to/auth",
|
||||
},
|
||||
}
|
||||
metadata := map[string]any{
|
||||
"project_id": "project-a, project-b, project-c",
|
||||
"email": "test@example.com",
|
||||
"type": "gemini",
|
||||
}
|
||||
|
||||
virtuals := SynthesizeGeminiVirtualAuths(primary, metadata, now)
|
||||
|
||||
if len(virtuals) != 3 {
|
||||
t.Fatalf("expected 3 virtuals, got %d", len(virtuals))
|
||||
}
|
||||
|
||||
// Check primary is disabled
|
||||
if !primary.Disabled {
|
||||
t.Error("expected primary to be disabled")
|
||||
}
|
||||
if primary.Status != coreauth.StatusDisabled {
|
||||
t.Errorf("expected primary status disabled, got %s", primary.Status)
|
||||
}
|
||||
if primary.Attributes["gemini_virtual_primary"] != "true" {
|
||||
t.Error("expected gemini_virtual_primary=true")
|
||||
}
|
||||
if !strings.Contains(primary.Attributes["virtual_children"], "project-a") {
|
||||
t.Error("expected virtual_children to contain project-a")
|
||||
}
|
||||
|
||||
// Check virtuals
|
||||
projectIDs := []string{"project-a", "project-b", "project-c"}
|
||||
for i, v := range virtuals {
|
||||
if v.Provider != "gemini-cli" {
|
||||
t.Errorf("expected provider gemini-cli, got %s", v.Provider)
|
||||
}
|
||||
if v.Status != coreauth.StatusActive {
|
||||
t.Errorf("expected status active, got %s", v.Status)
|
||||
}
|
||||
if v.Prefix != "test-prefix" {
|
||||
t.Errorf("expected prefix test-prefix, got %s", v.Prefix)
|
||||
}
|
||||
if v.ProxyURL != "http://proxy.local" {
|
||||
t.Errorf("expected proxy_url http://proxy.local, got %s", v.ProxyURL)
|
||||
}
|
||||
if v.Attributes["runtime_only"] != "true" {
|
||||
t.Error("expected runtime_only=true")
|
||||
}
|
||||
if v.Attributes["gemini_virtual_parent"] != "primary-id" {
|
||||
t.Errorf("expected gemini_virtual_parent=primary-id, got %s", v.Attributes["gemini_virtual_parent"])
|
||||
}
|
||||
if v.Attributes["gemini_virtual_project"] != projectIDs[i] {
|
||||
t.Errorf("expected gemini_virtual_project=%s, got %s", projectIDs[i], v.Attributes["gemini_virtual_project"])
|
||||
}
|
||||
if !strings.Contains(v.Label, "["+projectIDs[i]+"]") {
|
||||
t.Errorf("expected label to contain [%s], got %s", projectIDs[i], v.Label)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestSynthesizeGeminiVirtualAuths_EmptyProviderAndLabel(t *testing.T) {
|
||||
now := time.Now()
|
||||
// Test with empty Provider and Label to cover fallback branches
|
||||
primary := &coreauth.Auth{
|
||||
ID: "primary-id",
|
||||
Provider: "", // empty provider - should default to gemini-cli
|
||||
Label: "", // empty label - should default to provider
|
||||
Attributes: map[string]string{},
|
||||
}
|
||||
metadata := map[string]any{
|
||||
"project_id": "proj-a, proj-b",
|
||||
"email": "user@example.com",
|
||||
"type": "gemini",
|
||||
}
|
||||
|
||||
virtuals := SynthesizeGeminiVirtualAuths(primary, metadata, now)
|
||||
|
||||
if len(virtuals) != 2 {
|
||||
t.Fatalf("expected 2 virtuals, got %d", len(virtuals))
|
||||
}
|
||||
|
||||
// Check that empty provider defaults to gemini-cli
|
||||
if virtuals[0].Provider != "gemini-cli" {
|
||||
t.Errorf("expected provider gemini-cli (default), got %s", virtuals[0].Provider)
|
||||
}
|
||||
// Check that empty label defaults to provider
|
||||
if !strings.Contains(virtuals[0].Label, "gemini-cli") {
|
||||
t.Errorf("expected label to contain gemini-cli, got %s", virtuals[0].Label)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSynthesizeGeminiVirtualAuths_NilPrimaryAttributes(t *testing.T) {
|
||||
now := time.Now()
|
||||
primary := &coreauth.Auth{
|
||||
ID: "primary-id",
|
||||
Provider: "gemini-cli",
|
||||
Label: "test@example.com",
|
||||
Attributes: nil, // nil attributes
|
||||
}
|
||||
metadata := map[string]any{
|
||||
"project_id": "proj-a, proj-b",
|
||||
"email": "test@example.com",
|
||||
"type": "gemini",
|
||||
}
|
||||
|
||||
virtuals := SynthesizeGeminiVirtualAuths(primary, metadata, now)
|
||||
|
||||
if len(virtuals) != 2 {
|
||||
t.Fatalf("expected 2 virtuals, got %d", len(virtuals))
|
||||
}
|
||||
// Nil attributes should be initialized
|
||||
if primary.Attributes == nil {
|
||||
t.Error("expected primary.Attributes to be initialized")
|
||||
}
|
||||
if primary.Attributes["gemini_virtual_primary"] != "true" {
|
||||
t.Error("expected gemini_virtual_primary=true")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSplitGeminiProjectIDs(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
metadata map[string]any
|
||||
want []string
|
||||
}{
|
||||
{
|
||||
name: "single project",
|
||||
metadata: map[string]any{"project_id": "proj-a"},
|
||||
want: []string{"proj-a"},
|
||||
},
|
||||
{
|
||||
name: "multiple projects",
|
||||
metadata: map[string]any{"project_id": "proj-a, proj-b, proj-c"},
|
||||
want: []string{"proj-a", "proj-b", "proj-c"},
|
||||
},
|
||||
{
|
||||
name: "with duplicates",
|
||||
metadata: map[string]any{"project_id": "proj-a, proj-b, proj-a"},
|
||||
want: []string{"proj-a", "proj-b"},
|
||||
},
|
||||
{
|
||||
name: "with empty parts",
|
||||
metadata: map[string]any{"project_id": "proj-a, , proj-b, "},
|
||||
want: []string{"proj-a", "proj-b"},
|
||||
},
|
||||
{
|
||||
name: "empty project_id",
|
||||
metadata: map[string]any{"project_id": ""},
|
||||
want: nil,
|
||||
},
|
||||
{
|
||||
name: "no project_id",
|
||||
metadata: map[string]any{},
|
||||
want: nil,
|
||||
},
|
||||
{
|
||||
name: "whitespace only",
|
||||
metadata: map[string]any{"project_id": " "},
|
||||
want: nil,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := splitGeminiProjectIDs(tt.metadata)
|
||||
if len(got) != len(tt.want) {
|
||||
t.Fatalf("expected %v, got %v", tt.want, got)
|
||||
}
|
||||
for i := range got {
|
||||
if got[i] != tt.want[i] {
|
||||
t.Errorf("expected %v, got %v", tt.want, got)
|
||||
break
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestFileSynthesizer_Synthesize_MultiProjectGemini(t *testing.T) {
|
||||
tempDir := t.TempDir()
|
||||
|
||||
// Create a gemini auth file with multiple projects
|
||||
authData := map[string]any{
|
||||
"type": "gemini",
|
||||
"email": "multi@example.com",
|
||||
"project_id": "project-a, project-b, project-c",
|
||||
}
|
||||
data, _ := json.Marshal(authData)
|
||||
err := os.WriteFile(filepath.Join(tempDir, "gemini-multi.json"), data, 0644)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to write auth file: %v", err)
|
||||
}
|
||||
|
||||
synth := NewFileSynthesizer()
|
||||
ctx := &SynthesisContext{
|
||||
Config: &config.Config{},
|
||||
AuthDir: tempDir,
|
||||
Now: time.Now(),
|
||||
IDGenerator: NewStableIDGenerator(),
|
||||
}
|
||||
|
||||
auths, err := synth.Synthesize(ctx)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
// Should have 4 auths: 1 primary (disabled) + 3 virtuals
|
||||
if len(auths) != 4 {
|
||||
t.Fatalf("expected 4 auths (1 primary + 3 virtuals), got %d", len(auths))
|
||||
}
|
||||
|
||||
// First auth should be the primary (disabled)
|
||||
primary := auths[0]
|
||||
if !primary.Disabled {
|
||||
t.Error("expected primary to be disabled")
|
||||
}
|
||||
if primary.Status != coreauth.StatusDisabled {
|
||||
t.Errorf("expected primary status disabled, got %s", primary.Status)
|
||||
}
|
||||
|
||||
// Remaining auths should be virtuals
|
||||
for i := 1; i < 4; i++ {
|
||||
v := auths[i]
|
||||
if v.Status != coreauth.StatusActive {
|
||||
t.Errorf("expected virtual %d to be active, got %s", i, v.Status)
|
||||
}
|
||||
if v.Attributes["gemini_virtual_parent"] != primary.ID {
|
||||
t.Errorf("expected virtual %d parent to be %s, got %s", i, primary.ID, v.Attributes["gemini_virtual_parent"])
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildGeminiVirtualID(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
baseID string
|
||||
projectID string
|
||||
want string
|
||||
}{
|
||||
{
|
||||
name: "basic",
|
||||
baseID: "auth.json",
|
||||
projectID: "my-project",
|
||||
want: "auth.json::my-project",
|
||||
},
|
||||
{
|
||||
name: "with slashes",
|
||||
baseID: "path/to/auth.json",
|
||||
projectID: "project/with/slashes",
|
||||
want: "path/to/auth.json::project_with_slashes",
|
||||
},
|
||||
{
|
||||
name: "with spaces",
|
||||
baseID: "auth.json",
|
||||
projectID: "my project",
|
||||
want: "auth.json::my_project",
|
||||
},
|
||||
{
|
||||
name: "empty project",
|
||||
baseID: "auth.json",
|
||||
projectID: "",
|
||||
want: "auth.json::project",
|
||||
},
|
||||
{
|
||||
name: "whitespace project",
|
||||
baseID: "auth.json",
|
||||
projectID: " ",
|
||||
want: "auth.json::project",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := buildGeminiVirtualID(tt.baseID, tt.projectID)
|
||||
if got != tt.want {
|
||||
t.Errorf("expected %q, got %q", tt.want, got)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
110
internal/watcher/synthesizer/helpers.go
Normal file
110
internal/watcher/synthesizer/helpers.go
Normal file
@@ -0,0 +1,110 @@
|
||||
package synthesizer
|
||||
|
||||
import (
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
"sort"
|
||||
"strings"
|
||||
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/watcher/diff"
|
||||
coreauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
|
||||
)
|
||||
|
||||
// StableIDGenerator generates stable, deterministic IDs for auth entries.
|
||||
// It uses SHA256 hashing with collision handling via counters.
|
||||
// It is not safe for concurrent use.
|
||||
type StableIDGenerator struct {
|
||||
counters map[string]int
|
||||
}
|
||||
|
||||
// NewStableIDGenerator creates a new StableIDGenerator instance.
|
||||
func NewStableIDGenerator() *StableIDGenerator {
|
||||
return &StableIDGenerator{counters: make(map[string]int)}
|
||||
}
|
||||
|
||||
// Next generates a stable ID based on the kind and parts.
|
||||
// Returns the full ID (kind:hash) and the short hash portion.
|
||||
func (g *StableIDGenerator) Next(kind string, parts ...string) (string, string) {
|
||||
if g == nil {
|
||||
return kind + ":000000000000", "000000000000"
|
||||
}
|
||||
hasher := sha256.New()
|
||||
hasher.Write([]byte(kind))
|
||||
for _, part := range parts {
|
||||
trimmed := strings.TrimSpace(part)
|
||||
hasher.Write([]byte{0})
|
||||
hasher.Write([]byte(trimmed))
|
||||
}
|
||||
digest := hex.EncodeToString(hasher.Sum(nil))
|
||||
if len(digest) < 12 {
|
||||
digest = fmt.Sprintf("%012s", digest)
|
||||
}
|
||||
short := digest[:12]
|
||||
key := kind + ":" + short
|
||||
index := g.counters[key]
|
||||
g.counters[key] = index + 1
|
||||
if index > 0 {
|
||||
short = fmt.Sprintf("%s-%d", short, index)
|
||||
}
|
||||
return fmt.Sprintf("%s:%s", kind, short), short
|
||||
}
|
||||
|
||||
// ApplyAuthExcludedModelsMeta applies excluded models metadata to an auth entry.
|
||||
// It computes a hash of excluded models and sets the auth_kind attribute.
|
||||
func ApplyAuthExcludedModelsMeta(auth *coreauth.Auth, cfg *config.Config, perKey []string, authKind string) {
|
||||
if auth == nil || cfg == nil {
|
||||
return
|
||||
}
|
||||
authKindKey := strings.ToLower(strings.TrimSpace(authKind))
|
||||
seen := make(map[string]struct{})
|
||||
add := func(list []string) {
|
||||
for _, entry := range list {
|
||||
if trimmed := strings.TrimSpace(entry); trimmed != "" {
|
||||
key := strings.ToLower(trimmed)
|
||||
if _, exists := seen[key]; exists {
|
||||
continue
|
||||
}
|
||||
seen[key] = struct{}{}
|
||||
}
|
||||
}
|
||||
}
|
||||
if authKindKey == "apikey" {
|
||||
add(perKey)
|
||||
} else if cfg.OAuthExcludedModels != nil {
|
||||
providerKey := strings.ToLower(strings.TrimSpace(auth.Provider))
|
||||
add(cfg.OAuthExcludedModels[providerKey])
|
||||
}
|
||||
combined := make([]string, 0, len(seen))
|
||||
for k := range seen {
|
||||
combined = append(combined, k)
|
||||
}
|
||||
sort.Strings(combined)
|
||||
hash := diff.ComputeExcludedModelsHash(combined)
|
||||
if auth.Attributes == nil {
|
||||
auth.Attributes = make(map[string]string)
|
||||
}
|
||||
if hash != "" {
|
||||
auth.Attributes["excluded_models_hash"] = hash
|
||||
}
|
||||
if authKind != "" {
|
||||
auth.Attributes["auth_kind"] = authKind
|
||||
}
|
||||
}
|
||||
|
||||
// addConfigHeadersToAttrs adds header configuration to auth attributes.
|
||||
// Headers are prefixed with "header:" in the attributes map.
|
||||
func addConfigHeadersToAttrs(headers map[string]string, attrs map[string]string) {
|
||||
if len(headers) == 0 || attrs == nil {
|
||||
return
|
||||
}
|
||||
for hk, hv := range headers {
|
||||
key := strings.TrimSpace(hk)
|
||||
val := strings.TrimSpace(hv)
|
||||
if key == "" || val == "" {
|
||||
continue
|
||||
}
|
||||
attrs["header:"+key] = val
|
||||
}
|
||||
}
|
||||
264
internal/watcher/synthesizer/helpers_test.go
Normal file
264
internal/watcher/synthesizer/helpers_test.go
Normal file
@@ -0,0 +1,264 @@
|
||||
package synthesizer
|
||||
|
||||
import (
|
||||
"reflect"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||
coreauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
|
||||
)
|
||||
|
||||
func TestNewStableIDGenerator(t *testing.T) {
|
||||
gen := NewStableIDGenerator()
|
||||
if gen == nil {
|
||||
t.Fatal("expected non-nil generator")
|
||||
}
|
||||
if gen.counters == nil {
|
||||
t.Fatal("expected non-nil counters map")
|
||||
}
|
||||
}
|
||||
|
||||
func TestStableIDGenerator_Next(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
kind string
|
||||
parts []string
|
||||
wantPrefix string
|
||||
}{
|
||||
{
|
||||
name: "basic gemini apikey",
|
||||
kind: "gemini:apikey",
|
||||
parts: []string{"test-key", ""},
|
||||
wantPrefix: "gemini:apikey:",
|
||||
},
|
||||
{
|
||||
name: "claude with base url",
|
||||
kind: "claude:apikey",
|
||||
parts: []string{"sk-ant-xxx", "https://api.anthropic.com"},
|
||||
wantPrefix: "claude:apikey:",
|
||||
},
|
||||
{
|
||||
name: "empty parts",
|
||||
kind: "codex:apikey",
|
||||
parts: []string{},
|
||||
wantPrefix: "codex:apikey:",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
gen := NewStableIDGenerator()
|
||||
id, short := gen.Next(tt.kind, tt.parts...)
|
||||
|
||||
if !strings.Contains(id, tt.wantPrefix) {
|
||||
t.Errorf("expected id to contain %q, got %q", tt.wantPrefix, id)
|
||||
}
|
||||
if short == "" {
|
||||
t.Error("expected non-empty short id")
|
||||
}
|
||||
if len(short) != 12 {
|
||||
t.Errorf("expected short id length 12, got %d", len(short))
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestStableIDGenerator_Stability(t *testing.T) {
|
||||
gen1 := NewStableIDGenerator()
|
||||
gen2 := NewStableIDGenerator()
|
||||
|
||||
id1, _ := gen1.Next("gemini:apikey", "test-key", "https://api.example.com")
|
||||
id2, _ := gen2.Next("gemini:apikey", "test-key", "https://api.example.com")
|
||||
|
||||
if id1 != id2 {
|
||||
t.Errorf("same inputs should produce same ID: got %q and %q", id1, id2)
|
||||
}
|
||||
}
|
||||
|
||||
func TestStableIDGenerator_CollisionHandling(t *testing.T) {
|
||||
gen := NewStableIDGenerator()
|
||||
|
||||
id1, short1 := gen.Next("gemini:apikey", "same-key")
|
||||
id2, short2 := gen.Next("gemini:apikey", "same-key")
|
||||
|
||||
if id1 == id2 {
|
||||
t.Error("collision should be handled with suffix")
|
||||
}
|
||||
if short1 == short2 {
|
||||
t.Error("short ids should differ")
|
||||
}
|
||||
if !strings.Contains(short2, "-1") {
|
||||
t.Errorf("second short id should contain -1 suffix, got %q", short2)
|
||||
}
|
||||
}
|
||||
|
||||
func TestStableIDGenerator_NilReceiver(t *testing.T) {
|
||||
var gen *StableIDGenerator = nil
|
||||
id, short := gen.Next("test:kind", "part")
|
||||
|
||||
if id != "test:kind:000000000000" {
|
||||
t.Errorf("expected test:kind:000000000000, got %q", id)
|
||||
}
|
||||
if short != "000000000000" {
|
||||
t.Errorf("expected 000000000000, got %q", short)
|
||||
}
|
||||
}
|
||||
|
||||
func TestApplyAuthExcludedModelsMeta(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
auth *coreauth.Auth
|
||||
cfg *config.Config
|
||||
perKey []string
|
||||
authKind string
|
||||
wantHash bool
|
||||
wantKind string
|
||||
}{
|
||||
{
|
||||
name: "apikey with excluded models",
|
||||
auth: &coreauth.Auth{
|
||||
Provider: "gemini",
|
||||
Attributes: make(map[string]string),
|
||||
},
|
||||
cfg: &config.Config{},
|
||||
perKey: []string{"model-a", "model-b"},
|
||||
authKind: "apikey",
|
||||
wantHash: true,
|
||||
wantKind: "apikey",
|
||||
},
|
||||
{
|
||||
name: "oauth with provider excluded models",
|
||||
auth: &coreauth.Auth{
|
||||
Provider: "claude",
|
||||
Attributes: make(map[string]string),
|
||||
},
|
||||
cfg: &config.Config{
|
||||
OAuthExcludedModels: map[string][]string{
|
||||
"claude": {"claude-2.0"},
|
||||
},
|
||||
},
|
||||
perKey: nil,
|
||||
authKind: "oauth",
|
||||
wantHash: true,
|
||||
wantKind: "oauth",
|
||||
},
|
||||
{
|
||||
name: "nil auth",
|
||||
auth: nil,
|
||||
cfg: &config.Config{},
|
||||
},
|
||||
{
|
||||
name: "nil config",
|
||||
auth: &coreauth.Auth{Provider: "test"},
|
||||
cfg: nil,
|
||||
authKind: "apikey",
|
||||
},
|
||||
{
|
||||
name: "nil attributes initialized",
|
||||
auth: &coreauth.Auth{
|
||||
Provider: "gemini",
|
||||
Attributes: nil,
|
||||
},
|
||||
cfg: &config.Config{},
|
||||
perKey: []string{"model-x"},
|
||||
authKind: "apikey",
|
||||
wantHash: true,
|
||||
wantKind: "apikey",
|
||||
},
|
||||
{
|
||||
name: "apikey with duplicate excluded models",
|
||||
auth: &coreauth.Auth{
|
||||
Provider: "gemini",
|
||||
Attributes: make(map[string]string),
|
||||
},
|
||||
cfg: &config.Config{},
|
||||
perKey: []string{"model-a", "MODEL-A", "model-b", "model-a"},
|
||||
authKind: "apikey",
|
||||
wantHash: true,
|
||||
wantKind: "apikey",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
ApplyAuthExcludedModelsMeta(tt.auth, tt.cfg, tt.perKey, tt.authKind)
|
||||
|
||||
if tt.auth != nil && tt.cfg != nil {
|
||||
if tt.wantHash {
|
||||
if _, ok := tt.auth.Attributes["excluded_models_hash"]; !ok {
|
||||
t.Error("expected excluded_models_hash in attributes")
|
||||
}
|
||||
}
|
||||
if tt.wantKind != "" {
|
||||
if got := tt.auth.Attributes["auth_kind"]; got != tt.wantKind {
|
||||
t.Errorf("expected auth_kind=%s, got %s", tt.wantKind, got)
|
||||
}
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestAddConfigHeadersToAttrs(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
headers map[string]string
|
||||
attrs map[string]string
|
||||
want map[string]string
|
||||
}{
|
||||
{
|
||||
name: "basic headers",
|
||||
headers: map[string]string{
|
||||
"Authorization": "Bearer token",
|
||||
"X-Custom": "value",
|
||||
},
|
||||
attrs: map[string]string{"existing": "key"},
|
||||
want: map[string]string{
|
||||
"existing": "key",
|
||||
"header:Authorization": "Bearer token",
|
||||
"header:X-Custom": "value",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "empty headers",
|
||||
headers: map[string]string{},
|
||||
attrs: map[string]string{"existing": "key"},
|
||||
want: map[string]string{"existing": "key"},
|
||||
},
|
||||
{
|
||||
name: "nil headers",
|
||||
headers: nil,
|
||||
attrs: map[string]string{"existing": "key"},
|
||||
want: map[string]string{"existing": "key"},
|
||||
},
|
||||
{
|
||||
name: "nil attrs",
|
||||
headers: map[string]string{"key": "value"},
|
||||
attrs: nil,
|
||||
want: nil,
|
||||
},
|
||||
{
|
||||
name: "skip empty keys and values",
|
||||
headers: map[string]string{
|
||||
"": "value",
|
||||
"key": "",
|
||||
" ": "value",
|
||||
"valid": "valid-value",
|
||||
},
|
||||
attrs: make(map[string]string),
|
||||
want: map[string]string{
|
||||
"header:valid": "valid-value",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
addConfigHeadersToAttrs(tt.headers, tt.attrs)
|
||||
if !reflect.DeepEqual(tt.attrs, tt.want) {
|
||||
t.Errorf("expected %v, got %v", tt.want, tt.attrs)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
16
internal/watcher/synthesizer/interface.go
Normal file
16
internal/watcher/synthesizer/interface.go
Normal file
@@ -0,0 +1,16 @@
|
||||
// Package synthesizer provides auth synthesis strategies for the watcher package.
|
||||
// It implements the Strategy pattern to support multiple auth sources:
|
||||
// - ConfigSynthesizer: generates Auth entries from config API keys
|
||||
// - FileSynthesizer: generates Auth entries from OAuth JSON files
|
||||
package synthesizer
|
||||
|
||||
import (
|
||||
coreauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
|
||||
)
|
||||
|
||||
// AuthSynthesizer defines the interface for generating Auth entries from various sources.
|
||||
type AuthSynthesizer interface {
|
||||
// Synthesize generates Auth entries from the given context.
|
||||
// Returns a slice of Auth pointers and any error encountered.
|
||||
Synthesize(ctx *SynthesisContext) ([]*coreauth.Auth, error)
|
||||
}
|
||||
File diff suppressed because it is too large
Load Diff
1490
internal/watcher/watcher_test.go
Normal file
1490
internal/watcher/watcher_test.go
Normal file
File diff suppressed because it is too large
Load Diff
@@ -84,7 +84,8 @@ func (h *GeminiAPIHandler) GeminiGetHandler(c *gin.Context) {
|
||||
})
|
||||
return
|
||||
}
|
||||
switch request.Action {
|
||||
action := strings.TrimPrefix(request.Action, "/")
|
||||
switch action {
|
||||
case "gemini-3-pro-preview":
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"name": "models/gemini-3-pro-preview",
|
||||
@@ -189,7 +190,7 @@ func (h *GeminiAPIHandler) GeminiHandler(c *gin.Context) {
|
||||
})
|
||||
return
|
||||
}
|
||||
action := strings.Split(request.Action, ":")
|
||||
action := strings.Split(strings.TrimPrefix(request.Action, "/"), ":")
|
||||
if len(action) != 2 {
|
||||
c.JSON(http.StatusNotFound, handlers.ErrorResponse{
|
||||
Error: handlers.ErrorDetail{
|
||||
|
||||
@@ -5,6 +5,7 @@ package handlers
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strings"
|
||||
@@ -48,9 +49,6 @@ type BaseAPIHandler struct {
|
||||
|
||||
// Cfg holds the current application configuration.
|
||||
Cfg *config.SDKConfig
|
||||
|
||||
// OpenAICompatProviders is a list of provider names for OpenAI compatibility.
|
||||
OpenAICompatProviders []string
|
||||
}
|
||||
|
||||
// NewBaseAPIHandlers creates a new API handlers instance.
|
||||
@@ -62,11 +60,10 @@ type BaseAPIHandler struct {
|
||||
//
|
||||
// Returns:
|
||||
// - *BaseAPIHandler: A new API handlers instance
|
||||
func NewBaseAPIHandlers(cfg *config.SDKConfig, authManager *coreauth.Manager, openAICompatProviders []string) *BaseAPIHandler {
|
||||
func NewBaseAPIHandlers(cfg *config.SDKConfig, authManager *coreauth.Manager) *BaseAPIHandler {
|
||||
return &BaseAPIHandler{
|
||||
Cfg: cfg,
|
||||
AuthManager: authManager,
|
||||
OpenAICompatProviders: openAICompatProviders,
|
||||
Cfg: cfg,
|
||||
AuthManager: authManager,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -117,6 +114,16 @@ func (h *BaseAPIHandler) GetContextWithCancel(handler interfaces.APIHandler, c *
|
||||
newCtx = context.WithValue(newCtx, "handler", handler)
|
||||
return newCtx, func(params ...interface{}) {
|
||||
if h.Cfg.RequestLog && len(params) == 1 {
|
||||
if existing, exists := c.Get("API_RESPONSE"); exists {
|
||||
if existingBytes, ok := existing.([]byte); ok && len(bytes.TrimSpace(existingBytes)) > 0 {
|
||||
switch params[0].(type) {
|
||||
case error, string:
|
||||
cancel()
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
var payload []byte
|
||||
switch data := params[0].(type) {
|
||||
case []byte:
|
||||
@@ -331,30 +338,19 @@ func (h *BaseAPIHandler) getRequestDetails(modelName string) (providers []string
|
||||
// Resolve "auto" model to an actual available model first
|
||||
resolvedModelName := util.ResolveAutoModel(modelName)
|
||||
|
||||
providerName, extractedModelName, isDynamic := h.parseDynamicModel(resolvedModelName)
|
||||
|
||||
targetModelName := resolvedModelName
|
||||
if isDynamic {
|
||||
targetModelName = extractedModelName
|
||||
}
|
||||
|
||||
// Normalize the model name to handle dynamic thinking suffixes before determining the provider.
|
||||
normalizedModel, metadata = normalizeModelMetadata(targetModelName)
|
||||
normalizedModel, metadata = normalizeModelMetadata(resolvedModelName)
|
||||
|
||||
if isDynamic {
|
||||
providers = []string{providerName}
|
||||
} else {
|
||||
// For non-dynamic models, use the normalizedModel to get the provider name.
|
||||
providers = util.GetProviderName(normalizedModel)
|
||||
if len(providers) == 0 && metadata != nil {
|
||||
if originalRaw, ok := metadata[util.ThinkingOriginalModelMetadataKey]; ok {
|
||||
if originalModel, okStr := originalRaw.(string); okStr {
|
||||
originalModel = strings.TrimSpace(originalModel)
|
||||
if originalModel != "" && !strings.EqualFold(originalModel, normalizedModel) {
|
||||
if altProviders := util.GetProviderName(originalModel); len(altProviders) > 0 {
|
||||
providers = altProviders
|
||||
normalizedModel = originalModel
|
||||
}
|
||||
// Use the normalizedModel to get the provider name.
|
||||
providers = util.GetProviderName(normalizedModel)
|
||||
if len(providers) == 0 && metadata != nil {
|
||||
if originalRaw, ok := metadata[util.ThinkingOriginalModelMetadataKey]; ok {
|
||||
if originalModel, okStr := originalRaw.(string); okStr {
|
||||
originalModel = strings.TrimSpace(originalModel)
|
||||
if originalModel != "" && !strings.EqualFold(originalModel, normalizedModel) {
|
||||
if altProviders := util.GetProviderName(originalModel); len(altProviders) > 0 {
|
||||
providers = altProviders
|
||||
normalizedModel = originalModel
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -372,30 +368,6 @@ func (h *BaseAPIHandler) getRequestDetails(modelName string) (providers []string
|
||||
return providers, normalizedModel, metadata, nil
|
||||
}
|
||||
|
||||
func (h *BaseAPIHandler) parseDynamicModel(modelName string) (providerName, model string, isDynamic bool) {
|
||||
var providerPart, modelPart string
|
||||
for _, sep := range []string{"://"} {
|
||||
if parts := strings.SplitN(modelName, sep, 2); len(parts) == 2 {
|
||||
providerPart = parts[0]
|
||||
modelPart = parts[1]
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if providerPart == "" {
|
||||
return "", modelName, false
|
||||
}
|
||||
|
||||
// Check if the provider is a configured openai-compatibility provider
|
||||
for _, pName := range h.OpenAICompatProviders {
|
||||
if pName == providerPart {
|
||||
return providerPart, modelPart, true
|
||||
}
|
||||
}
|
||||
|
||||
return "", modelName, false
|
||||
}
|
||||
|
||||
func cloneBytes(src []byte) []byte {
|
||||
if len(src) == 0 {
|
||||
return nil
|
||||
@@ -437,12 +409,53 @@ func (h *BaseAPIHandler) WriteErrorResponse(c *gin.Context, msg *interfaces.Erro
|
||||
}
|
||||
}
|
||||
}
|
||||
c.Status(status)
|
||||
|
||||
errText := http.StatusText(status)
|
||||
if msg != nil && msg.Error != nil {
|
||||
_, _ = c.Writer.Write([]byte(msg.Error.Error()))
|
||||
} else {
|
||||
_, _ = c.Writer.Write([]byte(http.StatusText(status)))
|
||||
if v := strings.TrimSpace(msg.Error.Error()); v != "" {
|
||||
errText = v
|
||||
}
|
||||
}
|
||||
|
||||
// Prefer preserving upstream JSON error bodies when possible.
|
||||
buildJSONBody := func() []byte {
|
||||
trimmed := strings.TrimSpace(errText)
|
||||
if trimmed != "" && json.Valid([]byte(trimmed)) {
|
||||
return []byte(trimmed)
|
||||
}
|
||||
errType := "invalid_request_error"
|
||||
switch status {
|
||||
case http.StatusUnauthorized:
|
||||
errType = "authentication_error"
|
||||
case http.StatusForbidden:
|
||||
errType = "permission_error"
|
||||
case http.StatusTooManyRequests:
|
||||
errType = "rate_limit_error"
|
||||
default:
|
||||
if status >= http.StatusInternalServerError {
|
||||
errType = "server_error"
|
||||
}
|
||||
}
|
||||
payload, err := json.Marshal(ErrorResponse{
|
||||
Error: ErrorDetail{
|
||||
Message: errText,
|
||||
Type: errType,
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
return []byte(fmt.Sprintf(`{"error":{"message":%q,"type":"server_error"}}`, errText))
|
||||
}
|
||||
return payload
|
||||
}
|
||||
|
||||
body := buildJSONBody()
|
||||
c.Set("API_RESPONSE", bytes.Clone(body))
|
||||
|
||||
if !c.Writer.Written() {
|
||||
c.Writer.Header().Set("Content-Type", "application/json")
|
||||
}
|
||||
c.Status(status)
|
||||
_, _ = c.Writer.Write(body)
|
||||
}
|
||||
|
||||
func (h *BaseAPIHandler) LoggingAPIResponseError(ctx context.Context, err *interfaces.ErrorMessage) {
|
||||
|
||||
46
sdk/api/options.go
Normal file
46
sdk/api/options.go
Normal file
@@ -0,0 +1,46 @@
|
||||
// Package api exposes server option helpers for embedding CLIProxyAPI.
|
||||
//
|
||||
// It wraps internal server option types so external projects can configure the embedded
|
||||
// HTTP server without importing internal packages.
|
||||
package api
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
internalapi "github.com/router-for-me/CLIProxyAPI/v6/internal/api"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/sdk/api/handlers"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/sdk/config"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/sdk/logging"
|
||||
)
|
||||
|
||||
// ServerOption customises HTTP server construction.
|
||||
type ServerOption = internalapi.ServerOption
|
||||
|
||||
// WithMiddleware appends additional Gin middleware during server construction.
|
||||
func WithMiddleware(mw ...gin.HandlerFunc) ServerOption { return internalapi.WithMiddleware(mw...) }
|
||||
|
||||
// WithEngineConfigurator allows callers to mutate the Gin engine prior to middleware setup.
|
||||
func WithEngineConfigurator(fn func(*gin.Engine)) ServerOption {
|
||||
return internalapi.WithEngineConfigurator(fn)
|
||||
}
|
||||
|
||||
// WithRouterConfigurator appends a callback after default routes are registered.
|
||||
func WithRouterConfigurator(fn func(*gin.Engine, *handlers.BaseAPIHandler, *config.Config)) ServerOption {
|
||||
return internalapi.WithRouterConfigurator(fn)
|
||||
}
|
||||
|
||||
// WithLocalManagementPassword stores a runtime-only management password accepted for localhost requests.
|
||||
func WithLocalManagementPassword(password string) ServerOption {
|
||||
return internalapi.WithLocalManagementPassword(password)
|
||||
}
|
||||
|
||||
// WithKeepAliveEndpoint enables a keep-alive endpoint with the provided timeout and callback.
|
||||
func WithKeepAliveEndpoint(timeout time.Duration, onTimeout func()) ServerOption {
|
||||
return internalapi.WithKeepAliveEndpoint(timeout, onTimeout)
|
||||
}
|
||||
|
||||
// WithRequestLoggerFactory customises request logger creation.
|
||||
func WithRequestLoggerFactory(factory func(*config.Config, string) logging.RequestLogger) ServerOption {
|
||||
return internalapi.WithRequestLoggerFactory(factory)
|
||||
}
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user