Compare commits

...

29 Commits

Author SHA1 Message Date
Luis Pater
d2394b0be9 Update .goreleaser.yml to specify archive formats for different OS targets
- Added `tar.gz` as the default archive format.
- Introduced `zip` format override for Windows builds.
2025-08-08 14:28:02 +08:00
Luis Pater
ebcd4dbf3d Fix activation URL extraction logic and improve warning message formatting
- Corrected JSON path for error code and activation URL extraction in client error handling.
- Improved readability of the activation warning message with better spacing.
2025-08-05 23:58:43 +08:00
Luis Pater
1483c31c73 Refactor API handlers organization and simplify error response handling
- Modularized handlers into dedicated packages (`gemini`, `claude`, `cli`) for better structure.
- Centralized `ErrorResponse` and `ErrorDetail` types under `handlers` package for reuse.
- Updated all handlers to utilize the shared `ErrorResponse` model.
- Introduced specialization of handler structs (`GeminiAPIHandlers`, `ClaudeCodeAPIHandlers`, `GeminiCLIAPIHandlers`) for improved clarity and separation of concerns.
- Refactored `getClient` logic with additional properties and better state management.

Refactor `translator` package by modularizing code for `claude` and `gemini`

- Moved Claude-specific logic (`PrepareClaudeRequest`, `ConvertCliToClaude`) to `translator/claude/code`.
- Moved Gemini-specific logic (`FixCLIToolResponse`) to `translator/gemini/cli` for better package structure.
- Updated affected handler imports and method references.

Add comprehensive package-level documentation across key modules

- Introduced detailed package-level documentation for core modules: `auth`, `client`, `cmd`, `handlers`, `util`, `watcher`, `config`, `translator`, and `api`.
- Enhanced code readability and maintainability by clarifying the purpose and functionality of each package.
- Aligned documentation style and tone with existing codebase conventions.

Refactor API handlers and translator modules for improved clarity and consistency

- Standardized handler struct names (`GeminiAPIHandlers`, `ClaudeCodeAPIHandlers`, `GeminiCLIAPIHandlers`, `OpenAIAPIHandlers`) and updated related comments.
- Fixed unnecessary `else` blocks in streaming logic for cleaner error handling.
- Renamed variables for better readability (`responseIdResult` to `responseIDResult`, `activationUrl` to `activationURL`, etc.).
- Addressed minor inconsistencies in API handler comments and SSE header initialization.
- Improved modularization of `claude` and `gemini` translator components.

Standardize configuration field naming for consistency across modules

- Renamed `ProxyUrl` to `ProxyURL`, `ApiKeys` to `APIKeys`, and `ConfigQuotaExceeded` to `QuotaExceeded`.
- Updated all relevant references and comments in `config`, `auth`, `api`, `util`, and `watcher`.
- Ensured consistent casing for `GlAPIKey` debug logs.
2025-08-05 23:11:31 +08:00
Luis Pater
00f33f5f3a Enhance Gemini request handling for contents support and improve error logging
- Added conditional logic to process `contents` in Gemini request templates, ensuring fallback behavior.
- Introduced detailed debug logs for quota errors and request issues.
- Updated handling of `rawJson` to construct templates more dynamically.
2025-08-04 02:51:00 +08:00
Luis Pater
3c4dc07980 Add file watcher for dynamic configuration and client reloading
- Introduced `Watcher` for monitoring updates to the configuration file and authentication directory.
- Integrated file watching into `StartService` to handle dynamic changes without restarting.
- Enhanced API server and handlers to support client and configuration updates.
- Updated `.gitignore` to include `docs/` directory.
- Modified go dependencies to include `fsnotify` for the file watcher.
2025-08-02 16:15:56 +08:00
Luis Pater
3b4634e2dc Improve getClient logic with optional content generation flag
- Added `isGenerateContent` optional parameter to `getClient` for conditional client selection.
- Updated `gemini-handlers` to utilize the new parameter for enhanced control.
2025-07-27 02:30:08 +08:00
Luis Pater
00bd6a3e46 Update .goreleaser.yml to include config.example.yaml instead of config.yaml in release assets 2025-07-26 22:19:33 +08:00
Luis Pater
5812229d9b Add .gitignore and ignore config.yaml 2025-07-26 22:10:07 +08:00
Luis Pater
0b026933a7 Update example configuration file (config.example.yaml) 2025-07-26 22:08:25 +08:00
Luis Pater
3b2ab0d7bd Fix SSE headers initialization for geminiStreamGenerateContent and internalStreamGenerateContent
- Added conditional logic to properly initialize SSE headers only when `alt` is empty.
- Ensured headers like `Content-Type`, `Cache-Control`, and `Access-Control-Allow-Origin` are set for better compatibility.
2025-07-26 17:16:55 +08:00
Luis Pater
e64fa48823 Enhance Gemini request handling with fallback support for contents
- Added conditional logic to support `contents` as a fallback to `generateContentRequest`.
- Improved template construction and ensured proper cleanup of request fields.
- Introduced debug logging for troubleshooting request generation.
2025-07-26 17:04:14 +08:00
Luis Pater
beff9282f6 Fix alt parameter handling in URL construction
- Ensured `alt` parameter is only appended when non-empty.
- Added debug logging for constructed URLs.
2025-07-26 15:51:04 +08:00
Luis Pater
31a9e2d11f Add GeminiGetHandler, enhance Gemini functionality, and enable token counting
- Added `GeminiGetHandler` for handling GET requests with extended Gemini model support.
- Introduced `geminiCountTokens` function to calculate token usage.
- Refactored `APIRequest` and related methods to support `alt` parameter for enhanced flexibility.
- Updated routes and request processing to integrate new handler and functions.
2025-07-26 06:51:49 +08:00
Luis Pater
423faae3da Add GeminiModels handler and enhance API key validation
- Introduced `GeminiModels` handler to serve Gemini model information under `/v1beta/models`.
- Updated `AuthMiddleware` to validate API keys from query parameters for improved flexibility.
- Adjusted route to use the new handler for model retrieval.
2025-07-26 04:41:55 +08:00
Luis Pater
ead71fb7ef Improve error logging and add user guidance for issue reporting
- Added fatal log in `login.go` for Cloud AI API enablement check failures, prompting users to report issues.
- Enhanced error logging in `client.go` with warning messages directing users to copy and provide error details when creating issues.
2025-07-24 04:51:09 +08:00
Luis Pater
58b7afdf1e Enhance HTTP server with custom multiplexer in Auth flow
- Replaced default `http` handler with `http.ServeMux` for improved routing control.
- Refactored callback handling to utilize the custom multiplexer.
2025-07-23 05:09:05 +08:00
Luis Pater
c86545d7e1 Add Chinese README and update project files
- Introduced `README_CN.md` to provide detailed documentation in Chinese.
- Updated `.goreleaser.yml` to include the new README file in release assets.
- Enhanced `README.md` with a language toggle link for improved accessibility.
2025-07-21 11:23:13 +08:00
Luis Pater
f49a530c1a Refactor client handling and improve error responses
- Centralized client retrieval logic with `getClient` function for reduced redundancy.
- Simplified client rotation and error handling by removing excessive load balancing logic.
- Updated server address in `auth.go` to use dynamic binding (`:8085`).
2025-07-15 17:03:18 +08:00
Luis Pater
368796349e Add Docker support with CI/CD workflow and usage instructions
- Added `.github/workflows/docker-image.yml` for automated Docker image build and push on version tags.
- Created `Dockerfile` to containerize the application.
- Updated README with instructions for running the application using Docker.
2025-07-14 16:50:51 +08:00
Luis Pater
c601542f6f Add ClaudeMessages handler for SSE-compatible chat completions
- Introduced `ClaudeMessages` to handle Claude-compatible streaming chat completions.
- Implemented client rotation, quota management, and dynamic model name mapping for better load balancing and resource utilization.
- Enhanced response streaming with real-time chunking and Claude format conversion.
- Added error handling for quota exhaustion, client disconnections, and backend failures.
2025-07-11 13:53:09 +08:00
Luis Pater
3c0c61aaf1 Add Claude compatibility and enhance API handling
- Integrated Claude API compatibility in handlers, translators, and server routes.
- Introduced `/messages` endpoint and upgraded `AuthMiddleware` for `X-Api-Key` header.
- Improved streaming response handling with `ConvertCliToClaude` for SSE compatibility.
- Enhanced request processing and tool-response mapping in translators.
- Updated README to reflect Claude integration and clarify supported features.
2025-07-11 13:46:27 +08:00
Luis Pater
edeadfc389 Restrict CLI access to localhost and update README for Gemini compatibility
- Added localhost-only access restriction to `CLIHandler` for security.
- Updated README to reflect Gemini-compatible API and local access limitation notes.
2025-07-11 10:57:23 +08:00
Luis Pater
aa9fd057fe Add FixCLIToolResponse for enhanced function call-response mapping
- Introduced `FixCLIToolResponse` in `translator` to group function calls with corresponding responses.
- Updated Gemini handlers to integrate new function for improved response handling.
- Enhanced error handling in case response mapping fails.
2025-07-11 10:17:25 +08:00
Luis Pater
b3607d3981 Add Gemini-compatible API and improve error handling
- Introduced a new Gemini-compatible API with routes under `/v1beta`.
- Added `GeminiHandler` to manage `generateContent` and `streamGenerateContent` actions.
- Enhanced `AuthMiddleware` to support `X-Goog-Api-Key` header.
- Improved client metadata handling and added conditional project ID updates in API calls.
- Updated logging to debug raw API request payloads for better traceability.
2025-07-11 04:01:45 +08:00
Luis Pater
fa8d94971f Enhance response and request handling in translators
- Refactored response handling to process multiple content parts effectively.
- Improved `tool_calls` structure with unique ID generation and enhanced mapping logic.
- Simplified `SystemInstruction` and tool message parsing in requests for better accuracy.
- Enhanced handling of function calls and tool responses with improved data integration.
2025-07-10 22:26:04 +08:00
Luis Pater
ef68a97526 Refactor API handlers and proxy logic
- Centralized `getClient` logic into a dedicated function to reduce redundancy.
- Moved proxy initialization to a new utility function `SetProxy` in `internal/util/proxy.go`.
- Replaced `Internal` handler with `CLIHandler` in `server.go` for improved clarity and consistency.
- Removed unused functions and redundant HTTP client setup across the codebase for better maintainability.
2025-07-10 17:45:28 +08:00
Luis Pater
d880d1a1ea Set the http request header and update client metadata handling 2025-07-10 14:02:10 +08:00
Luis Pater
d4104214ed Updated README.md 2025-07-10 05:31:55 +08:00
Luis Pater
273e1d9cbe Add system instruction support and enhance internal API handlers
- Introduced `SystemInstruction` field in `PrepareRequest` and `GenerateContentRequest` for better message parsing.
- Updated `SendMessage` and `SendMessageStream` to handle system instructions in client API calls.
- Enhanced error handling and manual flushing logic in response flows.
- Added new internal API endpoints `/v1internal:generateContent` and `/v1internal:streamGenerateContent`.
- Improved proxy handling and transport logic in HTTP client initialization.
2025-07-10 05:16:54 +08:00
32 changed files with 3247 additions and 684 deletions

42
.github/workflows/docker-image.yml vendored Normal file
View File

@@ -0,0 +1,42 @@
name: docker-image
on:
push:
tags:
- v*
env:
APP_NAME: CLIProxyAPI
DOCKERHUB_REPO: eceasy/cli-proxy-api
jobs:
docker:
runs-on: ubuntu-latest
steps:
- name: Checkout
uses: actions/checkout@v4
- name: Set up QEMU
uses: docker/setup-qemu-action@v3
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@v3
- name: Login to DockerHub
uses: docker/login-action@v3
with:
username: ${{ secrets.DOCKERHUB_USERNAME }}
password: ${{ secrets.DOCKERHUB_TOKEN }}
- name: Generate App Version
run: echo APP_VERSION=`git describe --tags --always` >> $GITHUB_ENV
- name: Build and push
uses: docker/build-push-action@v6
with:
context: .
platforms: |
linux/amd64
linux/arm64
push: true
build-args: |
APP_NAME=${{ env.APP_NAME }}
APP_VERSION=${{ env.APP_VERSION }}
tags: |
${{ env.DOCKERHUB_REPO }}:latest
${{ env.DOCKERHUB_REPO }}:${{ env.APP_VERSION }}

2
.gitignore vendored Normal file
View File

@@ -0,0 +1,2 @@
config.yaml
docs/

View File

@@ -11,7 +11,12 @@ builds:
binary: cli-proxy-api
archives:
- id: "cli-proxy-api"
format: tar.gz
format_overrides:
- goos: windows
format: zip
files:
- LICENSE
- README.md
- config.yaml
- README_CN.md
- config.example.yaml

23
Dockerfile Normal file
View File

@@ -0,0 +1,23 @@
FROM golang:1.24-alpine AS builder
WORKDIR /app
COPY go.mod go.sum ./
RUN go mod download
COPY . .
RUN CGO_ENABLED=0 GOOS=linux go build -o ./CLIProxyAPI ./cmd/server/
FROM alpine:3.22.0
RUN mkdir /CLIProxyAPI
COPY --from=builder ./app/CLIProxyAPI /CLIProxyAPI/CLIProxyAPI
WORKDIR /CLIProxyAPI
EXPOSE 8317
CMD ["./CLIProxyAPI"]

View File

@@ -1,16 +1,19 @@
# CLI Proxy API
A proxy server that provides an OpenAI-compatible API interface for CLI. This allows you to use CLI models with tools and libraries designed for the OpenAI API.
English | [中文](README_CN.md)
A proxy server that provides an OpenAI/Gemini/Claude compatible API interface for CLI. This allows you to use CLI models with tools and libraries designed for the OpenAI/Gemini/Claude API.
## Features
- OpenAI-compatible API endpoints for CLI models
- OpenAI/Gemini/Claude compatible API endpoints for CLI models
- Support for both streaming and non-streaming responses
- Function calling/tools support
- Multimodal input support (text and images)
- Multiple account support with load balancing
- Simple CLI authentication flow
- Support for Generative Language API Key
- Support Gemini CLI with multiple account load balancing
## Installation
@@ -135,7 +138,7 @@ console.log(response.choices[0].message.content);
- gemini-2.5-pro
- gemini-2.5-flash
- And various preview versions
- And it automates switching to various preview versions
## Configuration
@@ -147,14 +150,17 @@ The server uses a YAML configuration file (`config.yaml`) located in the project
### Configuration Options
| Parameter | Type | Default | Description |
|-------------------------------|----------|--------------------|----------------------------------------------------------------------------------------------|
| `port` | integer | 8317 | The port number on which the server will listen |
| `auth-dir` | string | "~/.cli-proxy-api" | Directory where authentication tokens are stored. Supports using `~` for home directory |
| `proxy-url` | string | "" | Proxy url, support socks5/http/https protocol, example: socks5://user:pass@192.168.1.1:1080/ |
| `debug` | boolean | false | Enable debug mode for verbose logging |
| `api-keys` | string[] | [] | List of API keys that can be used to authenticate requests |
| `generative-language-api-key` | string[] | [] | List of Generative Language API keys |
| Parameter | Type | Default | Description |
|---------------------------------------|----------|--------------------|----------------------------------------------------------------------------------------------|
| `port` | integer | 8317 | The port number on which the server will listen |
| `auth-dir` | string | "~/.cli-proxy-api" | Directory where authentication tokens are stored. Supports using `~` for home directory |
| `proxy-url` | string | "" | Proxy url, support socks5/http/https protocol, example: socks5://user:pass@192.168.1.1:1080/ |
| `quota-exceeded` | object | {} | Configuration for handling quota exceeded |
| `quota-exceeded.switch-project` | boolean | true | Whether to automatically switch to another project when a quota is exceeded |
| `quota-exceeded.switch-preview-model` | boolean | true | Whether to automatically switch to a preview model when a quota is exceeded |
| `debug` | boolean | false | Enable debug mode for verbose logging |
| `api-keys` | string[] | [] | List of API keys that can be used to authenticate requests |
| `generative-language-api-key` | string[] | [] | List of Generative Language API keys |
### Example Configuration File
@@ -168,10 +174,25 @@ auth-dir: "~/.cli-proxy-api"
# Enable debug logging
debug: false
# Proxy url, support socks5/http/https protocol, example: socks5://user:pass@192.168.1.1:1080/
proxy-url: ""
# Quota exceeded behavior
quota-exceeded:
switch-project: true # Whether to automatically switch to another project when a quota is exceeded
switch-preview-model: true # Whether to automatically switch to a preview model when a quota is exceeded
# API keys for authentication
api-keys:
- "your-api-key-1"
- "your-api-key-2"
# API keys for official Generative Language API
generative-language-api-key:
- "AIzaSy...01"
- "AIzaSy...02"
- "AIzaSy...03"
- "AIzaSy...04"
```
### Authentication Directory
@@ -186,6 +207,38 @@ The `api-keys` parameter allows you to define a list of API keys that can be use
Authorization: Bearer your-api-key-1
```
### Official Generative Language API
The `generative-language-api-key` parameter allows you to define a list of API keys that can be used to authenticate requests to the official Generative Language API.
## Gemini CLI with multiple account load balancing
Start CLI Proxy API server, and then set the `CODE_ASSIST_ENDPOINT` environment variable to the URL of the CLI Proxy API server.
```bash
export CODE_ASSIST_ENDPOINT="http://127.0.0.1:8317"
```
The server will relay the `loadCodeAssist`, `onboardUser`, and `countTokens` requests. And automatically load balance the text generation requests between the multiple accounts.
> [!NOTE]
> This feature only allows local access because I couldn't find a way to authenticate the requests.
> I hardcoded `127.0.0.1` into the load balancing.
## Run with Docker
Run the following command to login:
```bash
docker run --rm -p 8085:8085 -v /path/to/your/config.yaml:/CLIProxyAPI/config.yaml -v /path/to/your/auth-dir:/root/.cli-proxy-api eceasy/cli-proxy-api:latest /CLIProxyAPI/CLIProxyAPI --login
```
Run the following command to start the server:
```bash
docker run --rm -p 8317:8317 -v /path/to/your/config.yaml:/CLIProxyAPI/config.yaml -v /path/to/your/auth-dir:/root/.cli-proxy-api eceasy/cli-proxy-api:latest
```
## Contributing
Contributions are welcome! Please feel free to submit a Pull Request.

254
README_CN.md Normal file
View File

@@ -0,0 +1,254 @@
# CLI 代理 API
[English](README.md) | 中文
一个为 CLI 提供 OpenAI/Gemini/Claude 兼容 API 接口的代理服务器。这让您可以摆脱终端界面的束缚,将 Gemini 的强大能力以 API 的形式轻松接入到任何您喜爱的客户端或应用中。
## 功能特性
- 为 CLI 模型提供 OpenAI/Gemini/Claude 兼容的 API 端点
- 支持流式和非流式响应
- 函数调用/工具支持
- 多模态输入支持(文本和图像)
- 多账户支持与负载均衡
- 简单的 CLI 身份验证流程
- 支持 Gemini AIStudio API 密钥
- 支持 Gemini CLI 多账户轮询
## 安装
### 前置要求
- Go 1.24 或更高版本
- 有权访问 CLI 模型的 Google 账户
### 从源码构建
1. 克隆仓库:
```bash
git clone https://github.com/luispater/CLIProxyAPI.git
cd CLIProxyAPI
```
2. 构建应用程序:
```bash
go build -o cli-proxy-api ./cmd/server
```
## 使用方法
### 身份验证
在使用 API 之前,您需要使用 Google 账户进行身份验证:
```bash
./cli-proxy-api --login
```
如果您是旧版 gemini code 用户,可能需要指定项目 ID
```bash
./cli-proxy-api --login --project_id <your_project_id>
```
### 启动服务器
身份验证完成后,启动服务器:
```bash
./cli-proxy-api
```
默认情况下,服务器在端口 8317 上运行。
### API 端点
#### 列出模型
```
GET http://localhost:8317/v1/models
```
#### 聊天补全
```
POST http://localhost:8317/v1/chat/completions
```
请求体示例:
```json
{
"model": "gemini-2.5-pro",
"messages": [
{
"role": "user",
"content": "你好,你好吗?"
}
],
"stream": true
}
```
### 与 OpenAI 库一起使用
您可以通过将基础 URL 设置为本地服务器来将此代理与任何 OpenAI 兼容的库一起使用:
#### Python使用 OpenAI 库)
```python
from openai import OpenAI
client = OpenAI(
api_key="dummy", # 不使用但必需
base_url="http://localhost:8317/v1"
)
response = client.chat.completions.create(
model="gemini-2.5-pro",
messages=[
{"role": "user", "content": "你好,你好吗?"}
]
)
print(response.choices[0].message.content)
```
#### JavaScript/TypeScript
```javascript
import OpenAI from 'openai';
const openai = new OpenAI({
apiKey: 'dummy', // 不使用但必需
baseURL: 'http://localhost:8317/v1',
});
const response = await openai.chat.completions.create({
model: 'gemini-2.5-pro',
messages: [
{ role: 'user', content: '你好,你好吗?' }
],
});
console.log(response.choices[0].message.content);
```
## 支持的模型
- gemini-2.5-pro
- gemini-2.5-flash
- 并且自动切换到之前的预览版本
## 配置
服务器默认使用位于项目根目录的 YAML 配置文件(`config.yaml`)。您可以使用 `--config` 标志指定不同的配置文件路径:
```bash
./cli-proxy --config /path/to/your/config.yaml
```
### 配置选项
| 参数 | 类型 | 默认值 | 描述 |
|---------------------------------------|----------|--------------------|------------------------------------------------------------------------|
| `port` | integer | 8317 | 服务器监听的端口号 |
| `auth-dir` | string | "~/.cli-proxy-api" | 存储身份验证令牌的目录。支持使用 `~` 表示主目录 |
| `proxy-url` | string | "" | 代理 URL支持 socks5/http/https 协议示例socks5://user:pass@192.168.1.1:1080/ |
| `quota-exceeded` | object | {} | 处理配额超限的配置 |
| `quota-exceeded.switch-project` | boolean | true | 当配额超限时是否自动切换到另一个项目 |
| `quota-exceeded.switch-preview-model` | boolean | true | 当配额超限时是否自动切换到预览模型 |
| `debug` | boolean | false | 启用调试模式以进行详细日志记录 |
| `api-keys` | string[] | [] | 可用于验证请求的 API 密钥列表 |
| `generative-language-api-key` | string[] | [] | 生成式语言 API 密钥列表 |
### 配置文件示例
```yaml
# 服务器端口
port: 8317
# 身份验证目录(支持 ~ 表示主目录)
auth-dir: "~/.cli-proxy-api"
# 启用调试日志
debug: false
# 代理 URL支持 socks5/http/https 协议示例socks5://user:pass@192.168.1.1:1080/
proxy-url: ""
# 配额超限行为
quota-exceeded:
switch-project: true # 当配额超限时是否自动切换到另一个项目
switch-preview-model: true # 当配额超限时是否自动切换到预览模型
# 用于本地身份验证的 API 密钥
api-keys:
- "your-api-key-1"
- "your-api-key-2"
# AIStduio Gemini API 的 API 密钥
generative-language-api-key:
- "AIzaSy...01"
- "AIzaSy...02"
- "AIzaSy...03"
- "AIzaSy...04"
```
### 身份验证目录
`auth-dir` 参数指定身份验证令牌的存储位置。当您运行登录命令时,应用程序将在此目录中创建包含 Google 账户身份验证令牌的 JSON 文件。多个账户可用于轮询。
### API 密钥
`api-keys` 参数允许您定义可用于验证对代理服务器请求的 API 密钥列表。在向 API 发出请求时,您可以在 `Authorization` 标头中包含其中一个密钥:
```
Authorization: Bearer your-api-key-1
```
### 官方生成式语言 API
`generative-language-api-key` 参数允许您定义可用于验证对官方 AIStudio Gemini API 请求的 API 密钥列表。
## Gemini CLI 多账户负载均衡
启动 CLI 代理 API 服务器,然后将 `CODE_ASSIST_ENDPOINT` 环境变量设置为 CLI 代理 API 服务器的 URL。
```bash
export CODE_ASSIST_ENDPOINT="http://127.0.0.1:8317"
```
服务器将中继 `loadCodeAssist`、`onboardUser` 和 `countTokens` 请求。并自动在多个账户之间轮询文本生成请求。
> [!NOTE]
> 此功能仅允许本地访问,因为找不到一个可以验证请求的方法。
> 所以只能强制只有 `127.0.0.1` 可以访问。
## 使用 Docker 运行
运行以下命令进行登录:
```bash
docker run --rm -p 8085:8085 -v /path/to/your/config.yaml:/CLIProxyAPI/config.yaml -v /path/to/your/auth-dir:/root/.cli-proxy-api eceasy/cli-proxy-api:latest /CLIProxyAPI/CLIProxyAPI --login
```
运行以下命令启动服务器:
```bash
docker run --rm -p 8317:8317 -v /path/to/your/config.yaml:/CLIProxyAPI/config.yaml -v /path/to/your/auth-dir:/root/.cli-proxy-api eceasy/cli-proxy-api:latest
```
## 贡献
欢迎贡献!请随时提交 Pull Request。
1. Fork 仓库
2. 创建您的功能分支(`git checkout -b feature/amazing-feature`
3. 提交您的更改(`git commit -m 'Add some amazing feature'`
4. 推送到分支(`git push origin feature/amazing-feature`
5. 打开 Pull Request
## 许可证
此项目根据 MIT 许可证授权 - 有关详细信息,请参阅 [LICENSE](LICENSE) 文件。

View File

@@ -1,3 +1,6 @@
// Package main provides the entry point for the CLI Proxy API server.
// This server acts as a proxy that provides OpenAI/Gemini/Claude compatible API interfaces
// for CLI models, allowing CLI models to be used with tools and libraries designed for standard AI APIs.
package main
import (
@@ -63,14 +66,17 @@ func main() {
var wd string
// Load configuration from the specified path or the default path.
var configFilePath string
if configPath != "" {
configFilePath = configPath
cfg, err = config.LoadConfig(configPath)
} else {
wd, err = os.Getwd()
if err != nil {
log.Fatalf("failed to get working directory: %v", err)
}
cfg, err = config.LoadConfig(path.Join(wd, "config.yaml"))
configFilePath = path.Join(wd, "config.yaml")
cfg, err = config.LoadConfig(configFilePath)
}
if err != nil {
log.Fatalf("failed to load config: %v", err)
@@ -102,6 +108,6 @@ func main() {
if login {
cmd.DoLogin(cfg, projectID)
} else {
cmd.StartService(cfg)
cmd.StartService(cfg, configFilePath)
}
}

3
go.mod
View File

@@ -8,6 +8,7 @@ require (
github.com/skratchdot/open-golang v0.0.0-20200116055534-eef842397966
github.com/tidwall/gjson v1.18.0
github.com/tidwall/sjson v1.2.5
golang.org/x/net v0.37.1-0.20250305215238-2914f4677317
golang.org/x/oauth2 v0.30.0
gopkg.in/yaml.v3 v3.0.1
)
@@ -18,6 +19,7 @@ require (
github.com/bytedance/sonic/loader v0.1.1 // indirect
github.com/cloudwego/base64x v0.1.4 // indirect
github.com/cloudwego/iasm v0.2.0 // indirect
github.com/fsnotify/fsnotify v1.9.0 // indirect
github.com/gabriel-vasile/mimetype v1.4.3 // indirect
github.com/gin-contrib/sse v0.1.0 // indirect
github.com/go-playground/locales v0.14.1 // indirect
@@ -37,7 +39,6 @@ require (
github.com/ugorji/go/codec v1.2.12 // indirect
golang.org/x/arch v0.8.0 // indirect
golang.org/x/crypto v0.36.0 // indirect
golang.org/x/net v0.37.1-0.20250305215238-2914f4677317 // indirect
golang.org/x/sys v0.31.0 // indirect
golang.org/x/text v0.23.0 // indirect
google.golang.org/protobuf v1.34.1 // indirect

2
go.sum
View File

@@ -11,6 +11,8 @@ github.com/cloudwego/iasm v0.2.0/go.mod h1:8rXZaNYT2n95jn+zTI1sDr+IgcD2GVs0nlbbQ
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/fsnotify/fsnotify v1.9.0 h1:2Ml+OJNzbYCTzsxtv8vKSFD9PbJjmhYF14k/jKC7S9k=
github.com/fsnotify/fsnotify v1.9.0/go.mod h1:8jBTzvmWwFyi3Pb8djgCCO5IBqzKJ/Jwo8TRcHyHii0=
github.com/gabriel-vasile/mimetype v1.4.3 h1:in2uUcidCuFcDKtdcBxlR0rJ1+fsokWf+uqxgUFjbI0=
github.com/gabriel-vasile/mimetype v1.4.3/go.mod h1:d8uq/6HKRL6CGdk+aubisF/M5GcPfT7nKyLpA0lbSSk=
github.com/gin-contrib/sse v0.1.0 h1:Y/yl/+YNO8GZSjAhjMsSuLt29uWRFHdHYUb5lYOV9qE=

View File

@@ -1,422 +0,0 @@
package api
import (
"context"
"fmt"
"github.com/luispater/CLIProxyAPI/internal/api/translator"
"github.com/luispater/CLIProxyAPI/internal/client"
"github.com/luispater/CLIProxyAPI/internal/config"
log "github.com/sirupsen/logrus"
"github.com/tidwall/gjson"
"net/http"
"sync"
"time"
"github.com/gin-gonic/gin"
)
var (
mutex = &sync.Mutex{}
lastUsedClientIndex = 0
)
// APIHandlers contains the handlers for API endpoints.
// It holds a pool of clients to interact with the backend service.
type APIHandlers struct {
cliClients []*client.Client
cfg *config.Config
}
// NewAPIHandlers creates a new API handlers instance.
// It takes a slice of clients and a debug flag as input.
func NewAPIHandlers(cliClients []*client.Client, cfg *config.Config) *APIHandlers {
return &APIHandlers{
cliClients: cliClients,
cfg: cfg,
}
}
// Models handles the /v1/models endpoint.
// It returns a hardcoded list of available AI models.
func (h *APIHandlers) Models(c *gin.Context) {
c.JSON(http.StatusOK, gin.H{
"data": []map[string]any{
{
"id": "gemini-2.5-pro-preview-05-06",
"object": "model",
"version": "2.5-preview-05-06",
"name": "Gemini 2.5 Pro Preview 05-06",
"description": "Preview release (May 6th, 2025) of Gemini 2.5 Pro",
"context_length": 1048576,
"max_completion_tokens": 65536,
"supported_parameters": []string{
"tools",
"temperature",
"top_p",
"top_k",
},
"temperature": 1,
"topP": 0.95,
"topK": 64,
"maxTemperature": 2,
"thinking": true,
},
{
"id": "gemini-2.5-pro-preview-06-05",
"object": "model",
"version": "2.5-preview-06-05",
"name": "Gemini 2.5 Pro Preview 06-05",
"description": "Preview release (June 5th, 2025) of Gemini 2.5 Pro",
"context_length": 1048576,
"max_completion_tokens": 65536,
"supported_parameters": []string{
"tools",
"temperature",
"top_p",
"top_k",
},
"temperature": 1,
"topP": 0.95,
"topK": 64,
"maxTemperature": 2,
"thinking": true,
},
{
"id": "gemini-2.5-pro",
"object": "model",
"version": "2.5",
"name": "Gemini 2.5 Pro",
"description": "Stable release (June 17th, 2025) of Gemini 2.5 Pro",
"context_length": 1048576,
"max_completion_tokens": 65536,
"supported_parameters": []string{
"tools",
"temperature",
"top_p",
"top_k",
},
"temperature": 1,
"topP": 0.95,
"topK": 64,
"maxTemperature": 2,
"thinking": true,
},
{
"id": "gemini-2.5-flash-preview-04-17",
"object": "model",
"version": "2.5-preview-04-17",
"name": "Gemini 2.5 Flash Preview 04-17",
"description": "Preview release (April 17th, 2025) of Gemini 2.5 Flash",
"context_length": 1048576,
"max_completion_tokens": 65536,
"supported_parameters": []string{
"tools",
"temperature",
"top_p",
"top_k",
},
"temperature": 1,
"topP": 0.95,
"topK": 64,
"maxTemperature": 2,
"thinking": true,
},
{
"id": "gemini-2.5-flash-preview-05-20",
"object": "model",
"version": "2.5-preview-05-20",
"name": "Gemini 2.5 Flash Preview 05-20",
"description": "Preview release (April 17th, 2025) of Gemini 2.5 Flash",
"context_length": 1048576,
"max_completion_tokens": 65536,
"supported_parameters": []string{
"tools",
"temperature",
"top_p",
"top_k",
},
"temperature": 1,
"topP": 0.95,
"topK": 64,
"maxTemperature": 2,
"thinking": true,
},
{
"id": "gemini-2.5-flash",
"object": "model",
"version": "001",
"name": "Gemini 2.5 Flash",
"description": "Stable version of Gemini 2.5 Flash, our mid-size multimodal model that supports up to 1 million tokens, released in June of 2025.",
"context_length": 1048576,
"max_completion_tokens": 65536,
"supported_parameters": []string{
"tools",
"temperature",
"top_p",
"top_k",
},
"temperature": 1,
"topP": 0.95,
"topK": 64,
"maxTemperature": 2,
"thinking": true,
},
},
})
}
// ChatCompletions handles the /v1/chat/completions endpoint.
// It determines whether the request is for a streaming or non-streaming response
// and calls the appropriate handler.
func (h *APIHandlers) ChatCompletions(c *gin.Context) {
rawJson, err := c.GetRawData()
// If data retrieval fails, return a 400 Bad Request error.
if err != nil {
c.JSON(http.StatusBadRequest, ErrorResponse{
Error: ErrorDetail{
Message: fmt.Sprintf("Invalid request: %v", err),
Type: "invalid_request_error",
},
})
return
}
// Check if the client requested a streaming response.
streamResult := gjson.GetBytes(rawJson, "stream")
if streamResult.Type == gjson.True {
h.handleStreamingResponse(c, rawJson)
} else {
h.handleNonStreamingResponse(c, rawJson)
}
}
// handleNonStreamingResponse handles non-streaming chat completion responses.
// It selects a client from the pool, sends the request, and aggregates the response
// before sending it back to the client.
func (h *APIHandlers) handleNonStreamingResponse(c *gin.Context, rawJson []byte) {
c.Header("Content-Type", "application/json")
// Handle streaming manually
flusher, ok := c.Writer.(http.Flusher)
if !ok {
c.JSON(http.StatusInternalServerError, ErrorResponse{
Error: ErrorDetail{
Message: "Streaming not supported",
Type: "server_error",
},
})
return
}
modelName, contents, tools := translator.PrepareRequest(rawJson)
cliCtx, cliCancel := context.WithCancel(context.Background())
var cliClient *client.Client
defer func() {
if cliClient != nil {
cliClient.RequestMutex.Unlock()
}
}()
for {
// Lock the mutex to update the last used client index
mutex.Lock()
startIndex := lastUsedClientIndex
currentIndex := (startIndex + 1) % len(h.cliClients)
lastUsedClientIndex = currentIndex
mutex.Unlock()
// Reorder the client to start from the last used index
reorderedClients := make([]*client.Client, 0)
for i := 0; i < len(h.cliClients); i++ {
cliClient = h.cliClients[(startIndex+1+i)%len(h.cliClients)]
if cliClient.IsModelQuotaExceeded(modelName) {
log.Debugf("Model %s is quota exceeded for account %s, project id: %s", modelName, cliClient.GetEmail(), cliClient.GetProjectID())
cliClient = nil
continue
}
reorderedClients = append(reorderedClients, cliClient)
}
if len(reorderedClients) == 0 {
c.Status(429)
_, _ = fmt.Fprint(c.Writer, fmt.Sprintf(`{"error":{"code":429,"message":"All the models of '%s' are quota exceeded","status":"RESOURCE_EXHAUSTED"}}`, modelName))
flusher.Flush()
cliCancel()
return
}
locked := false
for i := 0; i < len(reorderedClients); i++ {
cliClient = reorderedClients[i]
if cliClient.RequestMutex.TryLock() {
locked = true
break
}
}
if !locked {
cliClient = h.cliClients[0]
cliClient.RequestMutex.Lock()
}
isGlAPIKey := false
if glAPIKey := cliClient.GetGenerativeLanguageAPIKey(); glAPIKey != "" {
log.Debugf("Request use generative language API Key: %s", glAPIKey)
isGlAPIKey = true
} else {
log.Debugf("Request use account: %s, project id: %s", cliClient.GetEmail(), cliClient.GetProjectID())
}
resp, err := cliClient.SendMessage(cliCtx, rawJson, modelName, contents, tools)
if err != nil {
if err.StatusCode == 429 && h.cfg.QuotaExceeded.SwitchProject {
continue
} else {
c.Status(err.StatusCode)
_, _ = fmt.Fprint(c.Writer, err.Error.Error())
flusher.Flush()
cliCancel()
}
break
} else {
openAIFormat := translator.ConvertCliToOpenAINonStream(resp, time.Now().Unix(), isGlAPIKey)
if openAIFormat != "" {
_, _ = fmt.Fprintf(c.Writer, "data: %s\n\n", openAIFormat)
flusher.Flush()
}
cliCancel()
break
}
}
}
// handleStreamingResponse handles streaming responses
func (h *APIHandlers) handleStreamingResponse(c *gin.Context, rawJson []byte) {
c.Header("Content-Type", "text/event-stream")
c.Header("Cache-Control", "no-cache")
c.Header("Connection", "keep-alive")
c.Header("Access-Control-Allow-Origin", "*")
// Get the http.Flusher interface to manually flush the response.
flusher, ok := c.Writer.(http.Flusher)
if !ok {
c.JSON(http.StatusInternalServerError, ErrorResponse{
Error: ErrorDetail{
Message: "Streaming not supported",
Type: "server_error",
},
})
return
}
// Prepare the request for the backend client.
modelName, contents, tools := translator.PrepareRequest(rawJson)
cliCtx, cliCancel := context.WithCancel(context.Background())
var cliClient *client.Client
defer func() {
// Ensure the client's mutex is unlocked on function exit.
if cliClient != nil {
cliClient.RequestMutex.Unlock()
}
}()
outLoop:
for {
// Lock the mutex to update the last used client index
mutex.Lock()
startIndex := lastUsedClientIndex
currentIndex := (startIndex + 1) % len(h.cliClients)
lastUsedClientIndex = currentIndex
mutex.Unlock()
// Reorder the client to start from the last used index
reorderedClients := make([]*client.Client, 0)
for i := 0; i < len(h.cliClients); i++ {
cliClient = h.cliClients[(startIndex+1+i)%len(h.cliClients)]
if cliClient.IsModelQuotaExceeded(modelName) {
log.Debugf("Model %s is quota exceeded for account %s, project id: %s", modelName, cliClient.GetEmail(), cliClient.GetProjectID())
cliClient = nil
continue
}
reorderedClients = append(reorderedClients, cliClient)
}
if len(reorderedClients) == 0 {
c.Status(429)
_, _ = fmt.Fprint(c.Writer, fmt.Sprintf(`{"error":{"code":429,"message":"All the models of '%s' are quota exceeded","status":"RESOURCE_EXHAUSTED"}}`, modelName))
flusher.Flush()
cliCancel()
return
}
locked := false
for i := 0; i < len(reorderedClients); i++ {
cliClient = reorderedClients[i]
if cliClient.RequestMutex.TryLock() {
locked = true
break
}
}
if !locked {
cliClient = h.cliClients[0]
cliClient.RequestMutex.Lock()
}
isGlAPIKey := false
if glAPIKey := cliClient.GetGenerativeLanguageAPIKey(); glAPIKey != "" {
log.Debugf("Request use generative language API Key: %s", glAPIKey)
isGlAPIKey = true
} else {
log.Debugf("Request use account: %s, project id: %s", cliClient.GetEmail(), cliClient.GetProjectID())
}
// Send the message and receive response chunks and errors via channels.
respChan, errChan := cliClient.SendMessageStream(cliCtx, rawJson, modelName, contents, tools)
hasFirstResponse := false
for {
select {
// Handle client disconnection.
case <-c.Request.Context().Done():
if c.Request.Context().Err().Error() == "context canceled" {
log.Debugf("Client disconnected: %v", c.Request.Context().Err())
cliCancel() // Cancel the backend request.
return
}
// Process incoming response chunks.
case chunk, okStream := <-respChan:
if !okStream {
// Stream is closed, send the final [DONE] message.
_, _ = fmt.Fprintf(c.Writer, "data: [DONE]\n\n")
flusher.Flush()
cliCancel()
return
} else {
// Convert the chunk to OpenAI format and send it to the client.
hasFirstResponse = true
openAIFormat := translator.ConvertCliToOpenAI(chunk, time.Now().Unix(), isGlAPIKey)
if openAIFormat != "" {
_, _ = fmt.Fprintf(c.Writer, "data: %s\n\n", openAIFormat)
flusher.Flush()
}
}
// Handle errors from the backend.
case err, okError := <-errChan:
if okError {
if err.StatusCode == 429 && h.cfg.QuotaExceeded.SwitchProject {
continue outLoop
} else {
c.Status(err.StatusCode)
_, _ = fmt.Fprint(c.Writer, err.Error.Error())
flusher.Flush()
cliCancel()
}
return
}
// Send a keep-alive signal to the client.
case <-time.After(500 * time.Millisecond):
if hasFirstResponse {
_, _ = c.Writer.Write([]byte(": CLI-PROXY-API PROCESSING\n\n"))
flusher.Flush()
}
}
}
}
}

View File

@@ -0,0 +1,208 @@
// Package claude provides HTTP handlers for Claude API code-related functionality.
// This package implements Claude-compatible streaming chat completions with sophisticated
// client rotation and quota management systems to ensure high availability and optimal
// resource utilization across multiple backend clients. It handles request translation
// between Claude API format and the underlying Gemini backend, providing seamless
// API compatibility while maintaining robust error handling and connection management.
package claude
import (
"context"
"fmt"
"github.com/gin-gonic/gin"
"github.com/luispater/CLIProxyAPI/internal/api/handlers"
"github.com/luispater/CLIProxyAPI/internal/api/translator/claude/code"
"github.com/luispater/CLIProxyAPI/internal/client"
log "github.com/sirupsen/logrus"
"net/http"
"strings"
"time"
)
// ClaudeCodeAPIHandlers contains the handlers for Claude API endpoints.
// It holds a pool of clients to interact with the backend service.
type ClaudeCodeAPIHandlers struct {
*handlers.APIHandlers
}
// NewClaudeCodeAPIHandlers creates a new Claude API handlers instance.
// It takes an APIHandlers instance as input and returns a ClaudeCodeAPIHandlers.
func NewClaudeCodeAPIHandlers(apiHandlers *handlers.APIHandlers) *ClaudeCodeAPIHandlers {
return &ClaudeCodeAPIHandlers{
APIHandlers: apiHandlers,
}
}
// ClaudeMessages handles Claude-compatible streaming chat completions.
// This function implements a sophisticated client rotation and quota management system
// to ensure high availability and optimal resource utilization across multiple backend clients.
func (h *ClaudeCodeAPIHandlers) ClaudeMessages(c *gin.Context) {
// Extract raw JSON data from the incoming request
rawJSON, err := c.GetRawData()
// If data retrieval fails, return a 400 Bad Request error.
if err != nil {
c.JSON(http.StatusBadRequest, handlers.ErrorResponse{
Error: handlers.ErrorDetail{
Message: fmt.Sprintf("Invalid request: %v", err),
Type: "invalid_request_error",
},
})
return
}
// Set up Server-Sent Events (SSE) headers for streaming response
// These headers are essential for maintaining a persistent connection
// and enabling real-time streaming of chat completions
c.Header("Content-Type", "text/event-stream")
c.Header("Cache-Control", "no-cache")
c.Header("Connection", "keep-alive")
c.Header("Access-Control-Allow-Origin", "*")
// Get the http.Flusher interface to manually flush the response.
// This is crucial for streaming as it allows immediate sending of data chunks
flusher, ok := c.Writer.(http.Flusher)
if !ok {
c.JSON(http.StatusInternalServerError, handlers.ErrorResponse{
Error: handlers.ErrorDetail{
Message: "Streaming not supported",
Type: "server_error",
},
})
return
}
// Parse and prepare the Claude request, extracting model name, system instructions,
// conversation contents, and available tools from the raw JSON
modelName, systemInstruction, contents, tools := code.PrepareClaudeRequest(rawJSON)
// Map Claude model names to corresponding Gemini models
// This allows the proxy to handle Claude API calls using Gemini backend
if modelName == "claude-sonnet-4-20250514" {
modelName = "gemini-2.5-pro"
} else if modelName == "claude-3-5-haiku-20241022" {
modelName = "gemini-2.5-flash"
}
// Create a cancellable context for the backend client request
// This allows proper cleanup and cancellation of ongoing requests
cliCtx, cliCancel := context.WithCancel(context.Background())
var cliClient *client.Client
defer func() {
// Ensure the client's mutex is unlocked on function exit.
// This prevents deadlocks and ensures proper resource cleanup
if cliClient != nil {
cliClient.RequestMutex.Unlock()
}
}()
// Main client rotation loop with quota management
// This loop implements a sophisticated load balancing and failover mechanism
outLoop:
for {
var errorResponse *client.ErrorMessage
cliClient, errorResponse = h.GetClient(modelName)
if errorResponse != nil {
c.Status(errorResponse.StatusCode)
_, _ = fmt.Fprint(c.Writer, errorResponse.Error)
flusher.Flush()
cliCancel()
return
}
// Determine the authentication method being used by the selected client
// This affects how responses are formatted and logged
isGlAPIKey := false
if glAPIKey := cliClient.GetGenerativeLanguageAPIKey(); glAPIKey != "" {
log.Debugf("Request use generative language API Key: %s", glAPIKey)
isGlAPIKey = true
} else {
log.Debugf("Request use account: %s, project id: %s", cliClient.GetEmail(), cliClient.GetProjectID())
}
// Initiate streaming communication with the backend client
// This returns two channels: one for response chunks and one for errors
includeThoughts := false
if userAgent, hasKey := c.Request.Header["User-Agent"]; hasKey {
includeThoughts = !strings.Contains(userAgent[0], "claude-cli")
}
respChan, errChan := cliClient.SendMessageStream(cliCtx, rawJSON, modelName, systemInstruction, contents, tools, includeThoughts)
// Track response state for proper Claude format conversion
hasFirstResponse := false
responseType := 0
responseIndex := 0
// Main streaming loop - handles multiple concurrent events using Go channels
// This select statement manages four different types of events simultaneously
for {
select {
// Case 1: Handle client disconnection
// Detects when the HTTP client has disconnected and cleans up resources
case <-c.Request.Context().Done():
if c.Request.Context().Err().Error() == "context canceled" {
log.Debugf("Client disconnected: %v", c.Request.Context().Err())
cliCancel() // Cancel the backend request to prevent resource leaks
return
}
// Case 2: Process incoming response chunks from the backend
// This handles the actual streaming data from the AI model
case chunk, okStream := <-respChan:
if !okStream {
// Stream has ended - send the final message_stop event
// This follows the Claude API specification for stream termination
_, _ = c.Writer.Write([]byte(`event: message_stop`))
_, _ = c.Writer.Write([]byte("\n"))
_, _ = c.Writer.Write([]byte(`data: {"type":"message_stop"}`))
_, _ = c.Writer.Write([]byte("\n\n\n"))
flusher.Flush()
cliCancel()
return
}
// Convert the backend response to Claude-compatible format
// This translation layer ensures API compatibility
claudeFormat := code.ConvertCliToClaude(chunk, isGlAPIKey, hasFirstResponse, &responseType, &responseIndex)
if claudeFormat != "" {
_, _ = c.Writer.Write([]byte(claudeFormat))
flusher.Flush() // Immediately send the chunk to the client
}
hasFirstResponse = true
// Case 3: Handle errors from the backend
// This manages various error conditions and implements retry logic
case errInfo, okError := <-errChan:
if okError {
// Special handling for quota exceeded errors
// If configured, attempt to switch to a different project/client
if errInfo.StatusCode == 429 && h.Cfg.QuotaExceeded.SwitchProject {
continue outLoop // Restart the client selection process
} else {
// Forward other errors directly to the client
c.Status(errInfo.StatusCode)
_, _ = fmt.Fprint(c.Writer, errInfo.Error.Error())
flusher.Flush()
cliCancel()
}
return
}
// Case 4: Send periodic keep-alive signals
// Prevents connection timeouts during long-running requests
case <-time.After(500 * time.Millisecond):
if hasFirstResponse {
// Send a ping event to maintain the connection
// This is especially important for slow AI model responses
output := "event: ping\n"
output = output + `data: {"type": "ping"}`
output = output + "\n\n\n"
_, _ = c.Writer.Write([]byte(output))
flusher.Flush()
}
}
}
}
}

View File

@@ -0,0 +1,268 @@
// Package cli provides HTTP handlers for Gemini CLI API functionality.
// This package implements handlers that process CLI-specific requests for Gemini API operations,
// including content generation and streaming content generation endpoints.
// The handlers restrict access to localhost only and manage communication with the backend service.
package cli
import (
"bytes"
"context"
"fmt"
"github.com/gin-gonic/gin"
"github.com/luispater/CLIProxyAPI/internal/api/handlers"
"github.com/luispater/CLIProxyAPI/internal/client"
"github.com/luispater/CLIProxyAPI/internal/util"
log "github.com/sirupsen/logrus"
"github.com/tidwall/gjson"
"github.com/tidwall/sjson"
"io"
"net/http"
"strings"
"time"
)
// GeminiCLIAPIHandlers contains the handlers for Gemini CLI API endpoints.
// It holds a pool of clients to interact with the backend service.
type GeminiCLIAPIHandlers struct {
*handlers.APIHandlers
}
// NewGeminiCLIAPIHandlers creates a new Gemini CLI API handlers instance.
// It takes an APIHandlers instance as input and returns a GeminiCLIAPIHandlers.
func NewGeminiCLIAPIHandlers(apiHandlers *handlers.APIHandlers) *GeminiCLIAPIHandlers {
return &GeminiCLIAPIHandlers{
APIHandlers: apiHandlers,
}
}
// CLIHandler handles CLI-specific requests for Gemini API operations.
// It restricts access to localhost only and routes requests to appropriate internal handlers.
func (h *GeminiCLIAPIHandlers) CLIHandler(c *gin.Context) {
if !strings.HasPrefix(c.Request.RemoteAddr, "127.0.0.1:") {
c.JSON(http.StatusForbidden, handlers.ErrorResponse{
Error: handlers.ErrorDetail{
Message: "CLI reply only allow local access",
Type: "forbidden",
},
})
return
}
rawJSON, _ := c.GetRawData()
requestRawURI := c.Request.URL.Path
if requestRawURI == "/v1internal:generateContent" {
h.internalGenerateContent(c, rawJSON)
} else if requestRawURI == "/v1internal:streamGenerateContent" {
h.internalStreamGenerateContent(c, rawJSON)
} else {
reqBody := bytes.NewBuffer(rawJSON)
req, err := http.NewRequest("POST", fmt.Sprintf("https://cloudcode-pa.googleapis.com%s", c.Request.URL.RequestURI()), reqBody)
if err != nil {
c.JSON(http.StatusBadRequest, handlers.ErrorResponse{
Error: handlers.ErrorDetail{
Message: fmt.Sprintf("Invalid request: %v", err),
Type: "invalid_request_error",
},
})
return
}
for key, value := range c.Request.Header {
req.Header[key] = value
}
httpClient, err := util.SetProxy(h.Cfg, &http.Client{})
if err != nil {
log.Fatalf("set proxy failed: %v", err)
}
resp, err := httpClient.Do(req)
if err != nil {
c.JSON(http.StatusBadRequest, handlers.ErrorResponse{
Error: handlers.ErrorDetail{
Message: fmt.Sprintf("Invalid request: %v", err),
Type: "invalid_request_error",
},
})
return
}
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
defer func() {
if err = resp.Body.Close(); err != nil {
log.Printf("warn: failed to close response body: %v", err)
}
}()
bodyBytes, _ := io.ReadAll(resp.Body)
c.JSON(http.StatusBadRequest, handlers.ErrorResponse{
Error: handlers.ErrorDetail{
Message: string(bodyBytes),
Type: "invalid_request_error",
},
})
return
}
defer func() {
_ = resp.Body.Close()
}()
for key, value := range resp.Header {
c.Header(key, value[0])
}
output, err := io.ReadAll(resp.Body)
if err != nil {
log.Errorf("Failed to read response body: %v", err)
return
}
_, _ = c.Writer.Write(output)
}
}
func (h *GeminiCLIAPIHandlers) internalStreamGenerateContent(c *gin.Context, rawJSON []byte) {
alt := h.GetAlt(c)
if alt == "" {
c.Header("Content-Type", "text/event-stream")
c.Header("Cache-Control", "no-cache")
c.Header("Connection", "keep-alive")
c.Header("Access-Control-Allow-Origin", "*")
}
// Get the http.Flusher interface to manually flush the response.
flusher, ok := c.Writer.(http.Flusher)
if !ok {
c.JSON(http.StatusInternalServerError, handlers.ErrorResponse{
Error: handlers.ErrorDetail{
Message: "Streaming not supported",
Type: "server_error",
},
})
return
}
modelResult := gjson.GetBytes(rawJSON, "model")
modelName := modelResult.String()
cliCtx, cliCancel := context.WithCancel(context.Background())
var cliClient *client.Client
defer func() {
// Ensure the client's mutex is unlocked on function exit.
if cliClient != nil {
cliClient.RequestMutex.Unlock()
}
}()
outLoop:
for {
var errorResponse *client.ErrorMessage
cliClient, errorResponse = h.GetClient(modelName)
if errorResponse != nil {
c.Status(errorResponse.StatusCode)
_, _ = fmt.Fprint(c.Writer, errorResponse.Error)
flusher.Flush()
cliCancel()
return
}
if glAPIKey := cliClient.GetGenerativeLanguageAPIKey(); glAPIKey != "" {
log.Debugf("Request use generative language API Key: %s", glAPIKey)
} else {
log.Debugf("Request use account: %s, project id: %s", cliClient.GetEmail(), cliClient.GetProjectID())
}
// Send the message and receive response chunks and errors via channels.
respChan, errChan := cliClient.SendRawMessageStream(cliCtx, rawJSON, "")
hasFirstResponse := false
for {
select {
// Handle client disconnection.
case <-c.Request.Context().Done():
if c.Request.Context().Err().Error() == "context canceled" {
log.Debugf("Client disconnected: %v", c.Request.Context().Err())
cliCancel() // Cancel the backend request.
return
}
// Process incoming response chunks.
case chunk, okStream := <-respChan:
if !okStream {
cliCancel()
return
}
hasFirstResponse = true
if cliClient.GetGenerativeLanguageAPIKey() != "" {
chunk, _ = sjson.SetRawBytes(chunk, "response", chunk)
}
_, _ = c.Writer.Write([]byte("data: "))
_, _ = c.Writer.Write(chunk)
_, _ = c.Writer.Write([]byte("\n\n"))
flusher.Flush()
// Handle errors from the backend.
case err, okError := <-errChan:
if okError {
if err.StatusCode == 429 && h.Cfg.QuotaExceeded.SwitchProject {
continue outLoop
} else {
c.Status(err.StatusCode)
_, _ = fmt.Fprint(c.Writer, err.Error.Error())
flusher.Flush()
cliCancel()
}
return
}
// Send a keep-alive signal to the client.
case <-time.After(500 * time.Millisecond):
if hasFirstResponse {
_, _ = c.Writer.Write([]byte("\n"))
flusher.Flush()
}
}
}
}
}
func (h *GeminiCLIAPIHandlers) internalGenerateContent(c *gin.Context, rawJSON []byte) {
c.Header("Content-Type", "application/json")
modelResult := gjson.GetBytes(rawJSON, "model")
modelName := modelResult.String()
cliCtx, cliCancel := context.WithCancel(context.Background())
var cliClient *client.Client
defer func() {
if cliClient != nil {
cliClient.RequestMutex.Unlock()
}
}()
for {
var errorResponse *client.ErrorMessage
cliClient, errorResponse = h.GetClient(modelName)
if errorResponse != nil {
c.Status(errorResponse.StatusCode)
_, _ = fmt.Fprint(c.Writer, errorResponse.Error)
cliCancel()
return
}
if glAPIKey := cliClient.GetGenerativeLanguageAPIKey(); glAPIKey != "" {
log.Debugf("Request use generative language API Key: %s", glAPIKey)
} else {
log.Debugf("Request use account: %s, project id: %s", cliClient.GetEmail(), cliClient.GetProjectID())
}
resp, err := cliClient.SendRawMessage(cliCtx, rawJSON, "")
if err != nil {
if err.StatusCode == 429 && h.Cfg.QuotaExceeded.SwitchProject {
continue
} else {
c.Status(err.StatusCode)
_, _ = c.Writer.Write([]byte(err.Error.Error()))
cliCancel()
}
break
} else {
_, _ = c.Writer.Write(resp)
cliCancel()
break
}
}
}

View File

@@ -0,0 +1,437 @@
// Package gemini provides HTTP handlers for Gemini API endpoints.
// This package implements handlers for managing Gemini model operations including
// model listing, content generation, streaming content generation, and token counting.
// It serves as a proxy layer between clients and the Gemini backend service,
// handling request translation, client management, and response processing.
package gemini
import (
"context"
"fmt"
"github.com/gin-gonic/gin"
"github.com/luispater/CLIProxyAPI/internal/api/handlers"
"github.com/luispater/CLIProxyAPI/internal/api/translator/gemini/cli"
"github.com/luispater/CLIProxyAPI/internal/client"
log "github.com/sirupsen/logrus"
"github.com/tidwall/gjson"
"github.com/tidwall/sjson"
"net/http"
"strings"
"time"
)
// GeminiAPIHandlers contains the handlers for Gemini API endpoints.
// It holds a pool of clients to interact with the backend service.
type GeminiAPIHandlers struct {
*handlers.APIHandlers
}
// NewGeminiAPIHandlers creates a new Gemini API handlers instance.
// It takes an APIHandlers instance as input and returns a GeminiAPIHandlers.
func NewGeminiAPIHandlers(apiHandlers *handlers.APIHandlers) *GeminiAPIHandlers {
return &GeminiAPIHandlers{
APIHandlers: apiHandlers,
}
}
// GeminiModels handles the Gemini models listing endpoint.
// It returns a JSON response containing available Gemini models and their specifications.
func (h *GeminiAPIHandlers) GeminiModels(c *gin.Context) {
c.Status(http.StatusOK)
c.Header("Content-Type", "application/json; charset=UTF-8")
_, _ = c.Writer.Write([]byte(`{"models":[{"name":"models/gemini-2.5-flash","version":"001","displayName":"Gemini `))
_, _ = c.Writer.Write([]byte(`2.5 Flash","description":"Stable version of Gemini 2.5 Flash, our mid-size multimod`))
_, _ = c.Writer.Write([]byte(`al model that supports up to 1 million tokens, released in June of 2025.","inputTok`))
_, _ = c.Writer.Write([]byte(`enLimit":1048576,"outputTokenLimit":65536,"supportedGenerationMethods":["generateCo`))
_, _ = c.Writer.Write([]byte(`ntent","countTokens","createCachedContent","batchGenerateContent"],"temperature":1,`))
_, _ = c.Writer.Write([]byte(`"topP":0.95,"topK":64,"maxTemperature":2,"thinking":true},{"name":"models/gemini-2.`))
_, _ = c.Writer.Write([]byte(`5-pro","version":"2.5","displayName":"Gemini 2.5 Pro","description":"Stable release`))
_, _ = c.Writer.Write([]byte(` (June 17th, 2025) of Gemini 2.5 Pro","inputTokenLimit":1048576,"outputTokenLimit":`))
_, _ = c.Writer.Write([]byte(`65536,"supportedGenerationMethods":["generateContent","countTokens","createCachedCo`))
_, _ = c.Writer.Write([]byte(`ntent","batchGenerateContent"],"temperature":1,"topP":0.95,"topK":64,"maxTemperatur`))
_, _ = c.Writer.Write([]byte(`e":2,"thinking":true}],"nextPageToken":""}`))
}
// GeminiGetHandler handles GET requests for specific Gemini model information.
// It returns detailed information about a specific Gemini model based on the action parameter.
func (h *GeminiAPIHandlers) GeminiGetHandler(c *gin.Context) {
var request struct {
Action string `uri:"action" binding:"required"`
}
if err := c.ShouldBindUri(&request); err != nil {
c.JSON(http.StatusBadRequest, handlers.ErrorResponse{
Error: handlers.ErrorDetail{
Message: fmt.Sprintf("Invalid request: %v", err),
Type: "invalid_request_error",
},
})
return
}
if request.Action == "gemini-2.5-pro" {
c.Status(http.StatusOK)
c.Header("Content-Type", "application/json; charset=UTF-8")
_, _ = c.Writer.Write([]byte(`{"name":"models/gemini-2.5-pro","version":"2.5","displayName":"Gemini 2.5 Pro",`))
_, _ = c.Writer.Write([]byte(`"description":"Stable release (June 17th, 2025) of Gemini 2.5 Pro","inputTokenL`))
_, _ = c.Writer.Write([]byte(`imit":1048576,"outputTokenLimit":65536,"supportedGenerationMethods":["generateC`))
_, _ = c.Writer.Write([]byte(`ontent","countTokens","createCachedContent","batchGenerateContent"],"temperatur`))
_, _ = c.Writer.Write([]byte(`e":1,"topP":0.95,"topK":64,"maxTemperature":2,"thinking":true}`))
} else if request.Action == "gemini-2.5-flash" {
c.Status(http.StatusOK)
c.Header("Content-Type", "application/json; charset=UTF-8")
_, _ = c.Writer.Write([]byte(`{"name":"models/gemini-2.5-flash","version":"001","displayName":"Gemini 2.5 Fla`))
_, _ = c.Writer.Write([]byte(`sh","description":"Stable version of Gemini 2.5 Flash, our mid-size multimodal `))
_, _ = c.Writer.Write([]byte(`model that supports up to 1 million tokens, released in June of 2025.","inputTo`))
_, _ = c.Writer.Write([]byte(`kenLimit":1048576,"outputTokenLimit":65536,"supportedGenerationMethods":["gener`))
_, _ = c.Writer.Write([]byte(`ateContent","countTokens","createCachedContent","batchGenerateContent"],"temper`))
_, _ = c.Writer.Write([]byte(`ature":1,"topP":0.95,"topK":64,"maxTemperature":2,"thinking":true}`))
} else {
c.Status(http.StatusNotFound)
_, _ = c.Writer.Write([]byte(
`{"error":{"message":"Not Found","code":404,"status":"NOT_FOUND"}}`,
))
}
}
// GeminiHandler handles POST requests for Gemini API operations.
// It routes requests to appropriate handlers based on the action parameter (model:method format).
func (h *GeminiAPIHandlers) GeminiHandler(c *gin.Context) {
var request struct {
Action string `uri:"action" binding:"required"`
}
if err := c.ShouldBindUri(&request); err != nil {
c.JSON(http.StatusBadRequest, handlers.ErrorResponse{
Error: handlers.ErrorDetail{
Message: fmt.Sprintf("Invalid request: %v", err),
Type: "invalid_request_error",
},
})
return
}
action := strings.Split(request.Action, ":")
if len(action) != 2 {
c.JSON(http.StatusNotFound, handlers.ErrorResponse{
Error: handlers.ErrorDetail{
Message: fmt.Sprintf("%s not found.", c.Request.URL.Path),
Type: "invalid_request_error",
},
})
return
}
modelName := action[0]
method := action[1]
rawJSON, _ := c.GetRawData()
rawJSON, _ = sjson.SetBytes(rawJSON, "model", []byte(modelName))
if method == "generateContent" {
h.geminiGenerateContent(c, rawJSON)
} else if method == "streamGenerateContent" {
h.geminiStreamGenerateContent(c, rawJSON)
} else if method == "countTokens" {
h.geminiCountTokens(c, rawJSON)
}
}
func (h *GeminiAPIHandlers) geminiStreamGenerateContent(c *gin.Context, rawJSON []byte) {
alt := h.GetAlt(c)
if alt == "" {
c.Header("Content-Type", "text/event-stream")
c.Header("Cache-Control", "no-cache")
c.Header("Connection", "keep-alive")
c.Header("Access-Control-Allow-Origin", "*")
}
// Get the http.Flusher interface to manually flush the response.
flusher, ok := c.Writer.(http.Flusher)
if !ok {
c.JSON(http.StatusInternalServerError, handlers.ErrorResponse{
Error: handlers.ErrorDetail{
Message: "Streaming not supported",
Type: "server_error",
},
})
return
}
modelResult := gjson.GetBytes(rawJSON, "model")
modelName := modelResult.String()
cliCtx, cliCancel := context.WithCancel(context.Background())
var cliClient *client.Client
defer func() {
// Ensure the client's mutex is unlocked on function exit.
if cliClient != nil {
cliClient.RequestMutex.Unlock()
}
}()
outLoop:
for {
var errorResponse *client.ErrorMessage
cliClient, errorResponse = h.GetClient(modelName)
if errorResponse != nil {
c.Status(errorResponse.StatusCode)
_, _ = fmt.Fprint(c.Writer, errorResponse.Error)
flusher.Flush()
cliCancel()
return
}
template := ""
parsed := gjson.Parse(string(rawJSON))
contents := parsed.Get("request.contents")
if contents.Exists() {
template = string(rawJSON)
} else {
template = `{"project":"","request":{},"model":""}`
template, _ = sjson.SetRaw(template, "request", string(rawJSON))
template, _ = sjson.Set(template, "model", gjson.Get(template, "request.model").String())
template, _ = sjson.Delete(template, "request.model")
}
template, errFixCLIToolResponse := cli.FixCLIToolResponse(template)
if errFixCLIToolResponse != nil {
c.JSON(http.StatusInternalServerError, handlers.ErrorResponse{
Error: handlers.ErrorDetail{
Message: errFixCLIToolResponse.Error(),
Type: "server_error",
},
})
cliCancel()
return
}
systemInstructionResult := gjson.Get(template, "request.system_instruction")
if systemInstructionResult.Exists() {
template, _ = sjson.SetRaw(template, "request.systemInstruction", systemInstructionResult.Raw)
template, _ = sjson.Delete(template, "request.system_instruction")
}
rawJSON = []byte(template)
if glAPIKey := cliClient.GetGenerativeLanguageAPIKey(); glAPIKey != "" {
log.Debugf("Request use generative language API Key: %s", glAPIKey)
} else {
log.Debugf("Request use account: %s, project id: %s", cliClient.GetEmail(), cliClient.GetProjectID())
}
// Send the message and receive response chunks and errors via channels.
respChan, errChan := cliClient.SendRawMessageStream(cliCtx, rawJSON, alt)
for {
select {
// Handle client disconnection.
case <-c.Request.Context().Done():
if c.Request.Context().Err().Error() == "context canceled" {
log.Debugf("Client disconnected: %v", c.Request.Context().Err())
cliCancel() // Cancel the backend request.
return
}
// Process incoming response chunks.
case chunk, okStream := <-respChan:
if !okStream {
cliCancel()
return
}
if cliClient.GetGenerativeLanguageAPIKey() == "" {
if alt == "" {
responseResult := gjson.GetBytes(chunk, "response")
if responseResult.Exists() {
chunk = []byte(responseResult.Raw)
}
} else {
chunkTemplate := "[]"
responseResult := gjson.ParseBytes(chunk)
if responseResult.IsArray() {
responseResultItems := responseResult.Array()
for i := 0; i < len(responseResultItems); i++ {
responseResultItem := responseResultItems[i]
if responseResultItem.Get("response").Exists() {
chunkTemplate, _ = sjson.SetRaw(chunkTemplate, "-1", responseResultItem.Get("response").Raw)
}
}
}
chunk = []byte(chunkTemplate)
}
}
if alt == "" {
_, _ = c.Writer.Write([]byte("data: "))
_, _ = c.Writer.Write(chunk)
_, _ = c.Writer.Write([]byte("\n\n"))
} else {
_, _ = c.Writer.Write(chunk)
}
flusher.Flush()
// Handle errors from the backend.
case err, okError := <-errChan:
if okError {
if err.StatusCode == 429 && h.Cfg.QuotaExceeded.SwitchProject {
log.Debugf("quota exceeded, switch client")
continue outLoop
} else {
log.Debugf("error code :%d, error: %v", err.StatusCode, err.Error.Error())
c.Status(err.StatusCode)
_, _ = fmt.Fprint(c.Writer, err.Error.Error())
flusher.Flush()
cliCancel()
}
return
}
// Send a keep-alive signal to the client.
case <-time.After(500 * time.Millisecond):
}
}
}
}
func (h *GeminiAPIHandlers) geminiCountTokens(c *gin.Context, rawJSON []byte) {
c.Header("Content-Type", "application/json")
alt := h.GetAlt(c)
// orgrawJSON := rawJSON
modelResult := gjson.GetBytes(rawJSON, "model")
modelName := modelResult.String()
cliCtx, cliCancel := context.WithCancel(context.Background())
var cliClient *client.Client
defer func() {
if cliClient != nil {
cliClient.RequestMutex.Unlock()
}
}()
for {
var errorResponse *client.ErrorMessage
cliClient, errorResponse = h.GetClient(modelName, false)
if errorResponse != nil {
c.Status(errorResponse.StatusCode)
_, _ = fmt.Fprint(c.Writer, errorResponse.Error)
cliCancel()
return
}
if glAPIKey := cliClient.GetGenerativeLanguageAPIKey(); glAPIKey != "" {
log.Debugf("Request use generative language API Key: %s", glAPIKey)
} else {
log.Debugf("Request use account: %s, project id: %s", cliClient.GetEmail(), cliClient.GetProjectID())
template := `{"request":{}}`
if gjson.GetBytes(rawJSON, "generateContentRequest").Exists() {
template, _ = sjson.SetRaw(template, "request", gjson.GetBytes(rawJSON, "generateContentRequest").Raw)
template, _ = sjson.Delete(template, "generateContentRequest")
} else if gjson.GetBytes(rawJSON, "contents").Exists() {
template, _ = sjson.SetRaw(template, "request.contents", gjson.GetBytes(rawJSON, "contents").Raw)
template, _ = sjson.Delete(template, "contents")
}
rawJSON = []byte(template)
}
resp, err := cliClient.SendRawTokenCount(cliCtx, rawJSON, alt)
if err != nil {
if err.StatusCode == 429 && h.Cfg.QuotaExceeded.SwitchProject {
continue
} else {
c.Status(err.StatusCode)
_, _ = c.Writer.Write([]byte(err.Error.Error()))
cliCancel()
// log.Debugf(err.Error.Error())
// log.Debugf(string(rawJSON))
// log.Debugf(string(orgrawJSON))
}
break
} else {
if cliClient.GetGenerativeLanguageAPIKey() == "" {
responseResult := gjson.GetBytes(resp, "response")
if responseResult.Exists() {
resp = []byte(responseResult.Raw)
}
}
_, _ = c.Writer.Write(resp)
cliCancel()
break
}
}
}
func (h *GeminiAPIHandlers) geminiGenerateContent(c *gin.Context, rawJSON []byte) {
c.Header("Content-Type", "application/json")
alt := h.GetAlt(c)
modelResult := gjson.GetBytes(rawJSON, "model")
modelName := modelResult.String()
cliCtx, cliCancel := context.WithCancel(context.Background())
var cliClient *client.Client
defer func() {
if cliClient != nil {
cliClient.RequestMutex.Unlock()
}
}()
for {
var errorResponse *client.ErrorMessage
cliClient, errorResponse = h.GetClient(modelName)
if errorResponse != nil {
c.Status(errorResponse.StatusCode)
_, _ = fmt.Fprint(c.Writer, errorResponse.Error)
cliCancel()
return
}
template := ""
parsed := gjson.Parse(string(rawJSON))
contents := parsed.Get("request.contents")
if contents.Exists() {
template = string(rawJSON)
} else {
template = `{"project":"","request":{},"model":""}`
template, _ = sjson.SetRaw(template, "request", string(rawJSON))
template, _ = sjson.Set(template, "model", gjson.Get(template, "request.model").String())
template, _ = sjson.Delete(template, "request.model")
}
template, errFixCLIToolResponse := cli.FixCLIToolResponse(template)
if errFixCLIToolResponse != nil {
c.JSON(http.StatusInternalServerError, handlers.ErrorResponse{
Error: handlers.ErrorDetail{
Message: errFixCLIToolResponse.Error(),
Type: "server_error",
},
})
cliCancel()
return
}
systemInstructionResult := gjson.Get(template, "request.system_instruction")
if systemInstructionResult.Exists() {
template, _ = sjson.SetRaw(template, "request.systemInstruction", systemInstructionResult.Raw)
template, _ = sjson.Delete(template, "request.system_instruction")
}
rawJSON = []byte(template)
if glAPIKey := cliClient.GetGenerativeLanguageAPIKey(); glAPIKey != "" {
log.Debugf("Request use generative language API Key: %s", glAPIKey)
} else {
log.Debugf("Request use account: %s, project id: %s", cliClient.GetEmail(), cliClient.GetProjectID())
}
resp, err := cliClient.SendRawMessage(cliCtx, rawJSON, alt)
if err != nil {
if err.StatusCode == 429 && h.Cfg.QuotaExceeded.SwitchProject {
continue
} else {
c.Status(err.StatusCode)
_, _ = c.Writer.Write([]byte(err.Error.Error()))
cliCancel()
}
break
} else {
if cliClient.GetGenerativeLanguageAPIKey() == "" {
responseResult := gjson.GetBytes(resp, "response")
if responseResult.Exists() {
resp = []byte(responseResult.Raw)
}
}
_, _ = c.Writer.Write(resp)
cliCancel()
break
}
}
}

View File

@@ -0,0 +1,122 @@
// Package handlers provides core API handler functionality for the CLI Proxy API server.
// It includes common types, client management, load balancing, and error handling
// shared across all API endpoint handlers (OpenAI, Claude, Gemini).
package handlers
import (
"fmt"
"github.com/gin-gonic/gin"
"github.com/luispater/CLIProxyAPI/internal/client"
"github.com/luispater/CLIProxyAPI/internal/config"
log "github.com/sirupsen/logrus"
"sync"
)
// ErrorResponse represents a standard error response format for the API.
// It contains a single ErrorDetail field.
type ErrorResponse struct {
Error ErrorDetail `json:"error"`
}
// ErrorDetail provides specific information about an error that occurred.
// It includes a human-readable message, an error type, and an optional error code.
type ErrorDetail struct {
// A human-readable message providing more details about the error.
Message string `json:"message"`
// The type of error that occurred (e.g., "invalid_request_error").
Type string `json:"type"`
// A short code identifying the error, if applicable.
Code string `json:"code,omitempty"`
}
// APIHandlers contains the handlers for API endpoints.
// It holds a pool of clients to interact with the backend service.
type APIHandlers struct {
CliClients []*client.Client
Cfg *config.Config
Mutex *sync.Mutex
LastUsedClientIndex int
}
// NewAPIHandlers creates a new API handlers instance.
// It takes a slice of clients and a debug flag as input.
func NewAPIHandlers(cliClients []*client.Client, cfg *config.Config) *APIHandlers {
return &APIHandlers{
CliClients: cliClients,
Cfg: cfg,
Mutex: &sync.Mutex{},
LastUsedClientIndex: 0,
}
}
// UpdateClients updates the handlers' client list and configuration
func (h *APIHandlers) UpdateClients(clients []*client.Client, cfg *config.Config) {
h.CliClients = clients
h.Cfg = cfg
}
// GetClient returns an available client from the pool using round-robin load balancing.
// It checks for quota limits and tries to find an unlocked client for immediate use.
// The modelName parameter is used to check quota status for specific models.
func (h *APIHandlers) GetClient(modelName string, isGenerateContent ...bool) (*client.Client, *client.ErrorMessage) {
if len(h.CliClients) == 0 {
return nil, &client.ErrorMessage{StatusCode: 500, Error: fmt.Errorf("no clients available")}
}
var cliClient *client.Client
// Lock the mutex to update the last used client index
h.Mutex.Lock()
startIndex := h.LastUsedClientIndex
if (len(isGenerateContent) > 0 && isGenerateContent[0]) || len(isGenerateContent) == 0 {
currentIndex := (startIndex + 1) % len(h.CliClients)
h.LastUsedClientIndex = currentIndex
}
h.Mutex.Unlock()
// Reorder the client to start from the last used index
reorderedClients := make([]*client.Client, 0)
for i := 0; i < len(h.CliClients); i++ {
cliClient = h.CliClients[(startIndex+1+i)%len(h.CliClients)]
if cliClient.IsModelQuotaExceeded(modelName) {
log.Debugf("Model %s is quota exceeded for account %s, project id: %s", modelName, cliClient.GetEmail(), cliClient.GetProjectID())
cliClient = nil
continue
}
reorderedClients = append(reorderedClients, cliClient)
}
if len(reorderedClients) == 0 {
return nil, &client.ErrorMessage{StatusCode: 429, Error: fmt.Errorf(`{"error":{"code":429,"message":"All the models of '%s' are quota exceeded","status":"RESOURCE_EXHAUSTED"}}`, modelName)}
}
locked := false
for i := 0; i < len(reorderedClients); i++ {
cliClient = reorderedClients[i]
if cliClient.RequestMutex.TryLock() {
locked = true
break
}
}
if !locked {
cliClient = h.CliClients[0]
cliClient.RequestMutex.Lock()
}
return cliClient, nil
}
// GetAlt extracts the 'alt' parameter from the request query string.
// It checks both 'alt' and '$alt' parameters and returns the appropriate value.
func (h *APIHandlers) GetAlt(c *gin.Context) string {
var alt string
var hasAlt bool
alt, hasAlt = c.GetQuery("alt")
if !hasAlt {
alt, _ = c.GetQuery("$alt")
}
if alt == "sse" {
return ""
}
return alt
}

View File

@@ -0,0 +1,264 @@
// Package openai provides HTTP handlers for OpenAI API endpoints.
// This package implements the OpenAI-compatible API interface, including model listing
// and chat completion functionality. It supports both streaming and non-streaming responses,
// and manages a pool of clients to interact with backend services.
// The handlers translate OpenAI API requests to the appropriate backend format and
// convert responses back to OpenAI-compatible format.
package openai
import (
"context"
"fmt"
"github.com/luispater/CLIProxyAPI/internal/api/handlers"
"github.com/luispater/CLIProxyAPI/internal/api/translator/openai"
"github.com/luispater/CLIProxyAPI/internal/client"
log "github.com/sirupsen/logrus"
"github.com/tidwall/gjson"
"net/http"
"time"
"github.com/gin-gonic/gin"
)
// OpenAIAPIHandlers contains the handlers for OpenAI API endpoints.
// It holds a pool of clients to interact with the backend service.
type OpenAIAPIHandlers struct {
*handlers.APIHandlers
}
// NewOpenAIAPIHandlers creates a new OpenAI API handlers instance.
// It takes an APIHandlers instance as input and returns an OpenAIAPIHandlers.
func NewOpenAIAPIHandlers(apiHandlers *handlers.APIHandlers) *OpenAIAPIHandlers {
return &OpenAIAPIHandlers{
APIHandlers: apiHandlers,
}
}
// Models handles the /v1/models endpoint.
// It returns a hardcoded list of available AI models.
func (h *OpenAIAPIHandlers) Models(c *gin.Context) {
c.JSON(http.StatusOK, gin.H{
"data": []map[string]any{
{
"id": "gemini-2.5-pro",
"object": "model",
"version": "2.5",
"name": "Gemini 2.5 Pro",
"description": "Stable release (June 17th, 2025) of Gemini 2.5 Pro",
"context_length": 1048576,
"max_completion_tokens": 65536,
"supported_parameters": []string{
"tools",
"temperature",
"top_p",
"top_k",
},
"temperature": 1,
"topP": 0.95,
"topK": 64,
"maxTemperature": 2,
"thinking": true,
},
{
"id": "gemini-2.5-flash",
"object": "model",
"version": "001",
"name": "Gemini 2.5 Flash",
"description": "Stable version of Gemini 2.5 Flash, our mid-size multimodal model that supports up to 1 million tokens, released in June of 2025.",
"context_length": 1048576,
"max_completion_tokens": 65536,
"supported_parameters": []string{
"tools",
"temperature",
"top_p",
"top_k",
},
"temperature": 1,
"topP": 0.95,
"topK": 64,
"maxTemperature": 2,
"thinking": true,
},
},
})
}
// ChatCompletions handles the /v1/chat/completions endpoint.
// It determines whether the request is for a streaming or non-streaming response
// and calls the appropriate handler.
func (h *OpenAIAPIHandlers) ChatCompletions(c *gin.Context) {
rawJSON, err := c.GetRawData()
// If data retrieval fails, return a 400 Bad Request error.
if err != nil {
c.JSON(http.StatusBadRequest, handlers.ErrorResponse{
Error: handlers.ErrorDetail{
Message: fmt.Sprintf("Invalid request: %v", err),
Type: "invalid_request_error",
},
})
return
}
// Check if the client requested a streaming response.
streamResult := gjson.GetBytes(rawJSON, "stream")
if streamResult.Type == gjson.True {
h.handleStreamingResponse(c, rawJSON)
} else {
h.handleNonStreamingResponse(c, rawJSON)
}
}
// handleNonStreamingResponse handles non-streaming chat completion responses.
// It selects a client from the pool, sends the request, and aggregates the response
// before sending it back to the client.
func (h *OpenAIAPIHandlers) handleNonStreamingResponse(c *gin.Context, rawJSON []byte) {
c.Header("Content-Type", "application/json")
modelName, systemInstruction, contents, tools := openai.PrepareRequest(rawJSON)
cliCtx, cliCancel := context.WithCancel(context.Background())
var cliClient *client.Client
defer func() {
if cliClient != nil {
cliClient.RequestMutex.Unlock()
}
}()
for {
var errorResponse *client.ErrorMessage
cliClient, errorResponse = h.GetClient(modelName)
if errorResponse != nil {
c.Status(errorResponse.StatusCode)
_, _ = fmt.Fprint(c.Writer, errorResponse.Error)
cliCancel()
return
}
isGlAPIKey := false
if glAPIKey := cliClient.GetGenerativeLanguageAPIKey(); glAPIKey != "" {
log.Debugf("Request use generative language API Key: %s", glAPIKey)
isGlAPIKey = true
} else {
log.Debugf("Request use account: %s, project id: %s", cliClient.GetEmail(), cliClient.GetProjectID())
}
resp, err := cliClient.SendMessage(cliCtx, rawJSON, modelName, systemInstruction, contents, tools)
if err != nil {
if err.StatusCode == 429 && h.Cfg.QuotaExceeded.SwitchProject {
continue
} else {
c.Status(err.StatusCode)
_, _ = c.Writer.Write([]byte(err.Error.Error()))
cliCancel()
}
break
} else {
openAIFormat := openai.ConvertCliToOpenAINonStream(resp, time.Now().Unix(), isGlAPIKey)
if openAIFormat != "" {
_, _ = c.Writer.Write([]byte(openAIFormat))
}
cliCancel()
break
}
}
}
// handleStreamingResponse handles streaming responses
func (h *OpenAIAPIHandlers) handleStreamingResponse(c *gin.Context, rawJSON []byte) {
c.Header("Content-Type", "text/event-stream")
c.Header("Cache-Control", "no-cache")
c.Header("Connection", "keep-alive")
c.Header("Access-Control-Allow-Origin", "*")
// Get the http.Flusher interface to manually flush the response.
flusher, ok := c.Writer.(http.Flusher)
if !ok {
c.JSON(http.StatusInternalServerError, handlers.ErrorResponse{
Error: handlers.ErrorDetail{
Message: "Streaming not supported",
Type: "server_error",
},
})
return
}
// Prepare the request for the backend client.
modelName, systemInstruction, contents, tools := openai.PrepareRequest(rawJSON)
cliCtx, cliCancel := context.WithCancel(context.Background())
var cliClient *client.Client
defer func() {
// Ensure the client's mutex is unlocked on function exit.
if cliClient != nil {
cliClient.RequestMutex.Unlock()
}
}()
outLoop:
for {
var errorResponse *client.ErrorMessage
cliClient, errorResponse = h.GetClient(modelName)
if errorResponse != nil {
c.Status(errorResponse.StatusCode)
_, _ = fmt.Fprint(c.Writer, errorResponse.Error)
flusher.Flush()
cliCancel()
return
}
isGlAPIKey := false
if glAPIKey := cliClient.GetGenerativeLanguageAPIKey(); glAPIKey != "" {
log.Debugf("Request use generative language API Key: %s", glAPIKey)
isGlAPIKey = true
} else {
log.Debugf("Request use account: %s, project id: %s", cliClient.GetEmail(), cliClient.GetProjectID())
}
// Send the message and receive response chunks and errors via channels.
respChan, errChan := cliClient.SendMessageStream(cliCtx, rawJSON, modelName, systemInstruction, contents, tools)
hasFirstResponse := false
for {
select {
// Handle client disconnection.
case <-c.Request.Context().Done():
if c.Request.Context().Err().Error() == "context canceled" {
log.Debugf("Client disconnected: %v", c.Request.Context().Err())
cliCancel() // Cancel the backend request.
return
}
// Process incoming response chunks.
case chunk, okStream := <-respChan:
if !okStream {
// Stream is closed, send the final [DONE] message.
_, _ = fmt.Fprintf(c.Writer, "data: [DONE]\n\n")
flusher.Flush()
cliCancel()
return
}
// Convert the chunk to OpenAI format and send it to the client.
hasFirstResponse = true
openAIFormat := openai.ConvertCliToOpenAI(chunk, time.Now().Unix(), isGlAPIKey)
if openAIFormat != "" {
_, _ = fmt.Fprintf(c.Writer, "data: %s\n\n", openAIFormat)
flusher.Flush()
}
// Handle errors from the backend.
case err, okError := <-errChan:
if okError {
if err.StatusCode == 429 && h.Cfg.QuotaExceeded.SwitchProject {
continue outLoop
} else {
c.Status(err.StatusCode)
_, _ = fmt.Fprint(c.Writer, err.Error.Error())
flusher.Flush()
cliCancel()
}
return
}
// Send a keep-alive signal to the client.
case <-time.After(500 * time.Millisecond):
if hasFirstResponse {
_, _ = c.Writer.Write([]byte(": CLI-PROXY-API PROCESSING\n\n"))
flusher.Flush()
}
}
}
}
}

View File

@@ -1,18 +0,0 @@
package api
// ErrorResponse represents a standard error response format for the API.
// It contains a single ErrorDetail field.
type ErrorResponse struct {
Error ErrorDetail `json:"error"`
}
// ErrorDetail provides specific information about an error that occurred.
// It includes a human-readable message, an error type, and an optional error code.
type ErrorDetail struct {
// A human-readable message providing more details about the error.
Message string `json:"message"`
// The type of error that occurred (e.g., "invalid_request_error").
Type string `json:"type"`
// A short code identifying the error, if applicable.
Code string `json:"code,omitempty"`
}

View File

@@ -1,3 +1,7 @@
// Package api provides the HTTP API server implementation for the CLI Proxy API.
// It includes the main server struct, routing setup, middleware for CORS and authentication,
// and integration with various AI API handlers (OpenAI, Claude, Gemini).
// The server supports hot-reloading of clients and configuration.
package api
import (
@@ -5,6 +9,11 @@ import (
"errors"
"fmt"
"github.com/gin-gonic/gin"
"github.com/luispater/CLIProxyAPI/internal/api/handlers"
"github.com/luispater/CLIProxyAPI/internal/api/handlers/claude"
"github.com/luispater/CLIProxyAPI/internal/api/handlers/gemini"
"github.com/luispater/CLIProxyAPI/internal/api/handlers/gemini/cli"
"github.com/luispater/CLIProxyAPI/internal/api/handlers/openai"
"github.com/luispater/CLIProxyAPI/internal/client"
"github.com/luispater/CLIProxyAPI/internal/config"
log "github.com/sirupsen/logrus"
@@ -17,7 +26,7 @@ import (
type Server struct {
engine *gin.Engine
server *http.Server
handlers *APIHandlers
handlers *handlers.APIHandlers
cfg *config.Config
}
@@ -29,9 +38,6 @@ func NewServer(cfg *config.Config, cliClients []*client.Client) *Server {
gin.SetMode(gin.ReleaseMode)
}
// Create handlers
handlers := NewAPIHandlers(cliClients, cfg)
// Create gin engine
engine := gin.New()
@@ -43,7 +49,7 @@ func NewServer(cfg *config.Config, cliClients []*client.Client) *Server {
// Create server instance
s := &Server{
engine: engine,
handlers: handlers,
handlers: handlers.NewAPIHandlers(cliClients, cfg),
cfg: cfg,
}
@@ -62,12 +68,27 @@ func NewServer(cfg *config.Config, cliClients []*client.Client) *Server {
// setupRoutes configures the API routes for the server.
// It defines the endpoints and associates them with their respective handlers.
func (s *Server) setupRoutes() {
openaiHandlers := openai.NewOpenAIAPIHandlers(s.handlers)
geminiHandlers := gemini.NewGeminiAPIHandlers(s.handlers)
geminiCLIHandlers := cli.NewGeminiCLIAPIHandlers(s.handlers)
claudeCodeHandlers := claude.NewClaudeCodeAPIHandlers(s.handlers)
// OpenAI compatible API routes
v1 := s.engine.Group("/v1")
v1.Use(AuthMiddleware(s.cfg))
{
v1.GET("/models", s.handlers.Models)
v1.POST("/chat/completions", s.handlers.ChatCompletions)
v1.GET("/models", openaiHandlers.Models)
v1.POST("/chat/completions", openaiHandlers.ChatCompletions)
v1.POST("/messages", claudeCodeHandlers.ClaudeMessages)
}
// Gemini compatible API routes
v1beta := s.engine.Group("/v1beta")
v1beta.Use(AuthMiddleware(s.cfg))
{
v1beta.GET("/models", geminiHandlers.GeminiModels)
v1beta.POST("/models/:action", geminiHandlers.GeminiHandler)
v1beta.GET("/models/:action", geminiHandlers.GeminiGetHandler)
}
// Root endpoint
@@ -81,6 +102,8 @@ func (s *Server) setupRoutes() {
},
})
})
s.engine.POST("/v1internal:method", geminiCLIHandlers.CLIHandler)
}
// Start begins listening for and serving HTTP requests.
@@ -127,18 +150,31 @@ func corsMiddleware() gin.HandlerFunc {
}
}
// UpdateClients updates the server's client list and configuration
func (s *Server) UpdateClients(clients []*client.Client, cfg *config.Config) {
s.cfg = cfg
s.handlers.UpdateClients(clients, cfg)
log.Infof("server clients and configuration updated: %d clients", len(clients))
}
// AuthMiddleware returns a Gin middleware handler that authenticates requests
// using API keys. If no API keys are configured, it allows all requests.
func AuthMiddleware(cfg *config.Config) gin.HandlerFunc {
return func(c *gin.Context) {
if len(cfg.ApiKeys) == 0 {
if len(cfg.APIKeys) == 0 {
c.Next()
return
}
// Get the Authorization header
authHeader := c.GetHeader("Authorization")
if authHeader == "" {
authHeaderGoogle := c.GetHeader("X-Goog-Api-Key")
authHeaderAnthropic := c.GetHeader("X-Api-Key")
// Get the API key from the query parameter
apiKeyQuery, _ := c.GetQuery("key")
if authHeader == "" && authHeaderGoogle == "" && authHeaderAnthropic == "" && apiKeyQuery == "" {
c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{
"error": "Missing API key",
})
@@ -156,9 +192,9 @@ func AuthMiddleware(cfg *config.Config) gin.HandlerFunc {
// Find the API key in the in-memory list
var foundKey string
for i := range cfg.ApiKeys {
if cfg.ApiKeys[i] == apiKey {
foundKey = cfg.ApiKeys[i]
for i := range cfg.APIKeys {
if cfg.APIKeys[i] == apiKey || cfg.APIKeys[i] == authHeaderGoogle || cfg.APIKeys[i] == authHeaderAnthropic || cfg.APIKeys[i] == apiKeyQuery {
foundKey = cfg.APIKeys[i]
break
}
}

View File

@@ -0,0 +1,169 @@
// Package code provides request translation functionality for Claude API.
// It handles parsing and transforming Claude API requests into the internal client format,
// extracting model information, system instructions, message contents, and tool declarations.
// The package also performs JSON data cleaning and transformation to ensure compatibility
// between Claude API format and the internal client's expected format.
package code
import (
"bytes"
"encoding/json"
"github.com/luispater/CLIProxyAPI/internal/client"
"github.com/tidwall/gjson"
"github.com/tidwall/sjson"
"strings"
)
// PrepareClaudeRequest parses and transforms a Claude API request into internal client format.
// It extracts the model name, system instruction, message contents, and tool declarations
// from the raw JSON request and returns them in the format expected by the internal client.
func PrepareClaudeRequest(rawJSON []byte) (string, *client.Content, []client.Content, []client.ToolDeclaration) {
var pathsToDelete []string
root := gjson.ParseBytes(rawJSON)
walk(root, "", "additionalProperties", &pathsToDelete)
walk(root, "", "$schema", &pathsToDelete)
var err error
for _, p := range pathsToDelete {
rawJSON, err = sjson.DeleteBytes(rawJSON, p)
if err != nil {
continue
}
}
rawJSON = bytes.Replace(rawJSON, []byte(`"url":{"type":"string","format":"uri",`), []byte(`"url":{"type":"string",`), -1)
// log.Debug(string(rawJSON))
modelName := "gemini-2.5-pro"
modelResult := gjson.GetBytes(rawJSON, "model")
if modelResult.Type == gjson.String {
modelName = modelResult.String()
}
contents := make([]client.Content, 0)
var systemInstruction *client.Content
systemResult := gjson.GetBytes(rawJSON, "system")
if systemResult.IsArray() {
systemResults := systemResult.Array()
systemInstruction = &client.Content{Role: "user", Parts: []client.Part{}}
for i := 0; i < len(systemResults); i++ {
systemPromptResult := systemResults[i]
systemTypePromptResult := systemPromptResult.Get("type")
if systemTypePromptResult.Type == gjson.String && systemTypePromptResult.String() == "text" {
systemPrompt := systemPromptResult.Get("text").String()
systemPart := client.Part{Text: systemPrompt}
systemInstruction.Parts = append(systemInstruction.Parts, systemPart)
}
}
if len(systemInstruction.Parts) == 0 {
systemInstruction = nil
}
}
messagesResult := gjson.GetBytes(rawJSON, "messages")
if messagesResult.IsArray() {
messageResults := messagesResult.Array()
for i := 0; i < len(messageResults); i++ {
messageResult := messageResults[i]
roleResult := messageResult.Get("role")
if roleResult.Type != gjson.String {
continue
}
role := roleResult.String()
if role == "assistant" {
role = "model"
}
clientContent := client.Content{Role: role, Parts: []client.Part{}}
contentsResult := messageResult.Get("content")
if contentsResult.IsArray() {
contentResults := contentsResult.Array()
for j := 0; j < len(contentResults); j++ {
contentResult := contentResults[j]
contentTypeResult := contentResult.Get("type")
if contentTypeResult.Type == gjson.String && contentTypeResult.String() == "text" {
prompt := contentResult.Get("text").String()
clientContent.Parts = append(clientContent.Parts, client.Part{Text: prompt})
} else if contentTypeResult.Type == gjson.String && contentTypeResult.String() == "tool_use" {
functionName := contentResult.Get("name").String()
functionArgs := contentResult.Get("input").String()
var args map[string]any
if err = json.Unmarshal([]byte(functionArgs), &args); err == nil {
clientContent.Parts = append(clientContent.Parts, client.Part{
FunctionCall: &client.FunctionCall{
Name: functionName,
Args: args,
},
})
}
} else if contentTypeResult.Type == gjson.String && contentTypeResult.String() == "tool_result" {
toolCallID := contentResult.Get("tool_use_id").String()
if toolCallID != "" {
funcName := toolCallID
toolCallIDs := strings.Split(toolCallID, "-")
if len(toolCallIDs) > 1 {
funcName = strings.Join(toolCallIDs[0:len(toolCallIDs)-1], "-")
}
responseData := contentResult.Get("content").String()
functionResponse := client.FunctionResponse{Name: funcName, Response: map[string]interface{}{"result": responseData}}
clientContent.Parts = append(clientContent.Parts, client.Part{FunctionResponse: &functionResponse})
}
}
}
contents = append(contents, clientContent)
} else if contentsResult.Type == gjson.String {
prompt := contentsResult.String()
contents = append(contents, client.Content{Role: role, Parts: []client.Part{{Text: prompt}}})
}
}
}
var tools []client.ToolDeclaration
toolsResult := gjson.GetBytes(rawJSON, "tools")
if toolsResult.IsArray() {
tools = make([]client.ToolDeclaration, 1)
tools[0].FunctionDeclarations = make([]any, 0)
toolsResults := toolsResult.Array()
for i := 0; i < len(toolsResults); i++ {
toolResult := toolsResults[i]
inputSchemaResult := toolResult.Get("input_schema")
if inputSchemaResult.Exists() && inputSchemaResult.IsObject() {
inputSchema := inputSchemaResult.Raw
inputSchema, _ = sjson.Delete(inputSchema, "additionalProperties")
inputSchema, _ = sjson.Delete(inputSchema, "$schema")
tool, _ := sjson.Delete(toolResult.Raw, "input_schema")
tool, _ = sjson.SetRaw(tool, "parameters", inputSchema)
var toolDeclaration any
if err = json.Unmarshal([]byte(tool), &toolDeclaration); err == nil {
tools[0].FunctionDeclarations = append(tools[0].FunctionDeclarations, toolDeclaration)
}
}
}
} else {
tools = make([]client.ToolDeclaration, 0)
}
return modelName, systemInstruction, contents, tools
}
func walk(value gjson.Result, path, field string, pathsToDelete *[]string) {
switch value.Type {
case gjson.JSON:
value.ForEach(func(key, val gjson.Result) bool {
var childPath string
if path == "" {
childPath = key.String()
} else {
childPath = path + "." + key.String()
}
if key.String() == field {
*pathsToDelete = append(*pathsToDelete, childPath)
}
walk(val, childPath, field, pathsToDelete)
return true
})
case gjson.String, gjson.Number, gjson.True, gjson.False, gjson.Null:
}
}

View File

@@ -0,0 +1,206 @@
// Package code provides response translation functionality for Claude API.
// This package handles the conversion of backend client responses into Claude-compatible
// Server-Sent Events (SSE) format, implementing a sophisticated state machine that manages
// different response types including text content, thinking processes, and function calls.
// The translation ensures proper sequencing of SSE events and maintains state across
// multiple response chunks to provide a seamless streaming experience.
package code
import (
"bytes"
"fmt"
"github.com/tidwall/gjson"
"github.com/tidwall/sjson"
"time"
)
// ConvertCliToClaude performs sophisticated streaming response format conversion.
// This function implements a complex state machine that translates backend client responses
// into Claude-compatible Server-Sent Events (SSE) format. It manages different response types
// and handles state transitions between content blocks, thinking processes, and function calls.
//
// Response type states: 0=none, 1=content, 2=thinking, 3=function
// The function maintains state across multiple calls to ensure proper SSE event sequencing.
func ConvertCliToClaude(rawJSON []byte, isGlAPIKey, hasFirstResponse bool, responseType, responseIndex *int) string {
// Normalize the response format for different API key types
// Generative Language API keys have a different response structure
if isGlAPIKey {
rawJSON, _ = sjson.SetRawBytes(rawJSON, "response", rawJSON)
}
// Track whether tools are being used in this response chunk
usedTool := false
output := ""
// Initialize the streaming session with a message_start event
// This is only sent for the very first response chunk
if !hasFirstResponse {
output = "event: message_start\n"
// Create the initial message structure with default values
// This follows the Claude API specification for streaming message initialization
messageStartTemplate := `{"type": "message_start", "message": {"id": "msg_1nZdL29xx5MUA1yADyHTEsnR8uuvGzszyY", "type": "message", "role": "assistant", "content": [], "model": "claude-3-5-sonnet-20241022", "stop_reason": null, "stop_sequence": null, "usage": {"input_tokens": 0, "output_tokens": 0}}}`
// Override default values with actual response metadata if available
if modelVersionResult := gjson.GetBytes(rawJSON, "response.modelVersion"); modelVersionResult.Exists() {
messageStartTemplate, _ = sjson.Set(messageStartTemplate, "message.model", modelVersionResult.String())
}
if responseIDResult := gjson.GetBytes(rawJSON, "response.responseId"); responseIDResult.Exists() {
messageStartTemplate, _ = sjson.Set(messageStartTemplate, "message.id", responseIDResult.String())
}
output = output + fmt.Sprintf("data: %s\n\n\n", messageStartTemplate)
}
// Process the response parts array from the backend client
// Each part can contain text content, thinking content, or function calls
partsResult := gjson.GetBytes(rawJSON, "response.candidates.0.content.parts")
if partsResult.IsArray() {
partResults := partsResult.Array()
for i := 0; i < len(partResults); i++ {
partResult := partResults[i]
// Extract the different types of content from each part
partTextResult := partResult.Get("text")
functionCallResult := partResult.Get("functionCall")
// Handle text content (both regular content and thinking)
if partTextResult.Exists() {
// Process thinking content (internal reasoning)
if partResult.Get("thought").Bool() {
// Continue existing thinking block
if *responseType == 2 {
output = output + "event: content_block_delta\n"
data, _ := sjson.Set(fmt.Sprintf(`{"type":"content_block_delta","index":%d,"delta":{"type":"thinking_delta","thinking":""}}`, *responseIndex), "delta.thinking", partTextResult.String())
output = output + fmt.Sprintf("data: %s\n\n\n", data)
} else {
// Transition from another state to thinking
// First, close any existing content block
if *responseType != 0 {
if *responseType == 2 {
output = output + "event: content_block_delta\n"
output = output + fmt.Sprintf(`data: {"type":"content_block_delta","index":%d,"delta":{"type":"signature_delta","signature":null}}`, *responseIndex)
output = output + "\n\n\n"
}
output = output + "event: content_block_stop\n"
output = output + fmt.Sprintf(`data: {"type":"content_block_stop","index":%d}`, *responseIndex)
output = output + "\n\n\n"
*responseIndex++
}
// Start a new thinking content block
output = output + "event: content_block_start\n"
output = output + fmt.Sprintf(`data: {"type":"content_block_start","index":%d,"content_block":{"type":"thinking","thinking":""}}`, *responseIndex)
output = output + "\n\n\n"
output = output + "event: content_block_delta\n"
data, _ := sjson.Set(fmt.Sprintf(`{"type":"content_block_delta","index":%d,"delta":{"type":"thinking_delta","thinking":""}}`, *responseIndex), "delta.thinking", partTextResult.String())
output = output + fmt.Sprintf("data: %s\n\n\n", data)
*responseType = 2 // Set state to thinking
}
} else {
// Process regular text content (user-visible output)
// Continue existing text block
if *responseType == 1 {
output = output + "event: content_block_delta\n"
data, _ := sjson.Set(fmt.Sprintf(`{"type":"content_block_delta","index":%d,"delta":{"type":"text_delta","text":""}}`, *responseIndex), "delta.text", partTextResult.String())
output = output + fmt.Sprintf("data: %s\n\n\n", data)
} else {
// Transition from another state to text content
// First, close any existing content block
if *responseType != 0 {
if *responseType == 2 {
output = output + "event: content_block_delta\n"
output = output + fmt.Sprintf(`data: {"type":"content_block_delta","index":%d,"delta":{"type":"signature_delta","signature":null}}`, *responseIndex)
output = output + "\n\n\n"
}
output = output + "event: content_block_stop\n"
output = output + fmt.Sprintf(`data: {"type":"content_block_stop","index":%d}`, *responseIndex)
output = output + "\n\n\n"
*responseIndex++
}
// Start a new text content block
output = output + "event: content_block_start\n"
output = output + fmt.Sprintf(`data: {"type":"content_block_start","index":%d,"content_block":{"type":"text","text":""}}`, *responseIndex)
output = output + "\n\n\n"
output = output + "event: content_block_delta\n"
data, _ := sjson.Set(fmt.Sprintf(`{"type":"content_block_delta","index":%d,"delta":{"type":"text_delta","text":""}}`, *responseIndex), "delta.text", partTextResult.String())
output = output + fmt.Sprintf("data: %s\n\n\n", data)
*responseType = 1 // Set state to content
}
}
} else if functionCallResult.Exists() {
// Handle function/tool calls from the AI model
// This processes tool usage requests and formats them for Claude API compatibility
usedTool = true
fcName := functionCallResult.Get("name").String()
// Handle state transitions when switching to function calls
// Close any existing function call block first
if *responseType == 3 {
output = output + "event: content_block_stop\n"
output = output + fmt.Sprintf(`data: {"type":"content_block_stop","index":%d}`, *responseIndex)
output = output + "\n\n\n"
*responseIndex++
*responseType = 0
}
// Special handling for thinking state transition
if *responseType == 2 {
output = output + "event: content_block_delta\n"
output = output + fmt.Sprintf(`data: {"type":"content_block_delta","index":%d,"delta":{"type":"signature_delta","signature":null}}`, *responseIndex)
output = output + "\n\n\n"
}
// Close any other existing content block
if *responseType != 0 {
output = output + "event: content_block_stop\n"
output = output + fmt.Sprintf(`data: {"type":"content_block_stop","index":%d}`, *responseIndex)
output = output + "\n\n\n"
*responseIndex++
}
// Start a new tool use content block
// This creates the structure for a function call in Claude format
output = output + "event: content_block_start\n"
// Create the tool use block with unique ID and function details
data := fmt.Sprintf(`{"type":"content_block_start","index":%d,"content_block":{"type":"tool_use","id":"","name":"","input":{}}}`, *responseIndex)
data, _ = sjson.Set(data, "content_block.id", fmt.Sprintf("%s-%d", fcName, time.Now().UnixNano()))
data, _ = sjson.Set(data, "content_block.name", fcName)
output = output + fmt.Sprintf("data: %s\n\n\n", data)
if fcArgsResult := functionCallResult.Get("args"); fcArgsResult.Exists() {
output = output + "event: content_block_delta\n"
data, _ = sjson.Set(fmt.Sprintf(`{"type":"content_block_delta","index":%d,"delta":{"type":"input_json_delta","partial_json":""}}`, *responseIndex), "delta.partial_json", fcArgsResult.Raw)
output = output + fmt.Sprintf("data: %s\n\n\n", data)
}
*responseType = 3
}
}
}
usageResult := gjson.GetBytes(rawJSON, "response.usageMetadata")
if usageResult.Exists() && bytes.Contains(rawJSON, []byte(`"finishReason"`)) {
if candidatesTokenCountResult := usageResult.Get("candidatesTokenCount"); candidatesTokenCountResult.Exists() {
output = output + "event: content_block_stop\n"
output = output + fmt.Sprintf(`data: {"type":"content_block_stop","index":%d}`, *responseIndex)
output = output + "\n\n\n"
output = output + "event: message_delta\n"
output = output + `data: `
template := `{"type":"message_delta","delta":{"stop_reason":"end_turn","stop_sequence":null},"usage":{"input_tokens":0,"output_tokens":0}}`
if usedTool {
template = `{"type":"message_delta","delta":{"stop_reason":"tool_use","stop_sequence":null},"usage":{"input_tokens":0,"output_tokens":0}}`
}
thoughtsTokenCount := usageResult.Get("thoughtsTokenCount").Int()
template, _ = sjson.Set(template, "usage.output_tokens", candidatesTokenCountResult.Int()+thoughtsTokenCount)
template, _ = sjson.Set(template, "usage.input_tokens", usageResult.Get("promptTokenCount").Int())
output = output + template + "\n\n\n"
}
}
return output
}

View File

@@ -0,0 +1,185 @@
// Package cli provides request translation functionality for Gemini CLI API.
// It handles the conversion and formatting of CLI tool responses, specifically
// transforming between different JSON formats to ensure proper conversation flow
// and API compatibility. The package focuses on intelligently grouping function
// calls with their corresponding responses, converting from linear format to
// grouped format where function calls and responses are properly associated.
package cli
import (
"encoding/json"
"fmt"
log "github.com/sirupsen/logrus"
"github.com/tidwall/gjson"
"github.com/tidwall/sjson"
)
// FunctionCallGroup represents a group of function calls and their responses
type FunctionCallGroup struct {
ModelContent map[string]interface{}
FunctionCalls []gjson.Result
ResponsesNeeded int
}
// FixCLIToolResponse performs sophisticated tool response format conversion and grouping.
// This function transforms the CLI tool response format by intelligently grouping function calls
// with their corresponding responses, ensuring proper conversation flow and API compatibility.
// It converts from a linear format (1.json) to a grouped format (2.json) where function calls
// and their responses are properly associated and structured.
func FixCLIToolResponse(input string) (string, error) {
// Parse the input JSON to extract the conversation structure
parsed := gjson.Parse(input)
// Extract the contents array which contains the conversation messages
contents := parsed.Get("request.contents")
if !contents.Exists() {
// log.Debugf(input)
return input, fmt.Errorf("contents not found in input")
}
// Initialize data structures for processing and grouping
var newContents []interface{} // Final processed contents array
var pendingGroups []*FunctionCallGroup // Groups awaiting completion with responses
var collectedResponses []gjson.Result // Standalone responses to be matched
// Process each content object in the conversation
// This iterates through messages and groups function calls with their responses
contents.ForEach(func(key, value gjson.Result) bool {
role := value.Get("role").String()
parts := value.Get("parts")
// Check if this content has function responses
var responsePartsInThisContent []gjson.Result
parts.ForEach(func(_, part gjson.Result) bool {
if part.Get("functionResponse").Exists() {
responsePartsInThisContent = append(responsePartsInThisContent, part)
}
return true
})
// If this content has function responses, collect them
if len(responsePartsInThisContent) > 0 {
collectedResponses = append(collectedResponses, responsePartsInThisContent...)
// Check if any pending groups can be satisfied
for i := len(pendingGroups) - 1; i >= 0; i-- {
group := pendingGroups[i]
if len(collectedResponses) >= group.ResponsesNeeded {
// Take the needed responses for this group
groupResponses := collectedResponses[:group.ResponsesNeeded]
collectedResponses = collectedResponses[group.ResponsesNeeded:]
// Create merged function response content
var responseParts []interface{}
for _, response := range groupResponses {
var responseMap map[string]interface{}
errUnmarshal := json.Unmarshal([]byte(response.Raw), &responseMap)
if errUnmarshal != nil {
log.Warnf("failed to unmarshal function response: %v\n", errUnmarshal)
continue
}
responseParts = append(responseParts, responseMap)
}
if len(responseParts) > 0 {
functionResponseContent := map[string]interface{}{
"parts": responseParts,
"role": "function",
}
newContents = append(newContents, functionResponseContent)
}
// Remove this group as it's been satisfied
pendingGroups = append(pendingGroups[:i], pendingGroups[i+1:]...)
break
}
}
return true // Skip adding this content, responses are merged
}
// If this is a model with function calls, create a new group
if role == "model" {
var functionCallsInThisModel []gjson.Result
parts.ForEach(func(_, part gjson.Result) bool {
if part.Get("functionCall").Exists() {
functionCallsInThisModel = append(functionCallsInThisModel, part)
}
return true
})
if len(functionCallsInThisModel) > 0 {
// Add the model content
var contentMap map[string]interface{}
errUnmarshal := json.Unmarshal([]byte(value.Raw), &contentMap)
if errUnmarshal != nil {
log.Warnf("failed to unmarshal model content: %v\n", errUnmarshal)
return true
}
newContents = append(newContents, contentMap)
// Create a new group for tracking responses
group := &FunctionCallGroup{
ModelContent: contentMap,
FunctionCalls: functionCallsInThisModel,
ResponsesNeeded: len(functionCallsInThisModel),
}
pendingGroups = append(pendingGroups, group)
} else {
// Regular model content without function calls
var contentMap map[string]interface{}
errUnmarshal := json.Unmarshal([]byte(value.Raw), &contentMap)
if errUnmarshal != nil {
log.Warnf("failed to unmarshal content: %v\n", errUnmarshal)
return true
}
newContents = append(newContents, contentMap)
}
} else {
// Non-model content (user, etc.)
var contentMap map[string]interface{}
errUnmarshal := json.Unmarshal([]byte(value.Raw), &contentMap)
if errUnmarshal != nil {
log.Warnf("failed to unmarshal content: %v\n", errUnmarshal)
return true
}
newContents = append(newContents, contentMap)
}
return true
})
// Handle any remaining pending groups with remaining responses
for _, group := range pendingGroups {
if len(collectedResponses) >= group.ResponsesNeeded {
groupResponses := collectedResponses[:group.ResponsesNeeded]
collectedResponses = collectedResponses[group.ResponsesNeeded:]
var responseParts []interface{}
for _, response := range groupResponses {
var responseMap map[string]interface{}
errUnmarshal := json.Unmarshal([]byte(response.Raw), &responseMap)
if errUnmarshal != nil {
log.Warnf("failed to unmarshal function response: %v\n", errUnmarshal)
continue
}
responseParts = append(responseParts, responseMap)
}
if len(responseParts) > 0 {
functionResponseContent := map[string]interface{}{
"parts": responseParts,
"role": "function",
}
newContents = append(newContents, functionResponseContent)
}
}
}
// Update the original JSON with the new contents
result := input
newContentsJSON, _ := json.Marshal(newContents)
result, _ = sjson.Set(result, "request.contents", json.RawMessage(newContentsJSON))
return result, nil
}

View File

@@ -1,3 +1,6 @@
// Package translator provides data translation and format conversion utilities
// for the CLI Proxy API. It includes MIME type mappings and other translation
// functions used across different API endpoints.
package translator
// MimeTypes is a comprehensive map of file extensions to their corresponding MIME types.

View File

@@ -1,7 +1,12 @@
package translator
// Package openai provides request translation functionality for OpenAI API.
// It handles the conversion of OpenAI-compatible request formats to the internal
// format expected by the backend client, including parsing messages, roles,
// content types (text, image, file), and tool calls.
package openai
import (
"encoding/json"
"github.com/luispater/CLIProxyAPI/internal/api/translator"
"strings"
"github.com/luispater/CLIProxyAPI/internal/client"
@@ -12,17 +17,60 @@ import (
// PrepareRequest translates a raw JSON request from an OpenAI-compatible format
// to the internal format expected by the backend client. It parses messages,
// roles, content types (text, image, file), and tool calls.
func PrepareRequest(rawJson []byte) (string, []client.Content, []client.ToolDeclaration) {
func PrepareRequest(rawJSON []byte) (string, *client.Content, []client.Content, []client.ToolDeclaration) {
// Extract the model name from the request, defaulting to "gemini-2.5-pro".
modelName := "gemini-2.5-pro"
modelResult := gjson.GetBytes(rawJson, "model")
modelResult := gjson.GetBytes(rawJSON, "model")
if modelResult.Type == gjson.String {
modelName = modelResult.String()
}
// Process the array of messages.
// Initialize data structures for processing conversation messages
// contents: stores the processed conversation history
// systemInstruction: stores system-level instructions separate from conversation
contents := make([]client.Content, 0)
messagesResult := gjson.GetBytes(rawJson, "messages")
var systemInstruction *client.Content
messagesResult := gjson.GetBytes(rawJSON, "messages")
// Pre-process tool responses to create a lookup map
// This first pass collects all tool responses so they can be matched with their corresponding calls
toolItems := make(map[string]*client.FunctionResponse)
if messagesResult.IsArray() {
messagesResults := messagesResult.Array()
for i := 0; i < len(messagesResults); i++ {
messageResult := messagesResults[i]
roleResult := messageResult.Get("role")
if roleResult.Type != gjson.String {
continue
}
contentResult := messageResult.Get("content")
// Extract tool responses for later matching with function calls
if roleResult.String() == "tool" {
toolCallID := messageResult.Get("tool_call_id").String()
if toolCallID != "" {
var responseData string
// Handle both string and object-based tool response formats
if contentResult.Type == gjson.String {
responseData = contentResult.String()
} else if contentResult.IsObject() && contentResult.Get("type").String() == "text" {
responseData = contentResult.Get("text").String()
}
// Clean up tool call ID by removing timestamp suffix
// This normalizes IDs for consistent matching between calls and responses
toolCallIDs := strings.Split(toolCallID, "-")
strings.Join(toolCallIDs, "-")
newToolCallID := strings.Join(toolCallIDs[:len(toolCallIDs)-1], "-")
// Create function response object with normalized ID and response data
functionResponse := client.FunctionResponse{Name: newToolCallID, Response: map[string]interface{}{"result": responseData}}
toolItems[toolCallID] = &functionResponse
}
}
}
}
if messagesResult.IsArray() {
messagesResults := messagesResult.Array()
for i := 0; i < len(messagesResults); i++ {
@@ -37,13 +85,11 @@ func PrepareRequest(rawJson []byte) (string, []client.Content, []client.ToolDecl
// System messages are converted to a user message followed by a model's acknowledgment.
case "system":
if contentResult.Type == gjson.String {
contents = append(contents, client.Content{Role: "user", Parts: []client.Part{{Text: contentResult.String()}}})
contents = append(contents, client.Content{Role: "model", Parts: []client.Part{{Text: "Understood. I will follow these instructions and use my tools to assist you."}}})
systemInstruction = &client.Content{Role: "user", Parts: []client.Part{{Text: contentResult.String()}}}
} else if contentResult.IsObject() {
// Handle object-based system messages.
if contentResult.Get("type").String() == "text" {
contents = append(contents, client.Content{Role: "user", Parts: []client.Part{{Text: contentResult.Get("text").String()}}})
contents = append(contents, client.Content{Role: "model", Parts: []client.Part{{Text: "Understood. I will follow these instructions and use my tools to assist you."}}})
systemInstruction = &client.Content{Role: "user", Parts: []client.Part{{Text: contentResult.Get("text").String()}}}
}
}
// User messages can contain simple text or a multi-part body.
@@ -80,7 +126,7 @@ func PrepareRequest(rawJson []byte) (string, []client.Content, []client.ToolDecl
if split := strings.Split(filename, "."); len(split) > 1 {
ext = split[len(split)-1]
}
if mimeType, ok := MimeTypes[ext]; ok {
if mimeType, ok := translator.MimeTypes[ext]; ok {
parts = append(parts, client.Part{InlineData: &client.InlineData{
MimeType: mimeType,
Data: fileData,
@@ -92,53 +138,70 @@ func PrepareRequest(rawJson []byte) (string, []client.Content, []client.ToolDecl
}
contents = append(contents, client.Content{Role: "user", Parts: parts})
}
// Assistant messages can contain text or tool calls.
// Assistant messages can contain text responses or tool calls
// In the internal format, assistant messages are converted to "model" role
case "assistant":
if contentResult.Type == gjson.String {
// Simple text response from the assistant
contents = append(contents, client.Content{Role: "model", Parts: []client.Part{{Text: contentResult.String()}}})
} else if !contentResult.Exists() || contentResult.Type == gjson.Null {
// Handle tool calls made by the assistant.
// Handle complex tool calls made by the assistant
// This processes function calls and matches them with their responses
functionIDs := make([]string, 0)
toolCallsResult := messageResult.Get("tool_calls")
if toolCallsResult.IsArray() {
parts := make([]client.Part, 0)
tcsResult := toolCallsResult.Array()
// Process each tool call in the assistant's message
for j := 0; j < len(tcsResult); j++ {
tcResult := tcsResult[j]
// Extract function call details
functionID := tcResult.Get("id").String()
functionIDs = append(functionIDs, functionID)
functionName := tcResult.Get("function.name").String()
functionArgs := tcResult.Get("function.arguments").String()
// Parse function arguments from JSON string to map
var args map[string]any
if err := json.Unmarshal([]byte(functionArgs), &args); err == nil {
contents = append(contents, client.Content{
Role: "model", Parts: []client.Part{{
FunctionCall: &client.FunctionCall{
Name: functionName,
Args: args,
},
}},
parts = append(parts, client.Part{
FunctionCall: &client.FunctionCall{
Name: functionName,
Args: args,
},
})
}
}
// Add the model's function calls to the conversation
if len(parts) > 0 {
contents = append(contents, client.Content{
Role: "model", Parts: parts,
})
// Create a separate tool response message with the collected responses
// This matches function calls with their corresponding responses
toolParts := make([]client.Part, 0)
for _, functionID := range functionIDs {
if functionResponse, ok := toolItems[functionID]; ok {
toolParts = append(toolParts, client.Part{FunctionResponse: functionResponse})
}
}
// Add the tool responses as a separate message in the conversation
contents = append(contents, client.Content{Role: "tool", Parts: toolParts})
}
}
}
// Tool messages contain the output of a tool call.
case "tool":
toolCallID := messageResult.Get("tool_call_id").String()
if toolCallID != "" {
var responseData string
if contentResult.Type == gjson.String {
responseData = contentResult.String()
} else if contentResult.IsObject() && contentResult.Get("type").String() == "text" {
responseData = contentResult.Get("text").String()
}
functionResponse := client.FunctionResponse{Name: toolCallID, Response: map[string]interface{}{"result": responseData}}
contents = append(contents, client.Content{Role: "tool", Parts: []client.Part{{FunctionResponse: &functionResponse}}})
}
}
}
}
// Translate the tool declarations from the request.
var tools []client.ToolDeclaration
toolsResult := gjson.GetBytes(rawJson, "tools")
toolsResult := gjson.GetBytes(rawJSON, "tools")
if toolsResult.IsArray() {
tools = make([]client.ToolDeclaration, 1)
tools[0].FunctionDeclarations = make([]any, 0)
@@ -159,5 +222,5 @@ func PrepareRequest(rawJson []byte) (string, []client.Content, []client.ToolDecl
tools = make([]client.ToolDeclaration, 0)
}
return modelName, contents, tools
return modelName, systemInstruction, contents, tools
}

View File

@@ -1,6 +1,14 @@
package translator
// Package openai provides response translation functionality for converting between
// different API response formats and OpenAI-compatible formats. It handles both
// streaming and non-streaming responses, transforming backend client responses
// into OpenAI Server-Sent Events (SSE) format and standard JSON response formats.
// The package supports content translation, function calls, usage metadata,
// and various response attributes while maintaining compatibility with OpenAI API
// specifications.
package openai
import (
"fmt"
"time"
"github.com/tidwall/gjson"
@@ -10,21 +18,21 @@ import (
// ConvertCliToOpenAI translates a single chunk of a streaming response from the
// backend client format to the OpenAI Server-Sent Events (SSE) format.
// It returns an empty string if the chunk contains no useful data.
func ConvertCliToOpenAI(rawJson []byte, unixTimestamp int64, isGlAPIKey bool) string {
func ConvertCliToOpenAI(rawJSON []byte, unixTimestamp int64, isGlAPIKey bool) string {
if isGlAPIKey {
rawJson, _ = sjson.SetRawBytes(rawJson, "response", rawJson)
rawJSON, _ = sjson.SetRawBytes(rawJSON, "response", rawJSON)
}
// Initialize the OpenAI SSE template.
template := `{"id":"","object":"chat.completion.chunk","created":12345,"model":"model","choices":[{"index":0,"delta":{"role":null,"content":null,"reasoning_content":null,"tool_calls":null},"finish_reason":null,"native_finish_reason":null}]}`
// Extract and set the model version.
if modelVersionResult := gjson.GetBytes(rawJson, "response.modelVersion"); modelVersionResult.Exists() {
if modelVersionResult := gjson.GetBytes(rawJSON, "response.modelVersion"); modelVersionResult.Exists() {
template, _ = sjson.Set(template, "model", modelVersionResult.String())
}
// Extract and set the creation timestamp.
if createTimeResult := gjson.GetBytes(rawJson, "response.createTime"); createTimeResult.Exists() {
if createTimeResult := gjson.GetBytes(rawJSON, "response.createTime"); createTimeResult.Exists() {
t, err := time.Parse(time.RFC3339Nano, createTimeResult.String())
if err == nil {
unixTimestamp = t.Unix()
@@ -35,18 +43,18 @@ func ConvertCliToOpenAI(rawJson []byte, unixTimestamp int64, isGlAPIKey bool) st
}
// Extract and set the response ID.
if responseIdResult := gjson.GetBytes(rawJson, "response.responseId"); responseIdResult.Exists() {
template, _ = sjson.Set(template, "id", responseIdResult.String())
if responseIDResult := gjson.GetBytes(rawJSON, "response.responseId"); responseIDResult.Exists() {
template, _ = sjson.Set(template, "id", responseIDResult.String())
}
// Extract and set the finish reason.
if finishReasonResult := gjson.GetBytes(rawJson, "response.candidates.0.finishReason"); finishReasonResult.Exists() {
if finishReasonResult := gjson.GetBytes(rawJSON, "response.candidates.0.finishReason"); finishReasonResult.Exists() {
template, _ = sjson.Set(template, "choices.0.finish_reason", finishReasonResult.String())
template, _ = sjson.Set(template, "choices.0.native_finish_reason", finishReasonResult.String())
}
// Extract and set usage metadata (token counts).
if usageResult := gjson.GetBytes(rawJson, "response.usageMetadata"); usageResult.Exists() {
if usageResult := gjson.GetBytes(rawJSON, "response.usageMetadata"); usageResult.Exists() {
if candidatesTokenCountResult := usageResult.Get("candidatesTokenCount"); candidatesTokenCountResult.Exists() {
template, _ = sjson.Set(template, "usage.completion_tokens", candidatesTokenCountResult.Int())
}
@@ -62,32 +70,40 @@ func ConvertCliToOpenAI(rawJson []byte, unixTimestamp int64, isGlAPIKey bool) st
}
// Process the main content part of the response.
partResult := gjson.GetBytes(rawJson, "response.candidates.0.content.parts.0")
partTextResult := partResult.Get("text")
functionCallResult := partResult.Get("functionCall")
partsResult := gjson.GetBytes(rawJSON, "response.candidates.0.content.parts")
if partsResult.IsArray() {
partResults := partsResult.Array()
for i := 0; i < len(partResults); i++ {
partResult := partResults[i]
partTextResult := partResult.Get("text")
functionCallResult := partResult.Get("functionCall")
if partTextResult.Exists() {
// Handle text content, distinguishing between regular content and reasoning/thoughts.
if partResult.Get("thought").Bool() {
template, _ = sjson.Set(template, "choices.0.delta.reasoning_content", partTextResult.String())
} else {
template, _ = sjson.Set(template, "choices.0.delta.content", partTextResult.String())
if partTextResult.Exists() {
// Handle text content, distinguishing between regular content and reasoning/thoughts.
if partResult.Get("thought").Bool() {
template, _ = sjson.Set(template, "choices.0.delta.reasoning_content", partTextResult.String())
} else {
template, _ = sjson.Set(template, "choices.0.delta.content", partTextResult.String())
}
template, _ = sjson.Set(template, "choices.0.delta.role", "assistant")
} else if functionCallResult.Exists() {
// Handle function call content.
toolCallsResult := gjson.Get(template, "choices.0.delta.tool_calls")
if !toolCallsResult.Exists() || !toolCallsResult.IsArray() {
template, _ = sjson.SetRaw(template, "choices.0.delta.tool_calls", `[]`)
}
functionCallTemplate := `{"id": "","type": "function","function": {"name": "","arguments": ""}}`
fcName := functionCallResult.Get("name").String()
functionCallTemplate, _ = sjson.Set(functionCallTemplate, "id", fmt.Sprintf("%s-%d", fcName, time.Now().UnixNano()))
functionCallTemplate, _ = sjson.Set(functionCallTemplate, "function.name", fcName)
if fcArgsResult := functionCallResult.Get("args"); fcArgsResult.Exists() {
functionCallTemplate, _ = sjson.Set(functionCallTemplate, "function.arguments", fcArgsResult.Raw)
}
template, _ = sjson.Set(template, "choices.0.delta.role", "assistant")
template, _ = sjson.SetRaw(template, "choices.0.message.tool_calls.-1", functionCallTemplate)
}
}
template, _ = sjson.Set(template, "choices.0.delta.role", "assistant")
} else if functionCallResult.Exists() {
// Handle function call content.
functionCallTemplate := `[{"id": "","type": "function","function": {"name": "","arguments": ""}}]`
fcName := functionCallResult.Get("name").String()
functionCallTemplate, _ = sjson.Set(functionCallTemplate, "0.id", fcName)
functionCallTemplate, _ = sjson.Set(functionCallTemplate, "0.function.name", fcName)
if fcArgsResult := functionCallResult.Get("args"); fcArgsResult.Exists() {
functionCallTemplate, _ = sjson.Set(functionCallTemplate, "0.function.arguments", fcArgsResult.Raw)
}
template, _ = sjson.Set(template, "choices.0.delta.role", "assistant")
template, _ = sjson.SetRaw(template, "choices.0.delta.tool_calls", functionCallTemplate)
} else {
// If no usable content is found, return an empty string.
return ""
}
return template
@@ -95,16 +111,16 @@ func ConvertCliToOpenAI(rawJson []byte, unixTimestamp int64, isGlAPIKey bool) st
// ConvertCliToOpenAINonStream aggregates response from the backend client
// convert a single, non-streaming OpenAI-compatible JSON response.
func ConvertCliToOpenAINonStream(rawJson []byte, unixTimestamp int64, isGlAPIKey bool) string {
func ConvertCliToOpenAINonStream(rawJSON []byte, unixTimestamp int64, isGlAPIKey bool) string {
if isGlAPIKey {
rawJson, _ = sjson.SetRawBytes(rawJson, "response", rawJson)
rawJSON, _ = sjson.SetRawBytes(rawJSON, "response", rawJSON)
}
template := `{"id":"","object":"chat.completion","created":123456,"model":"model","choices":[{"index":0,"message":{"role":"assistant","content":null,"reasoning_content":null,"tool_calls":null},"finish_reason":null,"native_finish_reason":null}]}`
if modelVersionResult := gjson.GetBytes(rawJson, "response.modelVersion"); modelVersionResult.Exists() {
if modelVersionResult := gjson.GetBytes(rawJSON, "response.modelVersion"); modelVersionResult.Exists() {
template, _ = sjson.Set(template, "model", modelVersionResult.String())
}
if createTimeResult := gjson.GetBytes(rawJson, "response.createTime"); createTimeResult.Exists() {
if createTimeResult := gjson.GetBytes(rawJSON, "response.createTime"); createTimeResult.Exists() {
t, err := time.Parse(time.RFC3339Nano, createTimeResult.String())
if err == nil {
unixTimestamp = t.Unix()
@@ -114,16 +130,16 @@ func ConvertCliToOpenAINonStream(rawJson []byte, unixTimestamp int64, isGlAPIKey
template, _ = sjson.Set(template, "created", unixTimestamp)
}
if responseIdResult := gjson.GetBytes(rawJson, "response.responseId"); responseIdResult.Exists() {
template, _ = sjson.Set(template, "id", responseIdResult.String())
if responseIDResult := gjson.GetBytes(rawJSON, "response.responseId"); responseIDResult.Exists() {
template, _ = sjson.Set(template, "id", responseIDResult.String())
}
if finishReasonResult := gjson.GetBytes(rawJson, "response.candidates.0.finishReason"); finishReasonResult.Exists() {
if finishReasonResult := gjson.GetBytes(rawJSON, "response.candidates.0.finishReason"); finishReasonResult.Exists() {
template, _ = sjson.Set(template, "choices.0.finish_reason", finishReasonResult.String())
template, _ = sjson.Set(template, "choices.0.native_finish_reason", finishReasonResult.String())
}
if usageResult := gjson.GetBytes(rawJson, "response.usageMetadata"); usageResult.Exists() {
if usageResult := gjson.GetBytes(rawJSON, "response.usageMetadata"); usageResult.Exists() {
if candidatesTokenCountResult := usageResult.Get("candidatesTokenCount"); candidatesTokenCountResult.Exists() {
template, _ = sjson.Set(template, "usage.completion_tokens", candidatesTokenCountResult.Int())
}
@@ -139,7 +155,7 @@ func ConvertCliToOpenAINonStream(rawJson []byte, unixTimestamp int64, isGlAPIKey
}
// Process the main content part of the response.
partsResult := gjson.GetBytes(rawJson, "response.candidates.0.content.parts")
partsResult := gjson.GetBytes(rawJSON, "response.candidates.0.content.parts")
if partsResult.IsArray() {
partsResults := partsResult.Array()
for i := 0; i < len(partsResults); i++ {
@@ -163,7 +179,7 @@ func ConvertCliToOpenAINonStream(rawJson []byte, unixTimestamp int64, isGlAPIKey
}
functionCallItemTemplate := `{"id": "","type": "function","function": {"name": "","arguments": ""}}`
fcName := functionCallResult.Get("name").String()
functionCallItemTemplate, _ = sjson.Set(functionCallItemTemplate, "id", fcName)
functionCallItemTemplate, _ = sjson.Set(functionCallItemTemplate, "id", fmt.Sprintf("%s-%d", fcName, time.Now().UnixNano()))
functionCallItemTemplate, _ = sjson.Set(functionCallItemTemplate, "function.name", fcName)
if fcArgsResult := functionCallResult.Get("args"); fcArgsResult.Exists() {
functionCallItemTemplate, _ = sjson.Set(functionCallItemTemplate, "function.arguments", fcArgsResult.Raw)

View File

@@ -1,3 +1,6 @@
// Package auth provides OAuth2 authentication functionality for Google Cloud APIs.
// It handles the complete OAuth2 flow including token storage, web-based authentication,
// proxy support, and automatic token refresh. The package supports both SOCKS5 and HTTP/HTTPS proxies.
package auth
import (
@@ -39,7 +42,7 @@ var (
// initiating a new web-based OAuth flow if necessary, and refreshing tokens.
func GetAuthenticatedClient(ctx context.Context, ts *TokenStorage, cfg *config.Config) (*http.Client, error) {
// Configure proxy settings for the HTTP client if a proxy URL is provided.
proxyURL, err := url.Parse(cfg.ProxyUrl)
proxyURL, err := url.Parse(cfg.ProxyURL)
if err == nil {
var transport *http.Transport
if proxyURL.Scheme == "socks5" {
@@ -168,11 +171,12 @@ func getTokenFromWeb(ctx context.Context, config *oauth2.Config) (*oauth2.Token,
codeChan := make(chan string)
errChan := make(chan error)
// Create a new HTTP server.
server := &http.Server{Addr: "localhost:8085"}
// Create a new HTTP server with its own multiplexer.
mux := http.NewServeMux()
server := &http.Server{Addr: ":8085", Handler: mux}
config.RedirectURL = "http://localhost:8085/oauth2callback"
http.HandleFunc("/oauth2callback", func(w http.ResponseWriter, r *http.Request) {
mux.HandleFunc("/oauth2callback", func(w http.ResponseWriter, r *http.Request) {
if err := r.URL.Query().Get("error"); err != "" {
_, _ = fmt.Fprintf(w, "Authentication failed: %s", err)
errChan <- fmt.Errorf("authentication failed via callback: %s", err)

View File

@@ -1,3 +1,7 @@
// Package client provides HTTP client functionality for interacting with Google Cloud AI APIs.
// It handles OAuth2 authentication, token management, request/response processing,
// streaming communication, quota management, and automatic model fallback.
// The package supports both direct API key authentication and OAuth2 flows.
package client
import (
@@ -28,8 +32,8 @@ const (
apiVersion = "v1internal"
pluginVersion = "0.1.9"
glEndPoint = "https://generativelanguage.googleapis.com/"
glApiVersion = "v1beta"
glEndPoint = "https://generativelanguage.googleapis.com"
glAPIVersion = "v1beta"
)
var (
@@ -64,30 +68,37 @@ func NewClient(httpClient *http.Client, ts *auth.TokenStorage, cfg *config.Confi
}
}
// SetProjectID updates the project ID for the client's token storage.
func (c *Client) SetProjectID(projectID string) {
c.tokenStorage.ProjectID = projectID
}
// SetIsAuto configures whether the client should operate in automatic mode.
func (c *Client) SetIsAuto(auto bool) {
c.tokenStorage.Auto = auto
}
// SetIsChecked sets the checked status for the client's token storage.
func (c *Client) SetIsChecked(checked bool) {
c.tokenStorage.Checked = checked
}
// IsChecked returns whether the client's token storage has been checked.
func (c *Client) IsChecked() bool {
return c.tokenStorage.Checked
}
// IsAuto returns whether the client is operating in automatic mode.
func (c *Client) IsAuto() bool {
return c.tokenStorage.Auto
}
// GetEmail returns the email address associated with the client's token storage.
func (c *Client) GetEmail() string {
return c.tokenStorage.Email
}
// GetProjectID returns the Google Cloud project ID from the client's token storage.
func (c *Client) GetProjectID() string {
if c.tokenStorage != nil {
return c.tokenStorage.ProjectID
@@ -95,6 +106,7 @@ func (c *Client) GetProjectID() string {
return ""
}
// GetGenerativeLanguageAPIKey returns the generative language API key if configured.
func (c *Client) GetGenerativeLanguageAPIKey() string {
return c.glAPIKey
}
@@ -212,6 +224,7 @@ func (c *Client) makeAPIRequest(ctx context.Context, endpoint, method string, bo
metadataStr := getClientMetadataString()
req.Header.Set("Content-Type", "application/json")
req.Header.Set("User-Agent", getUserAgent())
req.Header.Set("X-Goog-Api-Client", "gl-node/22.17.0")
req.Header.Set("Client-Metadata", metadataStr)
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", token.AccessToken))
@@ -240,7 +253,7 @@ func (c *Client) makeAPIRequest(ctx context.Context, endpoint, method string, bo
}
// APIRequest handles making requests to the CLI API endpoints.
func (c *Client) APIRequest(ctx context.Context, endpoint string, body interface{}, stream bool) (io.ReadCloser, *ErrorMessage) {
func (c *Client) APIRequest(ctx context.Context, endpoint string, body interface{}, alt string, stream bool) (io.ReadCloser, *ErrorMessage) {
var jsonBody []byte
var err error
if byteBody, ok := body.([]byte); ok {
@@ -256,19 +269,39 @@ func (c *Client) APIRequest(ctx context.Context, endpoint string, body interface
if c.glAPIKey == "" {
// Add alt=sse for streaming
url = fmt.Sprintf("%s/%s:%s", codeAssistEndpoint, apiVersion, endpoint)
if stream {
if alt == "" && stream {
url = url + "?alt=sse"
} else {
if alt != "" {
url = url + fmt.Sprintf("?$alt=%s", alt)
}
}
} else {
modelResult := gjson.GetBytes(jsonBody, "model")
url = fmt.Sprintf("%s/%s/models/%s:%s", glEndPoint, glApiVersion, modelResult.String(), endpoint)
if stream {
url = url + "?alt=sse"
if endpoint == "countTokens" {
modelResult := gjson.GetBytes(jsonBody, "model")
url = fmt.Sprintf("%s/%s/models/%s:%s", glEndPoint, glAPIVersion, modelResult.String(), endpoint)
} else {
modelResult := gjson.GetBytes(jsonBody, "model")
url = fmt.Sprintf("%s/%s/models/%s:%s", glEndPoint, glAPIVersion, modelResult.String(), endpoint)
if alt == "" && stream {
url = url + "?alt=sse"
} else {
if alt != "" {
url = url + fmt.Sprintf("?$alt=%s", alt)
}
}
jsonBody = []byte(gjson.GetBytes(jsonBody, "request").Raw)
systemInstructionResult := gjson.GetBytes(jsonBody, "systemInstruction")
if systemInstructionResult.Exists() {
jsonBody, _ = sjson.SetRawBytes(jsonBody, "system_instruction", []byte(systemInstructionResult.Raw))
jsonBody, _ = sjson.DeleteBytes(jsonBody, "systemInstruction")
jsonBody, _ = sjson.DeleteBytes(jsonBody, "session_id")
}
}
jsonBody = []byte(gjson.GetBytes(jsonBody, "request").Raw)
}
// log.Debug(string(jsonBody))
// log.Debug(url)
reqBody := bytes.NewBuffer(jsonBody)
req, err := http.NewRequestWithContext(ctx, "POST", url, reqBody)
@@ -285,6 +318,7 @@ func (c *Client) APIRequest(ctx context.Context, endpoint string, body interface
return nil, &ErrorMessage{500, fmt.Errorf("failed to get token: %v", errToken)}
}
req.Header.Set("User-Agent", getUserAgent())
req.Header.Set("X-Goog-Api-Client", "gl-node/22.17.0")
req.Header.Set("Client-Metadata", metadataStr)
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", token.AccessToken))
} else {
@@ -303,15 +337,15 @@ func (c *Client) APIRequest(ctx context.Context, endpoint string, body interface
}
}()
bodyBytes, _ := io.ReadAll(resp.Body)
// log.Debug(string(jsonBody))
return nil, &ErrorMessage{resp.StatusCode, fmt.Errorf(string(bodyBytes))}
}
return resp.Body, nil
}
// SendMessageStream handles a single conversational turn, including tool calls.
func (c *Client) SendMessage(ctx context.Context, rawJson []byte, model string, contents []Content, tools []ToolDeclaration) ([]byte, *ErrorMessage) {
// SendMessage handles a single conversational turn, including tool calls.
func (c *Client) SendMessage(ctx context.Context, rawJSON []byte, model string, systemInstruction *Content, contents []Content, tools []ToolDeclaration) ([]byte, *ErrorMessage) {
request := GenerateContentRequest{
Contents: contents,
GenerationConfig: GenerationConfig{
@@ -320,6 +354,9 @@ func (c *Client) SendMessage(ctx context.Context, rawJson []byte, model string,
},
},
}
request.SystemInstruction = systemInstruction
request.Tools = tools
requestBody := map[string]interface{}{
@@ -332,7 +369,7 @@ func (c *Client) SendMessage(ctx context.Context, rawJson []byte, model string,
// log.Debug(string(byteRequestBody))
reasoningEffortResult := gjson.GetBytes(rawJson, "reasoning_effort")
reasoningEffortResult := gjson.GetBytes(rawJSON, "reasoning_effort")
if reasoningEffortResult.String() == "none" {
byteRequestBody, _ = sjson.DeleteBytes(byteRequestBody, "request.generationConfig.thinkingConfig.include_thoughts")
byteRequestBody, _ = sjson.SetBytes(byteRequestBody, "request.generationConfig.thinkingConfig.thinkingBudget", 0)
@@ -348,17 +385,17 @@ func (c *Client) SendMessage(ctx context.Context, rawJson []byte, model string,
byteRequestBody, _ = sjson.SetBytes(byteRequestBody, "request.generationConfig.thinkingConfig.thinkingBudget", -1)
}
temperatureResult := gjson.GetBytes(rawJson, "temperature")
temperatureResult := gjson.GetBytes(rawJSON, "temperature")
if temperatureResult.Exists() && temperatureResult.Type == gjson.Number {
byteRequestBody, _ = sjson.SetBytes(byteRequestBody, "request.generationConfig.temperature", temperatureResult.Num)
}
topPResult := gjson.GetBytes(rawJson, "top_p")
topPResult := gjson.GetBytes(rawJSON, "top_p")
if topPResult.Exists() && topPResult.Type == gjson.Number {
byteRequestBody, _ = sjson.SetBytes(byteRequestBody, "request.generationConfig.topP", topPResult.Num)
}
topKResult := gjson.GetBytes(rawJson, "top_k")
topKResult := gjson.GetBytes(rawJSON, "top_k")
if topKResult.Exists() && topKResult.Type == gjson.Number {
byteRequestBody, _ = sjson.SetBytes(byteRequestBody, "request.generationConfig.topK", topKResult.Num)
}
@@ -381,7 +418,7 @@ func (c *Client) SendMessage(ctx context.Context, rawJson []byte, model string,
}
}
respBody, err := c.APIRequest(ctx, "generateContent", byteRequestBody, false)
respBody, err := c.APIRequest(ctx, "generateContent", byteRequestBody, "", false)
if err != nil {
if err.StatusCode == 429 {
now := time.Now()
@@ -401,8 +438,275 @@ func (c *Client) SendMessage(ctx context.Context, rawJson []byte, model string,
}
}
// SendMessageStream handles a single conversational turn, including tool calls.
func (c *Client) SendMessageStream(ctx context.Context, rawJson []byte, model string, contents []Content, tools []ToolDeclaration) (<-chan []byte, <-chan *ErrorMessage) {
// SendMessageStream handles streaming conversational turns with comprehensive parameter management.
// This function implements a sophisticated streaming system that supports tool calls, reasoning modes,
// quota management, and automatic model fallback. It returns two channels for asynchronous communication:
// one for streaming response data and another for error handling.
func (c *Client) SendMessageStream(ctx context.Context, rawJSON []byte, model string, systemInstruction *Content, contents []Content, tools []ToolDeclaration, includeThoughts ...bool) (<-chan []byte, <-chan *ErrorMessage) {
// Define the data prefix used in Server-Sent Events streaming format
dataTag := []byte("data: ")
// Create channels for asynchronous communication
// errChan: delivers error messages during streaming
// dataChan: delivers response data chunks
errChan := make(chan *ErrorMessage)
dataChan := make(chan []byte)
// Launch a goroutine to handle the streaming process asynchronously
// This allows the function to return immediately while processing continues in the background
go func() {
// Ensure channels are properly closed when the goroutine exits
defer close(errChan)
defer close(dataChan)
// Configure thinking/reasoning capabilities
// Default to including thoughts unless explicitly disabled
includeThoughtsFlag := true
if len(includeThoughts) > 0 {
includeThoughtsFlag = includeThoughts[0]
}
// Build the base request structure for the Gemini API
// This includes conversation contents and generation configuration
request := GenerateContentRequest{
Contents: contents,
GenerationConfig: GenerationConfig{
ThinkingConfig: GenerationConfigThinkingConfig{
IncludeThoughts: includeThoughtsFlag,
},
},
}
// Add system instructions if provided
// System instructions guide the AI's behavior and response style
request.SystemInstruction = systemInstruction
// Add available tools for function calling capabilities
// Tools allow the AI to perform actions beyond text generation
request.Tools = tools
// Construct the complete request body with project context
// The project ID is essential for proper API routing and billing
requestBody := map[string]interface{}{
"project": c.GetProjectID(), // Project ID for API routing and quota management
"request": request,
"model": model,
}
// Serialize the request body to JSON for API transmission
byteRequestBody, _ := json.Marshal(requestBody)
// Parse and configure reasoning effort levels from the original request
// This maps Claude-style reasoning effort parameters to Gemini's thinking budget system
reasoningEffortResult := gjson.GetBytes(rawJSON, "reasoning_effort")
if reasoningEffortResult.String() == "none" {
// Disable thinking entirely for fastest responses
byteRequestBody, _ = sjson.DeleteBytes(byteRequestBody, "request.generationConfig.thinkingConfig.include_thoughts")
byteRequestBody, _ = sjson.SetBytes(byteRequestBody, "request.generationConfig.thinkingConfig.thinkingBudget", 0)
} else if reasoningEffortResult.String() == "auto" {
// Let the model decide the appropriate thinking budget automatically
byteRequestBody, _ = sjson.SetBytes(byteRequestBody, "request.generationConfig.thinkingConfig.thinkingBudget", -1)
} else if reasoningEffortResult.String() == "low" {
// Minimal thinking for simple tasks (1KB thinking budget)
byteRequestBody, _ = sjson.SetBytes(byteRequestBody, "request.generationConfig.thinkingConfig.thinkingBudget", 1024)
} else if reasoningEffortResult.String() == "medium" {
// Moderate thinking for complex tasks (8KB thinking budget)
byteRequestBody, _ = sjson.SetBytes(byteRequestBody, "request.generationConfig.thinkingConfig.thinkingBudget", 8192)
} else if reasoningEffortResult.String() == "high" {
// Maximum thinking for very complex tasks (24KB thinking budget)
byteRequestBody, _ = sjson.SetBytes(byteRequestBody, "request.generationConfig.thinkingConfig.thinkingBudget", 24576)
} else {
// Default to automatic thinking budget if no specific level is provided
byteRequestBody, _ = sjson.SetBytes(byteRequestBody, "request.generationConfig.thinkingConfig.thinkingBudget", -1)
}
// Configure temperature parameter for response randomness control
// Temperature affects the creativity vs consistency trade-off in responses
temperatureResult := gjson.GetBytes(rawJSON, "temperature")
if temperatureResult.Exists() && temperatureResult.Type == gjson.Number {
byteRequestBody, _ = sjson.SetBytes(byteRequestBody, "request.generationConfig.temperature", temperatureResult.Num)
}
// Configure top-p parameter for nucleus sampling
// Controls the cumulative probability threshold for token selection
topPResult := gjson.GetBytes(rawJSON, "top_p")
if topPResult.Exists() && topPResult.Type == gjson.Number {
byteRequestBody, _ = sjson.SetBytes(byteRequestBody, "request.generationConfig.topP", topPResult.Num)
}
// Configure top-k parameter for limiting token candidates
// Restricts the model to consider only the top K most likely tokens
topKResult := gjson.GetBytes(rawJSON, "top_k")
if topKResult.Exists() && topKResult.Type == gjson.Number {
byteRequestBody, _ = sjson.SetBytes(byteRequestBody, "request.generationConfig.topK", topKResult.Num)
}
// Initialize model name for quota management and potential fallback
modelName := model
var stream io.ReadCloser
// Quota management and model fallback loop
// This loop handles quota exceeded scenarios and automatic model switching
for {
// Check if the current model has exceeded its quota
if c.isModelQuotaExceeded(modelName) {
// Attempt to switch to a preview model if configured and using account auth
if c.cfg.QuotaExceeded.SwitchPreviewModel && c.glAPIKey == "" {
modelName = c.getPreviewModel(model)
if modelName != "" {
log.Debugf("Model %s is quota exceeded. Switch to preview model %s", model, modelName)
// Update the request body with the new model name
byteRequestBody, _ = sjson.SetBytes(byteRequestBody, "model", modelName)
continue // Retry with the preview model
}
}
// If no fallback is available, return a quota exceeded error
errChan <- &ErrorMessage{
StatusCode: 429,
Error: fmt.Errorf(`{"error":{"code":429,"message":"All the models of '%s' are quota exceeded","status":"RESOURCE_EXHAUSTED"}}`, model),
}
return
}
// Attempt to establish a streaming connection with the API
var err *ErrorMessage
stream, err = c.APIRequest(ctx, "streamGenerateContent", byteRequestBody, "", true)
if err != nil {
// Handle quota exceeded errors by marking the model and potentially retrying
if err.StatusCode == 429 {
now := time.Now()
c.modelQuotaExceeded[modelName] = &now // Mark model as quota exceeded
// If preview model switching is enabled, retry the loop
if c.cfg.QuotaExceeded.SwitchPreviewModel && c.glAPIKey == "" {
continue
}
}
// Forward other errors to the error channel
errChan <- err
return
}
// Clear any previous quota exceeded status for this model
delete(c.modelQuotaExceeded, modelName)
break // Successfully established connection, exit the retry loop
}
// Process the streaming response using a scanner
// This handles the Server-Sent Events format from the API
scanner := bufio.NewScanner(stream)
for scanner.Scan() {
line := scanner.Bytes()
// Filter and forward only data lines (those prefixed with "data: ")
// This extracts the actual JSON content from the SSE format
if bytes.HasPrefix(line, dataTag) {
dataChan <- line[6:] // Remove "data: " prefix and send the JSON content
}
}
// Handle any scanning errors that occurred during stream processing
if errScanner := scanner.Err(); errScanner != nil {
// Send a 500 Internal Server Error for scanning failures
errChan <- &ErrorMessage{500, errScanner}
_ = stream.Close()
return
}
// Ensure the stream is properly closed to prevent resource leaks
_ = stream.Close()
}()
// Return the channels immediately for asynchronous communication
// The caller can read from these channels while the goroutine processes the request
return dataChan, errChan
}
// SendRawTokenCount handles a token count.
func (c *Client) SendRawTokenCount(ctx context.Context, rawJSON []byte, alt string) ([]byte, *ErrorMessage) {
modelResult := gjson.GetBytes(rawJSON, "model")
model := modelResult.String()
modelName := model
for {
if c.isModelQuotaExceeded(modelName) {
if c.cfg.QuotaExceeded.SwitchPreviewModel && c.glAPIKey == "" {
modelName = c.getPreviewModel(model)
if modelName != "" {
log.Debugf("Model %s is quota exceeded. Switch to preview model %s", model, modelName)
rawJSON, _ = sjson.SetBytes(rawJSON, "model", modelName)
continue
}
}
return nil, &ErrorMessage{
StatusCode: 429,
Error: fmt.Errorf(`{"error":{"code":429,"message":"All the models of '%s' are quota exceeded","status":"RESOURCE_EXHAUSTED"}}`, model),
}
}
respBody, err := c.APIRequest(ctx, "countTokens", rawJSON, alt, false)
if err != nil {
if err.StatusCode == 429 {
now := time.Now()
c.modelQuotaExceeded[modelName] = &now
if c.cfg.QuotaExceeded.SwitchPreviewModel && c.glAPIKey == "" {
continue
}
}
return nil, err
}
delete(c.modelQuotaExceeded, modelName)
bodyBytes, errReadAll := io.ReadAll(respBody)
if errReadAll != nil {
return nil, &ErrorMessage{StatusCode: 500, Error: errReadAll}
}
return bodyBytes, nil
}
}
// SendRawMessage handles a single conversational turn, including tool calls.
func (c *Client) SendRawMessage(ctx context.Context, rawJSON []byte, alt string) ([]byte, *ErrorMessage) {
if c.glAPIKey == "" {
rawJSON, _ = sjson.SetBytes(rawJSON, "project", c.GetProjectID())
}
modelResult := gjson.GetBytes(rawJSON, "model")
model := modelResult.String()
modelName := model
for {
if c.isModelQuotaExceeded(modelName) {
if c.cfg.QuotaExceeded.SwitchPreviewModel && c.glAPIKey == "" {
modelName = c.getPreviewModel(model)
if modelName != "" {
log.Debugf("Model %s is quota exceeded. Switch to preview model %s", model, modelName)
rawJSON, _ = sjson.SetBytes(rawJSON, "model", modelName)
continue
}
}
return nil, &ErrorMessage{
StatusCode: 429,
Error: fmt.Errorf(`{"error":{"code":429,"message":"All the models of '%s' are quota exceeded","status":"RESOURCE_EXHAUSTED"}}`, model),
}
}
respBody, err := c.APIRequest(ctx, "generateContent", rawJSON, alt, false)
if err != nil {
if err.StatusCode == 429 {
now := time.Now()
c.modelQuotaExceeded[modelName] = &now
if c.cfg.QuotaExceeded.SwitchPreviewModel && c.glAPIKey == "" {
continue
}
}
return nil, err
}
delete(c.modelQuotaExceeded, modelName)
bodyBytes, errReadAll := io.ReadAll(respBody)
if errReadAll != nil {
return nil, &ErrorMessage{StatusCode: 500, Error: errReadAll}
}
return bodyBytes, nil
}
}
// SendRawMessageStream handles a single conversational turn, including tool calls.
func (c *Client) SendRawMessageStream(ctx context.Context, rawJSON []byte, alt string) (<-chan []byte, <-chan *ErrorMessage) {
dataTag := []byte("data: ")
errChan := make(chan *ErrorMessage)
dataChan := make(chan []byte)
@@ -410,58 +714,12 @@ func (c *Client) SendMessageStream(ctx context.Context, rawJson []byte, model st
defer close(errChan)
defer close(dataChan)
request := GenerateContentRequest{
Contents: contents,
GenerationConfig: GenerationConfig{
ThinkingConfig: GenerationConfigThinkingConfig{
IncludeThoughts: true,
},
},
}
request.Tools = tools
requestBody := map[string]interface{}{
"project": c.GetProjectID(), // Assuming ProjectID is available
"request": request,
"model": model,
if c.glAPIKey == "" {
rawJSON, _ = sjson.SetBytes(rawJSON, "project", c.GetProjectID())
}
byteRequestBody, _ := json.Marshal(requestBody)
// log.Debug(string(byteRequestBody))
reasoningEffortResult := gjson.GetBytes(rawJson, "reasoning_effort")
if reasoningEffortResult.String() == "none" {
byteRequestBody, _ = sjson.DeleteBytes(byteRequestBody, "request.generationConfig.thinkingConfig.include_thoughts")
byteRequestBody, _ = sjson.SetBytes(byteRequestBody, "request.generationConfig.thinkingConfig.thinkingBudget", 0)
} else if reasoningEffortResult.String() == "auto" {
byteRequestBody, _ = sjson.SetBytes(byteRequestBody, "request.generationConfig.thinkingConfig.thinkingBudget", -1)
} else if reasoningEffortResult.String() == "low" {
byteRequestBody, _ = sjson.SetBytes(byteRequestBody, "request.generationConfig.thinkingConfig.thinkingBudget", 1024)
} else if reasoningEffortResult.String() == "medium" {
byteRequestBody, _ = sjson.SetBytes(byteRequestBody, "request.generationConfig.thinkingConfig.thinkingBudget", 8192)
} else if reasoningEffortResult.String() == "high" {
byteRequestBody, _ = sjson.SetBytes(byteRequestBody, "request.generationConfig.thinkingConfig.thinkingBudget", 24576)
} else {
byteRequestBody, _ = sjson.SetBytes(byteRequestBody, "request.generationConfig.thinkingConfig.thinkingBudget", -1)
}
temperatureResult := gjson.GetBytes(rawJson, "temperature")
if temperatureResult.Exists() && temperatureResult.Type == gjson.Number {
byteRequestBody, _ = sjson.SetBytes(byteRequestBody, "request.generationConfig.temperature", temperatureResult.Num)
}
topPResult := gjson.GetBytes(rawJson, "top_p")
if topPResult.Exists() && topPResult.Type == gjson.Number {
byteRequestBody, _ = sjson.SetBytes(byteRequestBody, "request.generationConfig.topP", topPResult.Num)
}
topKResult := gjson.GetBytes(rawJson, "top_k")
if topKResult.Exists() && topKResult.Type == gjson.Number {
byteRequestBody, _ = sjson.SetBytes(byteRequestBody, "request.generationConfig.topK", topKResult.Num)
}
// log.Debug(string(byteRequestBody))
modelResult := gjson.GetBytes(rawJSON, "model")
model := modelResult.String()
modelName := model
var stream io.ReadCloser
for {
@@ -470,7 +728,7 @@ func (c *Client) SendMessageStream(ctx context.Context, rawJson []byte, model st
modelName = c.getPreviewModel(model)
if modelName != "" {
log.Debugf("Model %s is quota exceeded. Switch to preview model %s", model, modelName)
byteRequestBody, _ = sjson.SetBytes(byteRequestBody, "model", modelName)
rawJSON, _ = sjson.SetBytes(rawJSON, "model", modelName)
continue
}
}
@@ -481,7 +739,7 @@ func (c *Client) SendMessageStream(ctx context.Context, rawJson []byte, model st
return
}
var err *ErrorMessage
stream, err = c.APIRequest(ctx, "streamGenerateContent", byteRequestBody, true)
stream, err = c.APIRequest(ctx, "streamGenerateContent", rawJSON, alt, true)
if err != nil {
if err.StatusCode == 429 {
now := time.Now()
@@ -497,28 +755,39 @@ func (c *Client) SendMessageStream(ctx context.Context, rawJson []byte, model st
break
}
scanner := bufio.NewScanner(stream)
for scanner.Scan() {
line := scanner.Bytes()
// log.Printf("Received stream chunk: %s", line)
if bytes.HasPrefix(line, dataTag) {
dataChan <- line[6:]
if alt == "" {
scanner := bufio.NewScanner(stream)
for scanner.Scan() {
line := scanner.Bytes()
if bytes.HasPrefix(line, dataTag) {
dataChan <- line[6:]
}
}
}
if errScanner := scanner.Err(); errScanner != nil {
// log.Println(err)
errChan <- &ErrorMessage{500, errScanner}
_ = stream.Close()
return
}
if errScanner := scanner.Err(); errScanner != nil {
errChan <- &ErrorMessage{500, errScanner}
_ = stream.Close()
return
}
} else {
data, err := io.ReadAll(stream)
if err != nil {
errChan <- &ErrorMessage{500, err}
_ = stream.Close()
return
}
dataChan <- data
}
_ = stream.Close()
}()
return dataChan, errChan
}
// isModelQuotaExceeded checks if the specified model has exceeded its quota
// within the last 30 minutes.
func (c *Client) isModelQuotaExceeded(model string) bool {
if lastExceededTime, hasKey := c.modelQuotaExceeded[model]; hasKey {
duration := time.Now().Sub(*lastExceededTime)
@@ -530,6 +799,8 @@ func (c *Client) isModelQuotaExceeded(model string) bool {
return false
}
// getPreviewModel returns an available preview model for the given base model,
// or an empty string if no preview models are available or all are quota exceeded.
func (c *Client) getPreviewModel(model string) string {
if models, hasKey := previewModels[model]; hasKey {
for i := 0; i < len(models); i++ {
@@ -541,6 +812,8 @@ func (c *Client) getPreviewModel(model string) string {
return ""
}
// IsModelQuotaExceeded returns true if the specified model has exceeded its quota
// and no fallback options are available.
func (c *Client) IsModelQuotaExceeded(model string) bool {
if c.isModelQuotaExceeded(model) {
if c.cfg.QuotaExceeded.SwitchPreviewModel {
@@ -565,23 +838,24 @@ func (c *Client) CheckCloudAPIIsEnabled() (bool, error) {
// A simple request to test the API endpoint.
requestBody := fmt.Sprintf(`{"project":"%s","request":{"contents":[{"role":"user","parts":[{"text":"Be concise. What is the capital of France?"}]}],"generationConfig":{"thinkingConfig":{"include_thoughts":false,"thinkingBudget":0}}},"model":"gemini-2.5-flash"}`, c.tokenStorage.ProjectID)
stream, err := c.APIRequest(ctx, "streamGenerateContent", []byte(requestBody), true)
stream, err := c.APIRequest(ctx, "streamGenerateContent", []byte(requestBody), "", true)
if err != nil {
// If a 403 Forbidden error occurs, it likely means the API is not enabled.
if err.StatusCode == 403 {
errJson := err.Error.Error()
errJSON := err.Error.Error()
// Check for a specific error code and extract the activation URL.
if gjson.Get(errJson, "error.code").Int() == 403 {
activationUrl := gjson.Get(errJson, "error.details.0.metadata.activationUrl").String()
if activationUrl != "" {
if gjson.Get(errJSON, "0.error.code").Int() == 403 {
activationURL := gjson.Get(errJSON, "0.error.details.0.metadata.activationUrl").String()
if activationURL != "" {
log.Warnf(
"\n\nPlease activate your account with this url:\n\n%s\n And execute this command again:\n%s --login --project_id %s",
activationUrl,
"\n\nPlease activate your account with this url:\n\n%s\n\n And execute this command again:\n%s --login --project_id %s",
activationURL,
os.Args[0],
c.tokenStorage.ProjectID,
)
}
}
log.Warnf("\n\nPlease copy this message and create an issue.\n\n%s\n\n", errJSON)
return false, nil
}
return false, err.Error
@@ -659,10 +933,10 @@ func (c *Client) SaveTokenToFile() error {
// such as IDE type, platform, and plugin version.
func getClientMetadata() map[string]string {
return map[string]string{
"ideType": "IDE_UNSPECIFIED",
"platform": getPlatform(),
"pluginType": "GEMINI",
"pluginVersion": pluginVersion,
"ideType": "IDE_UNSPECIFIED",
"platform": "PLATFORM_UNSPECIFIED",
"pluginType": "GEMINI",
// "pluginVersion": pluginVersion,
}
}
@@ -679,7 +953,8 @@ func getClientMetadataString() string {
// getUserAgent constructs the User-Agent string for HTTP requests.
func getUserAgent() string {
return fmt.Sprintf("GeminiCLI/%s (%s; %s)", pluginVersion, runtime.GOOS, runtime.GOARCH)
// return fmt.Sprintf("GeminiCLI/%s (%s; %s)", pluginVersion, runtime.GOOS, runtime.GOARCH)
return "google-api-nodejs-client/9.15.1"
}
// getPlatform determines the operating system and architecture and formats

View File

@@ -64,9 +64,10 @@ type FunctionResponse struct {
// GenerateContentRequest is the top-level request structure for the streamGenerateContent endpoint.
type GenerateContentRequest struct {
Contents []Content `json:"contents"`
Tools []ToolDeclaration `json:"tools,omitempty"`
GenerationConfig `json:"generationConfig"`
SystemInstruction *Content `json:"systemInstruction,omitempty"`
Contents []Content `json:"contents"`
Tools []ToolDeclaration `json:"tools,omitempty"`
GenerationConfig `json:"generationConfig"`
}
// GenerationConfig defines parameters that control the model's generation behavior.

View File

@@ -1,3 +1,6 @@
// Package cmd provides command-line interface functionality for the CLI Proxy API.
// It implements the main application commands including login/authentication
// and server startup, handling the complete user onboarding and service lifecycle.
package cmd
import (
@@ -73,6 +76,7 @@ func DoLogin(cfg *config.Config, projectID string) {
// If the check fails (returns false), the CheckCloudAPIIsEnabled function
// will have already printed instructions, so we can just exit.
if !isChecked {
log.Fatal("Failed to check if Cloud AI API is enabled. If you encounter an error message, please create an issue.")
return
}
}

View File

@@ -1,3 +1,8 @@
// Package cmd provides the main service execution functionality for the CLIProxyAPI.
// It contains the core logic for starting and managing the API proxy service,
// including authentication client management, server initialization, and graceful shutdown handling.
// The package handles loading authentication tokens, creating client pools, starting the API server,
// and monitoring configuration changes through file watchers.
package cmd
import (
@@ -7,12 +12,11 @@ import (
"github.com/luispater/CLIProxyAPI/internal/auth"
"github.com/luispater/CLIProxyAPI/internal/client"
"github.com/luispater/CLIProxyAPI/internal/config"
"github.com/luispater/CLIProxyAPI/internal/util"
"github.com/luispater/CLIProxyAPI/internal/watcher"
log "github.com/sirupsen/logrus"
"golang.org/x/net/proxy"
"io/fs"
"net"
"net/http"
"net/url"
"os"
"os/signal"
"path/filepath"
@@ -24,7 +28,7 @@ import (
// StartService initializes and starts the main API proxy service.
// It loads all available authentication tokens, creates a pool of clients,
// starts the API server, and handles graceful shutdown signals.
func StartService(cfg *config.Config) {
func StartService(cfg *config.Config, configPath string) {
// Create a pool of API clients, one for each token file found.
cliClients := make([]*client.Client, 0)
err := filepath.Walk(cfg.AuthDir, func(path string, info fs.FileInfo, err error) error {
@@ -69,33 +73,12 @@ func StartService(cfg *config.Config) {
}
if len(cfg.GlAPIKey) > 0 {
var transport *http.Transport
proxyURL, errParse := url.Parse(cfg.ProxyUrl)
if errParse == nil {
if proxyURL.Scheme == "socks5" {
username := proxyURL.User.Username()
password, _ := proxyURL.User.Password()
proxyAuth := &proxy.Auth{User: username, Password: password}
dialer, errSOCKS5 := proxy.SOCKS5("tcp", proxyURL.Host, proxyAuth, proxy.Direct)
if errSOCKS5 != nil {
log.Fatalf("create SOCKS5 dialer failed: %v", errSOCKS5)
}
transport = &http.Transport{
DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
return dialer.Dial(network, addr)
},
}
} else if proxyURL.Scheme == "http" || proxyURL.Scheme == "https" {
// Handle HTTP/HTTPS proxy.
transport = &http.Transport{Proxy: http.ProxyURL(proxyURL)}
}
}
for i := 0; i < len(cfg.GlAPIKey); i++ {
httpClient := &http.Client{}
if transport != nil {
httpClient.Transport = transport
httpClient, errSetProxy := util.SetProxy(cfg, &http.Client{})
if errSetProxy != nil {
log.Fatalf("set proxy failed: %v", errSetProxy)
}
log.Debug("Initializing with Generative Language API key...")
cliClient := client.NewClient(httpClient, nil, cfg, cfg.GlAPIKey[i])
cliClients = append(cliClients, cliClient)
@@ -105,10 +88,46 @@ func StartService(cfg *config.Config) {
// Create and start the API server with the pool of clients.
apiServer := api.NewServer(cfg, cliClients)
log.Infof("Starting API server on port %d", cfg.Port)
if err = apiServer.Start(); err != nil {
log.Fatalf("API server failed to start: %v", err)
// Start the API server in a goroutine so it doesn't block the main thread
go func() {
if err = apiServer.Start(); err != nil {
log.Fatalf("API server failed to start: %v", err)
}
}()
// Give the server a moment to start up
time.Sleep(100 * time.Millisecond)
log.Info("API server started successfully")
// Setup file watcher for config and auth directory changes
fileWatcher, errNewWatcher := watcher.NewWatcher(configPath, cfg.AuthDir, func(newClients []*client.Client, newCfg *config.Config) {
// Update the API server with new clients and configuration
apiServer.UpdateClients(newClients, newCfg)
})
if errNewWatcher != nil {
log.Fatalf("failed to create file watcher: %v", errNewWatcher)
}
// Set initial state for the watcher
fileWatcher.SetConfig(cfg)
fileWatcher.SetClients(cliClients)
// Start the file watcher
watcherCtx, watcherCancel := context.WithCancel(context.Background())
if errStartWatcher := fileWatcher.Start(watcherCtx); errStartWatcher != nil {
log.Fatalf("failed to start file watcher: %v", errStartWatcher)
}
log.Info("file watcher started for config and auth directory changes")
defer func() {
watcherCancel()
errStopWatcher := fileWatcher.Stop()
if errStopWatcher != nil {
log.Errorf("error stopping file watcher: %v", errStopWatcher)
}
}()
// Set up a channel to listen for OS signals for graceful shutdown.
sigChan := make(chan os.Signal, 1)
signal.Notify(sigChan, syscall.SIGINT, syscall.SIGTERM)

View File

@@ -1,3 +1,7 @@
// Package config provides configuration management for the CLI Proxy API server.
// It handles loading and parsing YAML configuration files, and provides structured
// access to application settings including server port, authentication directory,
// debug settings, proxy configuration, and API keys.
package config
import (
@@ -14,17 +18,19 @@ type Config struct {
AuthDir string `yaml:"auth-dir"`
// Debug enables or disables debug-level logging and other debug features.
Debug bool `yaml:"debug"`
// ProxyUrl is the URL of an optional proxy server to use for outbound requests.
ProxyUrl string `yaml:"proxy-url"`
// ApiKeys is a list of keys for authenticating clients to this proxy server.
ApiKeys []string `yaml:"api-keys"`
// ProxyURL is the URL of an optional proxy server to use for outbound requests.
ProxyURL string `yaml:"proxy-url"`
// APIKeys is a list of keys for authenticating clients to this proxy server.
APIKeys []string `yaml:"api-keys"`
// QuotaExceeded defines the behavior when a quota is exceeded.
QuotaExceeded ConfigQuotaExceeded `yaml:"quota-exceeded"`
QuotaExceeded QuotaExceeded `yaml:"quota-exceeded"`
// GlAPIKey is the API key for the generative language API.
GlAPIKey []string `yaml:"generative-language-api-key"`
}
type ConfigQuotaExceeded struct {
// QuotaExceeded defines the behavior when API quota limits are exceeded.
// It provides configuration options for automatic failover mechanisms.
type QuotaExceeded struct {
// SwitchProject indicates whether to automatically switch to another project when a quota is exceeded.
SwitchProject bool `yaml:"switch-project"`
// SwitchPreviewModel indicates whether to automatically switch to a preview model when a quota is exceeded.

43
internal/util/proxy.go Normal file
View File

@@ -0,0 +1,43 @@
// Package util provides utility functions for the CLI Proxy API server.
// It includes helper functions for proxy configuration, HTTP client setup,
// and other common operations used across the application.
package util
import (
"context"
"github.com/luispater/CLIProxyAPI/internal/config"
"golang.org/x/net/proxy"
"net"
"net/http"
"net/url"
)
// SetProxy configures the provided HTTP client with proxy settings from the configuration.
// It supports SOCKS5, HTTP, and HTTPS proxies. The function modifies the client's transport
// to route requests through the configured proxy server.
func SetProxy(cfg *config.Config, httpClient *http.Client) (*http.Client, error) {
var transport *http.Transport
proxyURL, errParse := url.Parse(cfg.ProxyURL)
if errParse == nil {
if proxyURL.Scheme == "socks5" {
username := proxyURL.User.Username()
password, _ := proxyURL.User.Password()
proxyAuth := &proxy.Auth{User: username, Password: password}
dialer, errSOCKS5 := proxy.SOCKS5("tcp", proxyURL.Host, proxyAuth, proxy.Direct)
if errSOCKS5 != nil {
return nil, errSOCKS5
}
transport = &http.Transport{
DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
return dialer.Dial(network, addr)
},
}
} else if proxyURL.Scheme == "http" || proxyURL.Scheme == "https" {
transport = &http.Transport{Proxy: http.ProxyURL(proxyURL)}
}
}
if transport != nil {
httpClient.Transport = transport
}
return httpClient, nil
}

286
internal/watcher/watcher.go Normal file
View File

@@ -0,0 +1,286 @@
// Package watcher provides file system monitoring functionality for the CLI Proxy API.
// It watches configuration files and authentication directories for changes,
// automatically reloading clients and configuration when files are modified.
// The package handles cross-platform file system events and supports hot-reloading.
package watcher
import (
"context"
"encoding/json"
"github.com/fsnotify/fsnotify"
"github.com/luispater/CLIProxyAPI/internal/auth"
"github.com/luispater/CLIProxyAPI/internal/client"
"github.com/luispater/CLIProxyAPI/internal/config"
"github.com/luispater/CLIProxyAPI/internal/util"
log "github.com/sirupsen/logrus"
"io/fs"
"net/http"
"os"
"path/filepath"
"strings"
"sync"
"time"
)
// Watcher manages file watching for configuration and authentication files
type Watcher struct {
configPath string
authDir string
config *config.Config
clients []*client.Client
clientsMutex sync.RWMutex
reloadCallback func([]*client.Client, *config.Config)
watcher *fsnotify.Watcher
}
// NewWatcher creates a new file watcher instance
func NewWatcher(configPath, authDir string, reloadCallback func([]*client.Client, *config.Config)) (*Watcher, error) {
watcher, errNewWatcher := fsnotify.NewWatcher()
if errNewWatcher != nil {
return nil, errNewWatcher
}
return &Watcher{
configPath: configPath,
authDir: authDir,
reloadCallback: reloadCallback,
watcher: watcher,
}, nil
}
// Start begins watching the configuration file and authentication directory
func (w *Watcher) Start(ctx context.Context) error {
// Watch the config file
if errAddConfig := w.watcher.Add(w.configPath); errAddConfig != nil {
log.Errorf("failed to watch config file %s: %v", w.configPath, errAddConfig)
return errAddConfig
}
log.Debugf("watching config file: %s", w.configPath)
// Watch the auth directory
if errAddAuthDir := w.watcher.Add(w.authDir); errAddAuthDir != nil {
log.Errorf("failed to watch auth directory %s: %v", w.authDir, errAddAuthDir)
return errAddAuthDir
}
log.Debugf("watching auth directory: %s", w.authDir)
// Start the event processing goroutine
go w.processEvents(ctx)
return nil
}
// Stop stops the file watcher
func (w *Watcher) Stop() error {
return w.watcher.Close()
}
// SetConfig updates the current configuration
func (w *Watcher) SetConfig(cfg *config.Config) {
w.clientsMutex.Lock()
defer w.clientsMutex.Unlock()
w.config = cfg
}
// SetClients updates the current client list
func (w *Watcher) SetClients(clients []*client.Client) {
w.clientsMutex.Lock()
defer w.clientsMutex.Unlock()
w.clients = clients
}
// processEvents handles file system events
func (w *Watcher) processEvents(ctx context.Context) {
for {
select {
case <-ctx.Done():
return
case event, ok := <-w.watcher.Events:
if !ok {
return
}
w.handleEvent(event)
case errWatch, ok := <-w.watcher.Errors:
if !ok {
return
}
log.Errorf("file watcher error: %v", errWatch)
}
}
}
// handleEvent processes individual file system events
func (w *Watcher) handleEvent(event fsnotify.Event) {
now := time.Now()
log.Debugf("file system event detected: %s %s", event.Op.String(), event.Name)
// Handle config file changes
if event.Name == w.configPath && (event.Op&fsnotify.Write == fsnotify.Write || event.Op&fsnotify.Create == fsnotify.Create) {
log.Infof("config file changed, reloading: %s", w.configPath)
log.Debugf("config file change details - operation: %s, timestamp: %s", event.Op.String(), now.Format("2006-01-02 15:04:05.000"))
w.reloadConfig()
return
}
// Handle auth directory changes (only for .json files)
// Simplified: reload on any change to .json files in auth directory
if strings.HasPrefix(event.Name, w.authDir) && strings.HasSuffix(event.Name, ".json") {
log.Infof("auth file changed (%s): %s, reloading clients", event.Op.String(), filepath.Base(event.Name))
log.Debugf("auth file change details - operation: %s, file: %s, timestamp: %s",
event.Op.String(), filepath.Base(event.Name), now.Format("2006-01-02 15:04:05.000"))
w.reloadClients()
}
}
// reloadConfig reloads the configuration and triggers a full reload
func (w *Watcher) reloadConfig() {
log.Debugf("starting config reload from: %s", w.configPath)
newConfig, errLoadConfig := config.LoadConfig(w.configPath)
if errLoadConfig != nil {
log.Errorf("failed to reload config: %v", errLoadConfig)
return
}
w.clientsMutex.Lock()
oldConfig := w.config
w.config = newConfig
w.clientsMutex.Unlock()
// Log configuration changes in debug mode
if oldConfig != nil {
log.Debugf("config changes detected:")
if oldConfig.Port != newConfig.Port {
log.Debugf(" port: %d -> %d", oldConfig.Port, newConfig.Port)
}
if oldConfig.AuthDir != newConfig.AuthDir {
log.Debugf(" auth-dir: %s -> %s", oldConfig.AuthDir, newConfig.AuthDir)
}
if oldConfig.Debug != newConfig.Debug {
log.Debugf(" debug: %t -> %t", oldConfig.Debug, newConfig.Debug)
}
if oldConfig.ProxyURL != newConfig.ProxyURL {
log.Debugf(" proxy-url: %s -> %s", oldConfig.ProxyURL, newConfig.ProxyURL)
}
if len(oldConfig.APIKeys) != len(newConfig.APIKeys) {
log.Debugf(" api-keys count: %d -> %d", len(oldConfig.APIKeys), len(newConfig.APIKeys))
}
if len(oldConfig.GlAPIKey) != len(newConfig.GlAPIKey) {
log.Debugf(" generative-language-api-key count: %d -> %d", len(oldConfig.GlAPIKey), len(newConfig.GlAPIKey))
}
}
log.Infof("config successfully reloaded, triggering client reload")
// Reload clients with new config
w.reloadClients()
}
// reloadClients reloads all authentication clients
func (w *Watcher) reloadClients() {
log.Debugf("starting client reload process")
w.clientsMutex.RLock()
cfg := w.config
oldClientCount := len(w.clients)
w.clientsMutex.RUnlock()
if cfg == nil {
log.Error("config is nil, cannot reload clients")
return
}
log.Debugf("scanning auth directory: %s", cfg.AuthDir)
// Create new client list
newClients := make([]*client.Client, 0)
authFileCount := 0
successfulAuthCount := 0
// Load clients from auth directory
errWalk := filepath.Walk(cfg.AuthDir, func(path string, info fs.FileInfo, err error) error {
if err != nil {
log.Debugf("error accessing path %s: %v", path, err)
return err
}
// Process only JSON files in the auth directory
if !info.IsDir() && strings.HasSuffix(info.Name(), ".json") {
authFileCount++
log.Debugf("processing auth file %d: %s", authFileCount, filepath.Base(path))
f, errOpen := os.Open(path)
if errOpen != nil {
log.Errorf("failed to open token file %s: %v", path, errOpen)
return nil // Continue processing other files
}
defer func() {
errClose := f.Close()
if errClose != nil {
log.Errorf("failed to close token file %s: %v", path, errClose)
}
}()
// Decode the token storage file
var ts auth.TokenStorage
if errDecode := json.NewDecoder(f).Decode(&ts); errDecode == nil {
// For each valid token, create an authenticated client
clientCtx := context.Background()
log.Debugf(" initializing authentication for token from %s...", filepath.Base(path))
httpClient, errGetClient := auth.GetAuthenticatedClient(clientCtx, &ts, cfg)
if errGetClient != nil {
log.Errorf(" failed to get authenticated client for token %s: %v", path, errGetClient)
return nil // Continue processing other files
}
log.Debugf(" authentication successful for token from %s", filepath.Base(path))
// Add the new client to the pool
cliClient := client.NewClient(httpClient, &ts, cfg)
newClients = append(newClients, cliClient)
successfulAuthCount++
} else {
log.Errorf(" failed to decode token file %s: %v", path, errDecode)
}
}
return nil
})
if errWalk != nil {
log.Errorf("error walking auth directory: %v", errWalk)
return
}
log.Debugf("auth directory scan complete - found %d .json files, %d successful authentications", authFileCount, successfulAuthCount)
// Add clients for Generative Language API keys if configured
glAPIKeyCount := 0
if len(cfg.GlAPIKey) > 0 {
log.Debugf("processing %d Generative Language API keys", len(cfg.GlAPIKey))
for i := 0; i < len(cfg.GlAPIKey); i++ {
httpClient, errSetProxy := util.SetProxy(cfg, &http.Client{})
if errSetProxy != nil {
log.Errorf("set proxy failed for GL API key %d: %v", i+1, errSetProxy)
continue
}
log.Debugf(" initializing with Generative Language API key %d...", i+1)
cliClient := client.NewClient(httpClient, nil, cfg, cfg.GlAPIKey[i])
newClients = append(newClients, cliClient)
glAPIKeyCount++
}
log.Debugf("successfully initialized %d Generative Language API key clients", glAPIKeyCount)
}
// Update the client list
w.clientsMutex.Lock()
w.clients = newClients
w.clientsMutex.Unlock()
log.Infof("client reload complete - old: %d clients, new: %d clients (%d auth files + %d GL API keys)",
oldClientCount, len(newClients), successfulAuthCount, glAPIKeyCount)
// Trigger the callback to update the server
if w.reloadCallback != nil {
log.Debugf("triggering server update callback")
w.reloadCallback(newClients, cfg)
}
}