Compare commits

...

37 Commits

Author SHA1 Message Date
Luis Pater
3b4634e2dc Improve getClient logic with optional content generation flag
- Added `isGenerateContent` optional parameter to `getClient` for conditional client selection.
- Updated `gemini-handlers` to utilize the new parameter for enhanced control.
2025-07-27 02:30:08 +08:00
Luis Pater
00bd6a3e46 Update .goreleaser.yml to include config.example.yaml instead of config.yaml in release assets 2025-07-26 22:19:33 +08:00
Luis Pater
5812229d9b Add .gitignore and ignore config.yaml 2025-07-26 22:10:07 +08:00
Luis Pater
0b026933a7 Update example configuration file (config.example.yaml) 2025-07-26 22:08:25 +08:00
Luis Pater
3b2ab0d7bd Fix SSE headers initialization for geminiStreamGenerateContent and internalStreamGenerateContent
- Added conditional logic to properly initialize SSE headers only when `alt` is empty.
- Ensured headers like `Content-Type`, `Cache-Control`, and `Access-Control-Allow-Origin` are set for better compatibility.
2025-07-26 17:16:55 +08:00
Luis Pater
e64fa48823 Enhance Gemini request handling with fallback support for contents
- Added conditional logic to support `contents` as a fallback to `generateContentRequest`.
- Improved template construction and ensured proper cleanup of request fields.
- Introduced debug logging for troubleshooting request generation.
2025-07-26 17:04:14 +08:00
Luis Pater
beff9282f6 Fix alt parameter handling in URL construction
- Ensured `alt` parameter is only appended when non-empty.
- Added debug logging for constructed URLs.
2025-07-26 15:51:04 +08:00
Luis Pater
31a9e2d11f Add GeminiGetHandler, enhance Gemini functionality, and enable token counting
- Added `GeminiGetHandler` for handling GET requests with extended Gemini model support.
- Introduced `geminiCountTokens` function to calculate token usage.
- Refactored `APIRequest` and related methods to support `alt` parameter for enhanced flexibility.
- Updated routes and request processing to integrate new handler and functions.
2025-07-26 06:51:49 +08:00
Luis Pater
423faae3da Add GeminiModels handler and enhance API key validation
- Introduced `GeminiModels` handler to serve Gemini model information under `/v1beta/models`.
- Updated `AuthMiddleware` to validate API keys from query parameters for improved flexibility.
- Adjusted route to use the new handler for model retrieval.
2025-07-26 04:41:55 +08:00
Luis Pater
ead71fb7ef Improve error logging and add user guidance for issue reporting
- Added fatal log in `login.go` for Cloud AI API enablement check failures, prompting users to report issues.
- Enhanced error logging in `client.go` with warning messages directing users to copy and provide error details when creating issues.
2025-07-24 04:51:09 +08:00
Luis Pater
58b7afdf1e Enhance HTTP server with custom multiplexer in Auth flow
- Replaced default `http` handler with `http.ServeMux` for improved routing control.
- Refactored callback handling to utilize the custom multiplexer.
2025-07-23 05:09:05 +08:00
Luis Pater
c86545d7e1 Add Chinese README and update project files
- Introduced `README_CN.md` to provide detailed documentation in Chinese.
- Updated `.goreleaser.yml` to include the new README file in release assets.
- Enhanced `README.md` with a language toggle link for improved accessibility.
2025-07-21 11:23:13 +08:00
Luis Pater
f49a530c1a Refactor client handling and improve error responses
- Centralized client retrieval logic with `getClient` function for reduced redundancy.
- Simplified client rotation and error handling by removing excessive load balancing logic.
- Updated server address in `auth.go` to use dynamic binding (`:8085`).
2025-07-15 17:03:18 +08:00
Luis Pater
368796349e Add Docker support with CI/CD workflow and usage instructions
- Added `.github/workflows/docker-image.yml` for automated Docker image build and push on version tags.
- Created `Dockerfile` to containerize the application.
- Updated README with instructions for running the application using Docker.
2025-07-14 16:50:51 +08:00
Luis Pater
c601542f6f Add ClaudeMessages handler for SSE-compatible chat completions
- Introduced `ClaudeMessages` to handle Claude-compatible streaming chat completions.
- Implemented client rotation, quota management, and dynamic model name mapping for better load balancing and resource utilization.
- Enhanced response streaming with real-time chunking and Claude format conversion.
- Added error handling for quota exhaustion, client disconnections, and backend failures.
2025-07-11 13:53:09 +08:00
Luis Pater
3c0c61aaf1 Add Claude compatibility and enhance API handling
- Integrated Claude API compatibility in handlers, translators, and server routes.
- Introduced `/messages` endpoint and upgraded `AuthMiddleware` for `X-Api-Key` header.
- Improved streaming response handling with `ConvertCliToClaude` for SSE compatibility.
- Enhanced request processing and tool-response mapping in translators.
- Updated README to reflect Claude integration and clarify supported features.
2025-07-11 13:46:27 +08:00
Luis Pater
edeadfc389 Restrict CLI access to localhost and update README for Gemini compatibility
- Added localhost-only access restriction to `CLIHandler` for security.
- Updated README to reflect Gemini-compatible API and local access limitation notes.
2025-07-11 10:57:23 +08:00
Luis Pater
aa9fd057fe Add FixCLIToolResponse for enhanced function call-response mapping
- Introduced `FixCLIToolResponse` in `translator` to group function calls with corresponding responses.
- Updated Gemini handlers to integrate new function for improved response handling.
- Enhanced error handling in case response mapping fails.
2025-07-11 10:17:25 +08:00
Luis Pater
b3607d3981 Add Gemini-compatible API and improve error handling
- Introduced a new Gemini-compatible API with routes under `/v1beta`.
- Added `GeminiHandler` to manage `generateContent` and `streamGenerateContent` actions.
- Enhanced `AuthMiddleware` to support `X-Goog-Api-Key` header.
- Improved client metadata handling and added conditional project ID updates in API calls.
- Updated logging to debug raw API request payloads for better traceability.
2025-07-11 04:01:45 +08:00
Luis Pater
fa8d94971f Enhance response and request handling in translators
- Refactored response handling to process multiple content parts effectively.
- Improved `tool_calls` structure with unique ID generation and enhanced mapping logic.
- Simplified `SystemInstruction` and tool message parsing in requests for better accuracy.
- Enhanced handling of function calls and tool responses with improved data integration.
2025-07-10 22:26:04 +08:00
Luis Pater
ef68a97526 Refactor API handlers and proxy logic
- Centralized `getClient` logic into a dedicated function to reduce redundancy.
- Moved proxy initialization to a new utility function `SetProxy` in `internal/util/proxy.go`.
- Replaced `Internal` handler with `CLIHandler` in `server.go` for improved clarity and consistency.
- Removed unused functions and redundant HTTP client setup across the codebase for better maintainability.
2025-07-10 17:45:28 +08:00
Luis Pater
d880d1a1ea Set the http request header and update client metadata handling 2025-07-10 14:02:10 +08:00
Luis Pater
d4104214ed Updated README.md 2025-07-10 05:31:55 +08:00
Luis Pater
273e1d9cbe Add system instruction support and enhance internal API handlers
- Introduced `SystemInstruction` field in `PrepareRequest` and `GenerateContentRequest` for better message parsing.
- Updated `SendMessage` and `SendMessageStream` to handle system instructions in client API calls.
- Enhanced error handling and manual flushing logic in response flows.
- Added new internal API endpoints `/v1internal:generateContent` and `/v1internal:streamGenerateContent`.
- Improved proxy handling and transport logic in HTTP client initialization.
2025-07-10 05:16:54 +08:00
Luis Pater
65f47c196a Merge pull request #1 from chaudhryfaisal/main
Correct config in README.md
2025-07-09 16:57:19 +08:00
Faisal Chaudhry
9be56fe8e0 Correct config in README.md 2025-07-08 23:28:55 -04:00
Luis Pater
589ae6d3aa Add support for Generative Language API Key and improve client initialization
- Added `GlAPIKey` support in configuration to enable Generative Language API.
- Integrated `GenerativeLanguageAPIKey` handling in client and API handlers.
- Updated response translators to manage generative language responses properly.
- Enhanced HTTP client initialization logic with proxy support for API requests.
- Refactored streaming and non-streaming flows to account for generative language-specific logic.
2025-07-06 02:13:11 +08:00
Luis Pater
7cb76ae1a5 Enhance quota management and refactor configuration handling
- Introduced `QuotaExceeded` settings in configuration to handle quota limits more effectively.
- Added preview model switching logic to `Client` to automatically use fallback models on quota exhaustion.
- Refactored `APIHandlers` to leverage new configuration structure.
- Simplified server initialization and removed redundant `ServerConfig` structure.
- Streamlined client initialization by unifying configuration handling throughout the project.
- Improved error handling and response mechanisms in both streaming and non-streaming flows.
2025-07-05 07:53:46 +08:00
Luis Pater
e73f165070 Refactor API handlers to streamline response handling
- Replaced channel-based handling in `SendMessage` flow with direct synchronous execution.
- Introduced `hasFirstResponse` flag to manage keep-alive signals in streaming handler.
- Simplified error handling and removed redundant code for enhanced readability and maintainability.
2025-07-05 04:10:00 +08:00
Luis Pater
512f2d5247 Refactor API request flow and streamline response handling
- Replaced `SendMessageStream` with synchronous `SendMessage` in API handlers for better manageability.
- Simplified `ConvertCliToOpenAINonStream` to reduce complexity and improve efficiency.
- Adjusted `client.go` functions to handle both streaming and non-streaming API requests more effectively.
- Improved error handling and channel communication in API handlers.
- Removed redundant and unused code for cleaner implementation.
2025-07-05 02:27:34 +08:00
Luis Pater
bf086464dd Add archive configuration to .goreleaser.yml
- Included LICENSE, README.md, and config.yaml in the archive section for cli-proxy-api.
2025-07-04 18:50:55 +08:00
Luis Pater
5ec6450c50 Numerous Comments Added and Extensive Optimization Performed using Roo-Code with CLIProxyAPI itself. 2025-07-04 18:44:55 +08:00
Luis Pater
8dd7f8e82f Update model name to include release date in API handlers 2025-07-04 17:26:23 +08:00
Luis Pater
582280f4c5 Refactor token management, client initialization, and project handling
- Consolidated `TokenStorage` struct into `internal/auth/models.go` for better organization.
- Updated `Client` to use `TokenStorage` for managing email and project ID.
- Simplified `SetupUser` method to ensure proper token and project assignment.
- Refactored API handlers to leverage new `GetEmail` and `GetProjectID` methods in `Client`.
- Cleanup: Removed unused structures and redundant code from `client.go` and `auth.go`.
- Adjusted CLI flow in `login.go` and `run.go` for streamlined user onboarding.
2025-07-04 17:08:58 +08:00
Luis Pater
57ead9a4bc Refactor user onboarding and token management
- Enhanced the `Client` initialization to include `TokenStorage` and configuration parameters.
- Replaced `SaveTokenToFile` with a `Client` method for better encapsulation.
- Improved onboarding flow with project ID verification and API enablement checks.
- Refactored token saving logic to ensure proper handling of directory creation and JSON encoding.
- Removed unused file-related code in `auth.go` for improved maintainability.
2025-07-04 07:53:07 +08:00
Luis Pater
79acea5976 Refactor authentication and service initialization code
- Moved login and service management logic to `internal/cmd` package (`login.go` and `run.go`).
- Introduced `DoLogin` and `StartService` functions for modularity.
- Enhanced error handling by using structured `ErrorMessage` in `Client`.
- Improved token file saving process and added project-specific token identification.
- Updated API handlers to handle more detailed error responses, including status codes.
2025-07-04 00:43:15 +08:00
Luis Pater
d29245666e Add SOCKS5 and HTTP/HTTPS proxy support
- Updated `GetAuthenticatedClient` to handle proxy configuration via `proxy-url`.
- Extended `Config` to include `proxy-url` property.
- Adjusted error handling and removed unused JSON error response logic for API handlers.
- Updated documentation and configuration examples to reflect new proxy settings.
2025-07-03 16:50:20 +08:00
26 changed files with 3538 additions and 1002 deletions

42
.github/workflows/docker-image.yml vendored Normal file
View File

@@ -0,0 +1,42 @@
name: docker-image
on:
push:
tags:
- v*
env:
APP_NAME: CLIProxyAPI
DOCKERHUB_REPO: eceasy/cli-proxy-api
jobs:
docker:
runs-on: ubuntu-latest
steps:
- name: Checkout
uses: actions/checkout@v4
- name: Set up QEMU
uses: docker/setup-qemu-action@v3
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@v3
- name: Login to DockerHub
uses: docker/login-action@v3
with:
username: ${{ secrets.DOCKERHUB_USERNAME }}
password: ${{ secrets.DOCKERHUB_TOKEN }}
- name: Generate App Version
run: echo APP_VERSION=`git describe --tags --always` >> $GITHUB_ENV
- name: Build and push
uses: docker/build-push-action@v6
with:
context: .
platforms: |
linux/amd64
linux/arm64
push: true
build-args: |
APP_NAME=${{ env.APP_NAME }}
APP_VERSION=${{ env.APP_VERSION }}
tags: |
${{ env.DOCKERHUB_REPO }}:latest
${{ env.DOCKERHUB_REPO }}:${{ env.APP_VERSION }}

1
.gitignore vendored Normal file
View File

@@ -0,0 +1 @@
config.yaml

View File

@@ -8,4 +8,11 @@ builds:
- amd64
- arm64
main: ./cmd/server/
binary: cli-proxy-api
binary: cli-proxy-api
archives:
- id: "cli-proxy-api"
files:
- LICENSE
- README.md
- README_CN.md
- config.example.yaml

23
Dockerfile Normal file
View File

@@ -0,0 +1,23 @@
FROM golang:1.24-alpine AS builder
WORKDIR /app
COPY go.mod go.sum ./
RUN go mod download
COPY . .
RUN CGO_ENABLED=0 GOOS=linux go build -o ./CLIProxyAPI ./cmd/server/
FROM alpine:3.22.0
RUN mkdir /CLIProxyAPI
COPY --from=builder ./app/CLIProxyAPI /CLIProxyAPI/CLIProxyAPI
WORKDIR /CLIProxyAPI
EXPOSE 8317
CMD ["./CLIProxyAPI"]

View File

@@ -1,15 +1,19 @@
# CLI Proxy API
A proxy server that provides an OpenAI-compatible API interface for CLI. This allows you to use CLI models with tools and libraries designed for the OpenAI API.
English | [中文](README_CN.md)
A proxy server that provides an OpenAI/Gemini/Claude compatible API interface for CLI. This allows you to use CLI models with tools and libraries designed for the OpenAI/Gemini/Claude API.
## Features
- OpenAI-compatible API endpoints for CLI models
- OpenAI/Gemini/Claude compatible API endpoints for CLI models
- Support for both streaming and non-streaming responses
- Function calling/tools support
- Multimodal input support (text and images)
- Multiple account support with load balancing
- Simple CLI authentication flow
- Support for Generative Language API Key
- Support Gemini CLI with multiple account load balancing
## Installation
@@ -134,7 +138,7 @@ console.log(response.choices[0].message.content);
- gemini-2.5-pro
- gemini-2.5-flash
- And various preview versions
- And it automates switching to various preview versions
## Configuration
@@ -146,12 +150,17 @@ The server uses a YAML configuration file (`config.yaml`) located in the project
### Configuration Options
| Parameter | Type | Default | Description |
|-----------|------|--------------------|-------------|
| `port` | integer | 8317 | The port number on which the server will listen |
| `auth_dir` | string | "~/.cli-proxy-api" | Directory where authentication tokens are stored. Supports using `~` for home directory |
| `debug` | boolean | false | Enable debug mode for verbose logging |
| `api_keys` | string[] | [] | List of API keys that can be used to authenticate requests |
| Parameter | Type | Default | Description |
|---------------------------------------|----------|--------------------|----------------------------------------------------------------------------------------------|
| `port` | integer | 8317 | The port number on which the server will listen |
| `auth-dir` | string | "~/.cli-proxy-api" | Directory where authentication tokens are stored. Supports using `~` for home directory |
| `proxy-url` | string | "" | Proxy url, support socks5/http/https protocol, example: socks5://user:pass@192.168.1.1:1080/ |
| `quota-exceeded` | object | {} | Configuration for handling quota exceeded |
| `quota-exceeded.switch-project` | boolean | true | Whether to automatically switch to another project when a quota is exceeded |
| `quota-exceeded.switch-preview-model` | boolean | true | Whether to automatically switch to a preview model when a quota is exceeded |
| `debug` | boolean | false | Enable debug mode for verbose logging |
| `api-keys` | string[] | [] | List of API keys that can be used to authenticate requests |
| `generative-language-api-key` | string[] | [] | List of Generative Language API keys |
### Example Configuration File
@@ -160,29 +169,76 @@ The server uses a YAML configuration file (`config.yaml`) located in the project
port: 8317
# Authentication directory (supports ~ for home directory)
auth_dir: "~/.cli-proxy-api"
auth-dir: "~/.cli-proxy-api"
# Enable debug logging
debug: false
# Proxy url, support socks5/http/https protocol, example: socks5://user:pass@192.168.1.1:1080/
proxy-url: ""
# Quota exceeded behavior
quota-exceeded:
switch-project: true # Whether to automatically switch to another project when a quota is exceeded
switch-preview-model: true # Whether to automatically switch to a preview model when a quota is exceeded
# API keys for authentication
api_keys:
api-keys:
- "your-api-key-1"
- "your-api-key-2"
# API keys for official Generative Language API
generative-language-api-key:
- "AIzaSy...01"
- "AIzaSy...02"
- "AIzaSy...03"
- "AIzaSy...04"
```
### Authentication Directory
The `auth_dir` parameter specifies where authentication tokens are stored. When you run the login command, the application will create JSON files in this directory containing the authentication tokens for your Google accounts. Multiple accounts can be used for load balancing.
The `auth-dir` parameter specifies where authentication tokens are stored. When you run the login command, the application will create JSON files in this directory containing the authentication tokens for your Google accounts. Multiple accounts can be used for load balancing.
### API Keys
The `api_keys` parameter allows you to define a list of API keys that can be used to authenticate requests to your proxy server. When making requests to the API, you can include one of these keys in the `Authorization` header:
The `api-keys` parameter allows you to define a list of API keys that can be used to authenticate requests to your proxy server. When making requests to the API, you can include one of these keys in the `Authorization` header:
```
Authorization: Bearer your-api-key-1
```
### Official Generative Language API
The `generative-language-api-key` parameter allows you to define a list of API keys that can be used to authenticate requests to the official Generative Language API.
## Gemini CLI with multiple account load balancing
Start CLI Proxy API server, and then set the `CODE_ASSIST_ENDPOINT` environment variable to the URL of the CLI Proxy API server.
```bash
export CODE_ASSIST_ENDPOINT="http://127.0.0.1:8317"
```
The server will relay the `loadCodeAssist`, `onboardUser`, and `countTokens` requests. And automatically load balance the text generation requests between the multiple accounts.
> [!NOTE]
> This feature only allows local access because I couldn't find a way to authenticate the requests.
> I hardcoded `127.0.0.1` into the load balancing.
## Run with Docker
Run the following command to login:
```bash
docker run --rm -p 8085:8085 -v /path/to/your/config.yaml:/CLIProxyAPI/config.yaml -v /path/to/your/auth-dir:/root/.cli-proxy-api eceasy/cli-proxy-api:latest /CLIProxyAPI/CLIProxyAPI --login
```
Run the following command to start the server:
```bash
docker run --rm -p 8317:8317 -v /path/to/your/config.yaml:/CLIProxyAPI/config.yaml -v /path/to/your/auth-dir:/root/.cli-proxy-api eceasy/cli-proxy-api:latest
```
## Contributing
Contributions are welcome! Please feel free to submit a Pull Request.

254
README_CN.md Normal file
View File

@@ -0,0 +1,254 @@
# CLI 代理 API
[English](README.md) | 中文
一个为 CLI 提供 OpenAI/Gemini/Claude 兼容 API 接口的代理服务器。这让您可以摆脱终端界面的束缚,将 Gemini 的强大能力以 API 的形式轻松接入到任何您喜爱的客户端或应用中。
## 功能特性
- 为 CLI 模型提供 OpenAI/Gemini/Claude 兼容的 API 端点
- 支持流式和非流式响应
- 函数调用/工具支持
- 多模态输入支持(文本和图像)
- 多账户支持与负载均衡
- 简单的 CLI 身份验证流程
- 支持 Gemini AIStudio API 密钥
- 支持 Gemini CLI 多账户轮询
## 安装
### 前置要求
- Go 1.24 或更高版本
- 有权访问 CLI 模型的 Google 账户
### 从源码构建
1. 克隆仓库:
```bash
git clone https://github.com/luispater/CLIProxyAPI.git
cd CLIProxyAPI
```
2. 构建应用程序:
```bash
go build -o cli-proxy-api ./cmd/server
```
## 使用方法
### 身份验证
在使用 API 之前,您需要使用 Google 账户进行身份验证:
```bash
./cli-proxy-api --login
```
如果您是旧版 gemini code 用户,可能需要指定项目 ID
```bash
./cli-proxy-api --login --project_id <your_project_id>
```
### 启动服务器
身份验证完成后,启动服务器:
```bash
./cli-proxy-api
```
默认情况下,服务器在端口 8317 上运行。
### API 端点
#### 列出模型
```
GET http://localhost:8317/v1/models
```
#### 聊天补全
```
POST http://localhost:8317/v1/chat/completions
```
请求体示例:
```json
{
"model": "gemini-2.5-pro",
"messages": [
{
"role": "user",
"content": "你好,你好吗?"
}
],
"stream": true
}
```
### 与 OpenAI 库一起使用
您可以通过将基础 URL 设置为本地服务器来将此代理与任何 OpenAI 兼容的库一起使用:
#### Python使用 OpenAI 库)
```python
from openai import OpenAI
client = OpenAI(
api_key="dummy", # 不使用但必需
base_url="http://localhost:8317/v1"
)
response = client.chat.completions.create(
model="gemini-2.5-pro",
messages=[
{"role": "user", "content": "你好,你好吗?"}
]
)
print(response.choices[0].message.content)
```
#### JavaScript/TypeScript
```javascript
import OpenAI from 'openai';
const openai = new OpenAI({
apiKey: 'dummy', // 不使用但必需
baseURL: 'http://localhost:8317/v1',
});
const response = await openai.chat.completions.create({
model: 'gemini-2.5-pro',
messages: [
{ role: 'user', content: '你好,你好吗?' }
],
});
console.log(response.choices[0].message.content);
```
## 支持的模型
- gemini-2.5-pro
- gemini-2.5-flash
- 并且自动切换到之前的预览版本
## 配置
服务器默认使用位于项目根目录的 YAML 配置文件(`config.yaml`)。您可以使用 `--config` 标志指定不同的配置文件路径:
```bash
./cli-proxy --config /path/to/your/config.yaml
```
### 配置选项
| 参数 | 类型 | 默认值 | 描述 |
|---------------------------------------|----------|--------------------|------------------------------------------------------------------------|
| `port` | integer | 8317 | 服务器监听的端口号 |
| `auth-dir` | string | "~/.cli-proxy-api" | 存储身份验证令牌的目录。支持使用 `~` 表示主目录 |
| `proxy-url` | string | "" | 代理 URL支持 socks5/http/https 协议示例socks5://user:pass@192.168.1.1:1080/ |
| `quota-exceeded` | object | {} | 处理配额超限的配置 |
| `quota-exceeded.switch-project` | boolean | true | 当配额超限时是否自动切换到另一个项目 |
| `quota-exceeded.switch-preview-model` | boolean | true | 当配额超限时是否自动切换到预览模型 |
| `debug` | boolean | false | 启用调试模式以进行详细日志记录 |
| `api-keys` | string[] | [] | 可用于验证请求的 API 密钥列表 |
| `generative-language-api-key` | string[] | [] | 生成式语言 API 密钥列表 |
### 配置文件示例
```yaml
# 服务器端口
port: 8317
# 身份验证目录(支持 ~ 表示主目录)
auth-dir: "~/.cli-proxy-api"
# 启用调试日志
debug: false
# 代理 URL支持 socks5/http/https 协议示例socks5://user:pass@192.168.1.1:1080/
proxy-url: ""
# 配额超限行为
quota-exceeded:
switch-project: true # 当配额超限时是否自动切换到另一个项目
switch-preview-model: true # 当配额超限时是否自动切换到预览模型
# 用于本地身份验证的 API 密钥
api-keys:
- "your-api-key-1"
- "your-api-key-2"
# AIStduio Gemini API 的 API 密钥
generative-language-api-key:
- "AIzaSy...01"
- "AIzaSy...02"
- "AIzaSy...03"
- "AIzaSy...04"
```
### 身份验证目录
`auth-dir` 参数指定身份验证令牌的存储位置。当您运行登录命令时,应用程序将在此目录中创建包含 Google 账户身份验证令牌的 JSON 文件。多个账户可用于轮询。
### API 密钥
`api-keys` 参数允许您定义可用于验证对代理服务器请求的 API 密钥列表。在向 API 发出请求时,您可以在 `Authorization` 标头中包含其中一个密钥:
```
Authorization: Bearer your-api-key-1
```
### 官方生成式语言 API
`generative-language-api-key` 参数允许您定义可用于验证对官方 AIStudio Gemini API 请求的 API 密钥列表。
## Gemini CLI 多账户负载均衡
启动 CLI 代理 API 服务器,然后将 `CODE_ASSIST_ENDPOINT` 环境变量设置为 CLI 代理 API 服务器的 URL。
```bash
export CODE_ASSIST_ENDPOINT="http://127.0.0.1:8317"
```
服务器将中继 `loadCodeAssist`、`onboardUser` 和 `countTokens` 请求。并自动在多个账户之间轮询文本生成请求。
> [!NOTE]
> 此功能仅允许本地访问,因为找不到一个可以验证请求的方法。
> 所以只能强制只有 `127.0.0.1` 可以访问。
## 使用 Docker 运行
运行以下命令进行登录:
```bash
docker run --rm -p 8085:8085 -v /path/to/your/config.yaml:/CLIProxyAPI/config.yaml -v /path/to/your/auth-dir:/root/.cli-proxy-api eceasy/cli-proxy-api:latest /CLIProxyAPI/CLIProxyAPI --login
```
运行以下命令启动服务器:
```bash
docker run --rm -p 8317:8317 -v /path/to/your/config.yaml:/CLIProxyAPI/config.yaml -v /path/to/your/auth-dir:/root/.cli-proxy-api eceasy/cli-proxy-api:latest
```
## 贡献
欢迎贡献!请随时提交 Pull Request。
1. Fork 仓库
2. 创建您的功能分支(`git checkout -b feature/amazing-feature`
3. 提交您的更改(`git commit -m 'Add some amazing feature'`
4. 推送到分支(`git push origin feature/amazing-feature`
5. 打开 Pull Request
## 许可证
此项目根据 MIT 许可证授权 - 有关详细信息,请参阅 [LICENSE](LICENSE) 文件。

View File

@@ -2,28 +2,21 @@ package main
import (
"bytes"
"context"
"encoding/json"
"flag"
"fmt"
"github.com/luispater/CLIProxyAPI/internal/api"
"github.com/luispater/CLIProxyAPI/internal/auth"
"github.com/luispater/CLIProxyAPI/internal/client"
"github.com/luispater/CLIProxyAPI/internal/cmd"
"github.com/luispater/CLIProxyAPI/internal/config"
log "github.com/sirupsen/logrus"
"io/fs"
"os"
"os/signal"
"path"
"path/filepath"
"strings"
"syscall"
"time"
)
// LogFormatter defines a custom log format for logrus.
type LogFormatter struct {
}
// Format renders a single log entry.
func (m *LogFormatter) Format(entry *log.Entry) ([]byte, error) {
var b *bytes.Buffer
if entry.Buffer != nil {
@@ -34,33 +27,42 @@ func (m *LogFormatter) Format(entry *log.Entry) ([]byte, error) {
timestamp := entry.Time.Format("2006-01-02 15:04:05")
var newLog string
// Customize the log format to include timestamp, level, caller file/line, and message.
newLog = fmt.Sprintf("[%s] [%s] [%s:%d] %s\n", timestamp, entry.Level, path.Base(entry.Caller.File), entry.Caller.Line, entry.Message)
b.WriteString(newLog)
return b.Bytes(), nil
}
// init initializes the logger configuration.
func init() {
// Set logger output to standard output.
log.SetOutput(os.Stdout)
// Enable reporting the caller function's file and line number.
log.SetReportCaller(true)
// Set the custom log formatter.
log.SetFormatter(&LogFormatter{})
}
// main is the entry point of the application.
func main() {
var login bool
var projectID string
var configPath string
// Define command-line flags.
flag.BoolVar(&login, "login", false, "Login Google Account")
flag.StringVar(&projectID, "project_id", "", "Project ID")
flag.StringVar(&configPath, "config", "", "Configure File Path")
// Parse the command-line flags.
flag.Parse()
var err error
var cfg *config.Config
var wd string
// Load configuration from the specified path or the default path.
if configPath != "" {
cfg, err = config.LoadConfig(configPath)
} else {
@@ -74,12 +76,14 @@ func main() {
log.Fatalf("failed to load config: %v", err)
}
// Set the log level based on the configuration.
if cfg.Debug {
log.SetLevel(log.DebugLevel)
} else {
log.SetLevel(log.InfoLevel)
}
// Expand the tilde (~) in the auth directory path to the user's home directory.
if strings.HasPrefix(cfg.AuthDir, "~") {
home, errUserHomeDir := os.UserHomeDir()
if errUserHomeDir != nil {
@@ -94,148 +98,10 @@ func main() {
}
}
// Either perform login or start the service based on the 'login' flag.
if login {
var ts auth.TokenStorage
if projectID != "" {
ts.ProjectID = projectID
}
// 2. Initialize authenticated HTTP Client
clientCtx := context.Background()
log.Info("Initializing authentication...")
httpClient, errGetClient := auth.GetAuthenticatedClient(clientCtx, &ts, cfg.AuthDir)
if errGetClient != nil {
log.Fatalf("failed to get authenticated client: %v", errGetClient)
return
}
log.Info("Authentication successful.")
// 3. Initialize CLI Client
cliClient := client.NewClient(httpClient)
if err = cliClient.SetupUser(clientCtx, ts.Email, projectID); err != nil {
if err.Error() == "failed to start user onboarding, need define a project id" {
log.Error("failed to start user onboarding")
project, errGetProjectList := cliClient.GetProjectList(clientCtx)
if errGetProjectList != nil {
log.Fatalf("failed to complete user setup: %v", err)
} else {
log.Infof("Your account %s needs specify a project id.", ts.Email)
log.Info("========================================================================")
for i := 0; i < len(project.Projects); i++ {
log.Infof("Project ID: %s", project.Projects[i].ProjectID)
log.Infof("Project Name: %s", project.Projects[i].Name)
log.Info("========================================================================")
}
log.Infof("Please run this command to login again:\n\n%s --login --project_id <project_id>\n", os.Args[0])
}
} else {
// Log as a warning because in some cases, the CLI might still be usable
// or the user might want to retry setup later.
log.Fatalf("failed to complete user setup: %v", err)
}
}
cmd.DoLogin(cfg, projectID)
} else {
// Create API server configuration
apiConfig := &api.ServerConfig{
Port: fmt.Sprintf("%d", cfg.Port),
Debug: cfg.Debug,
ApiKeys: cfg.ApiKeys,
}
cliClients := make([]*client.Client, 0)
err = filepath.Walk(cfg.AuthDir, func(path string, info fs.FileInfo, err error) error {
if err != nil {
return err
}
if !info.IsDir() && strings.HasSuffix(info.Name(), ".json") {
log.Debugf(path)
f, errOpen := os.Open(path)
if errOpen != nil {
return errOpen
}
defer func() {
_ = f.Close()
}()
var ts auth.TokenStorage
if err = json.NewDecoder(f).Decode(&ts); err == nil {
// 2. Initialize authenticated HTTP Client
clientCtx := context.Background()
log.Info("Initializing authentication...")
httpClient, errGetClient := auth.GetAuthenticatedClient(clientCtx, &ts, cfg.AuthDir)
if errGetClient != nil {
log.Fatalf("failed to get authenticated client: %v", errGetClient)
return errGetClient
}
log.Info("Authentication successful.")
// 3. Initialize CLI Client
cliClient := client.NewClient(httpClient)
if err = cliClient.SetupUser(clientCtx, ts.Email, ts.ProjectID); err != nil {
if err.Error() == "failed to start user onboarding, need define a project id" {
log.Error("failed to start user onboarding")
project, errGetProjectList := cliClient.GetProjectList(clientCtx)
if errGetProjectList != nil {
log.Fatalf("failed to complete user setup: %v", err)
} else {
log.Infof("Your account %s needs specify a project id.", ts.Email)
log.Info("========================================================================")
for i := 0; i < len(project.Projects); i++ {
log.Infof("Project ID: %s", project.Projects[i].ProjectID)
log.Infof("Project Name: %s", project.Projects[i].Name)
log.Info("========================================================================")
}
log.Infof("Please run this command to login again:\n\n%s --login --project_id <project_id>\n", os.Args[0])
}
} else {
// Log as a warning because in some cases, the CLI might still be usable
// or the user might want to retry setup later.
log.Fatalf("failed to complete user setup: %v", err)
}
} else {
cliClients = append(cliClients, cliClient)
}
}
}
return nil
})
// Create API server
apiServer := api.NewServer(apiConfig, cliClients)
log.Infof("Starting API server on port %s", apiConfig.Port)
if err = apiServer.Start(); err != nil {
log.Fatalf("API server failed to start: %v", err)
return
}
// Set up graceful shutdown
sigChan := make(chan os.Signal, 1)
signal.Notify(sigChan, syscall.SIGINT, syscall.SIGTERM)
for {
select {
case <-sigChan:
log.Debugf("Received shutdown signal. Cleaning up...")
// Create shutdown context
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
_ = ctx // Mark ctx as used to avoid error, as apiServer.Stop(ctx) is commented out
// Stop API server
if err = apiServer.Stop(ctx); err != nil {
log.Debugf("Error stopping API server: %v", err)
}
cancel()
log.Debugf("Cleanup completed. Exiting...")
os.Exit(0)
case <-time.After(5 * time.Second):
}
}
cmd.StartService(cfg)
}
}

15
config.example.yaml Normal file
View File

@@ -0,0 +1,15 @@
port: 8317
auth-dir: "~/.cli-proxy-api"
debug: true
proxy-url: ""
quota-exceeded:
switch-project: true
switch-preview-model: true
api-keys:
- "12345"
- "23456"
generative-language-api-key:
- "AIzaSy...01"
- "AIzaSy...02"
- "AIzaSy...03"
- "AIzaSy...04"

View File

@@ -1,6 +0,0 @@
port: 8317
auth_dir: "~/.cli-proxy-api"
debug: false
api_keys:
- "12345"
- "23456"

View File

@@ -0,0 +1,188 @@
package api
import (
"context"
"fmt"
"github.com/gin-gonic/gin"
"github.com/luispater/CLIProxyAPI/internal/api/translator"
"github.com/luispater/CLIProxyAPI/internal/client"
log "github.com/sirupsen/logrus"
"net/http"
"strings"
"time"
)
// ClaudeMessages handles Claude-compatible streaming chat completions.
// This function implements a sophisticated client rotation and quota management system
// to ensure high availability and optimal resource utilization across multiple backend clients.
func (h *APIHandlers) ClaudeMessages(c *gin.Context) {
// Extract raw JSON data from the incoming request
rawJson, err := c.GetRawData()
// If data retrieval fails, return a 400 Bad Request error.
if err != nil {
c.JSON(http.StatusBadRequest, ErrorResponse{
Error: ErrorDetail{
Message: fmt.Sprintf("Invalid request: %v", err),
Type: "invalid_request_error",
},
})
return
}
// Set up Server-Sent Events (SSE) headers for streaming response
// These headers are essential for maintaining a persistent connection
// and enabling real-time streaming of chat completions
c.Header("Content-Type", "text/event-stream")
c.Header("Cache-Control", "no-cache")
c.Header("Connection", "keep-alive")
c.Header("Access-Control-Allow-Origin", "*")
// Get the http.Flusher interface to manually flush the response.
// This is crucial for streaming as it allows immediate sending of data chunks
flusher, ok := c.Writer.(http.Flusher)
if !ok {
c.JSON(http.StatusInternalServerError, ErrorResponse{
Error: ErrorDetail{
Message: "Streaming not supported",
Type: "server_error",
},
})
return
}
// Parse and prepare the Claude request, extracting model name, system instructions,
// conversation contents, and available tools from the raw JSON
modelName, systemInstruction, contents, tools := translator.PrepareClaudeRequest(rawJson)
// Map Claude model names to corresponding Gemini models
// This allows the proxy to handle Claude API calls using Gemini backend
if modelName == "claude-sonnet-4-20250514" {
modelName = "gemini-2.5-pro"
} else if modelName == "claude-3-5-haiku-20241022" {
modelName = "gemini-2.5-flash"
}
// Create a cancellable context for the backend client request
// This allows proper cleanup and cancellation of ongoing requests
cliCtx, cliCancel := context.WithCancel(context.Background())
var cliClient *client.Client
defer func() {
// Ensure the client's mutex is unlocked on function exit.
// This prevents deadlocks and ensures proper resource cleanup
if cliClient != nil {
cliClient.RequestMutex.Unlock()
}
}()
// Main client rotation loop with quota management
// This loop implements a sophisticated load balancing and failover mechanism
outLoop:
for {
var errorResponse *client.ErrorMessage
cliClient, errorResponse = h.getClient(modelName)
if errorResponse != nil {
c.Status(errorResponse.StatusCode)
_, _ = fmt.Fprint(c.Writer, errorResponse.Error)
flusher.Flush()
cliCancel()
return
}
// Determine the authentication method being used by the selected client
// This affects how responses are formatted and logged
isGlAPIKey := false
if glAPIKey := cliClient.GetGenerativeLanguageAPIKey(); glAPIKey != "" {
log.Debugf("Request use generative language API Key: %s", glAPIKey)
isGlAPIKey = true
} else {
log.Debugf("Request use account: %s, project id: %s", cliClient.GetEmail(), cliClient.GetProjectID())
}
// Initiate streaming communication with the backend client
// This returns two channels: one for response chunks and one for errors
includeThoughts := false
if userAgent, hasKey := c.Request.Header["User-Agent"]; hasKey {
includeThoughts = !strings.Contains(userAgent[0], "claude-cli")
}
respChan, errChan := cliClient.SendMessageStream(cliCtx, rawJson, modelName, systemInstruction, contents, tools, includeThoughts)
// Track response state for proper Claude format conversion
hasFirstResponse := false
responseType := 0
responseIndex := 0
// Main streaming loop - handles multiple concurrent events using Go channels
// This select statement manages four different types of events simultaneously
for {
select {
// Case 1: Handle client disconnection
// Detects when the HTTP client has disconnected and cleans up resources
case <-c.Request.Context().Done():
if c.Request.Context().Err().Error() == "context canceled" {
log.Debugf("Client disconnected: %v", c.Request.Context().Err())
cliCancel() // Cancel the backend request to prevent resource leaks
return
}
// Case 2: Process incoming response chunks from the backend
// This handles the actual streaming data from the AI model
case chunk, okStream := <-respChan:
if !okStream {
// Stream has ended - send the final message_stop event
// This follows the Claude API specification for stream termination
_, _ = c.Writer.Write([]byte(`event: message_stop`))
_, _ = c.Writer.Write([]byte("\n"))
_, _ = c.Writer.Write([]byte(`data: {"type":"message_stop"}`))
_, _ = c.Writer.Write([]byte("\n\n\n"))
flusher.Flush()
cliCancel()
return
} else {
// Convert the backend response to Claude-compatible format
// This translation layer ensures API compatibility
claudeFormat := translator.ConvertCliToClaude(chunk, isGlAPIKey, hasFirstResponse, &responseType, &responseIndex)
if claudeFormat != "" {
_, _ = c.Writer.Write([]byte(claudeFormat))
flusher.Flush() // Immediately send the chunk to the client
}
hasFirstResponse = true
}
// Case 3: Handle errors from the backend
// This manages various error conditions and implements retry logic
case errInfo, okError := <-errChan:
if okError {
// Special handling for quota exceeded errors
// If configured, attempt to switch to a different project/client
if errInfo.StatusCode == 429 && h.cfg.QuotaExceeded.SwitchProject {
continue outLoop // Restart the client selection process
} else {
// Forward other errors directly to the client
c.Status(errInfo.StatusCode)
_, _ = fmt.Fprint(c.Writer, errInfo.Error.Error())
flusher.Flush()
cliCancel()
}
return
}
// Case 4: Send periodic keep-alive signals
// Prevents connection timeouts during long-running requests
case <-time.After(500 * time.Millisecond):
if hasFirstResponse {
// Send a ping event to maintain the connection
// This is especially important for slow AI model responses
output := "event: ping\n"
output = output + `data: {"type": "ping"}`
output = output + "\n\n\n"
_, _ = c.Writer.Write([]byte(output))
flusher.Flush()
}
}
}
}
}

View File

@@ -0,0 +1,248 @@
package api
import (
"bytes"
"context"
"fmt"
"github.com/gin-gonic/gin"
"github.com/luispater/CLIProxyAPI/internal/client"
"github.com/luispater/CLIProxyAPI/internal/util"
log "github.com/sirupsen/logrus"
"github.com/tidwall/gjson"
"github.com/tidwall/sjson"
"io"
"net/http"
"strings"
"time"
)
func (h *APIHandlers) CLIHandler(c *gin.Context) {
if !strings.HasPrefix(c.Request.RemoteAddr, "127.0.0.1:") {
c.JSON(http.StatusForbidden, ErrorResponse{
Error: ErrorDetail{
Message: "CLI reply only allow local access",
Type: "forbidden",
},
})
return
}
rawJson, _ := c.GetRawData()
requestRawURI := c.Request.URL.Path
if requestRawURI == "/v1internal:generateContent" {
h.internalGenerateContent(c, rawJson)
} else if requestRawURI == "/v1internal:streamGenerateContent" {
h.internalStreamGenerateContent(c, rawJson)
} else {
reqBody := bytes.NewBuffer(rawJson)
req, err := http.NewRequest("POST", fmt.Sprintf("https://cloudcode-pa.googleapis.com%s", c.Request.URL.RequestURI()), reqBody)
if err != nil {
c.JSON(http.StatusBadRequest, ErrorResponse{
Error: ErrorDetail{
Message: fmt.Sprintf("Invalid request: %v", err),
Type: "invalid_request_error",
},
})
return
}
for key, value := range c.Request.Header {
req.Header[key] = value
}
httpClient, err := util.SetProxy(h.cfg, &http.Client{})
if err != nil {
log.Fatalf("set proxy failed: %v", err)
}
resp, err := httpClient.Do(req)
if err != nil {
c.JSON(http.StatusBadRequest, ErrorResponse{
Error: ErrorDetail{
Message: fmt.Sprintf("Invalid request: %v", err),
Type: "invalid_request_error",
},
})
return
}
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
defer func() {
if err = resp.Body.Close(); err != nil {
log.Printf("warn: failed to close response body: %v", err)
}
}()
bodyBytes, _ := io.ReadAll(resp.Body)
c.JSON(http.StatusBadRequest, ErrorResponse{
Error: ErrorDetail{
Message: string(bodyBytes),
Type: "invalid_request_error",
},
})
return
}
defer func() {
_ = resp.Body.Close()
}()
for key, value := range resp.Header {
c.Header(key, value[0])
}
output, err := io.ReadAll(resp.Body)
if err != nil {
log.Errorf("Failed to read response body: %v", err)
return
}
_, _ = c.Writer.Write(output)
}
}
func (h *APIHandlers) internalStreamGenerateContent(c *gin.Context, rawJson []byte) {
alt := h.getAlt(c)
if alt == "" {
c.Header("Content-Type", "text/event-stream")
c.Header("Cache-Control", "no-cache")
c.Header("Connection", "keep-alive")
c.Header("Access-Control-Allow-Origin", "*")
}
// Get the http.Flusher interface to manually flush the response.
flusher, ok := c.Writer.(http.Flusher)
if !ok {
c.JSON(http.StatusInternalServerError, ErrorResponse{
Error: ErrorDetail{
Message: "Streaming not supported",
Type: "server_error",
},
})
return
}
modelResult := gjson.GetBytes(rawJson, "model")
modelName := modelResult.String()
cliCtx, cliCancel := context.WithCancel(context.Background())
var cliClient *client.Client
defer func() {
// Ensure the client's mutex is unlocked on function exit.
if cliClient != nil {
cliClient.RequestMutex.Unlock()
}
}()
outLoop:
for {
var errorResponse *client.ErrorMessage
cliClient, errorResponse = h.getClient(modelName)
if errorResponse != nil {
c.Status(errorResponse.StatusCode)
_, _ = fmt.Fprint(c.Writer, errorResponse.Error)
flusher.Flush()
cliCancel()
return
}
if glAPIKey := cliClient.GetGenerativeLanguageAPIKey(); glAPIKey != "" {
log.Debugf("Request use generative language API Key: %s", glAPIKey)
} else {
log.Debugf("Request use account: %s, project id: %s", cliClient.GetEmail(), cliClient.GetProjectID())
}
// Send the message and receive response chunks and errors via channels.
respChan, errChan := cliClient.SendRawMessageStream(cliCtx, rawJson, "")
hasFirstResponse := false
for {
select {
// Handle client disconnection.
case <-c.Request.Context().Done():
if c.Request.Context().Err().Error() == "context canceled" {
log.Debugf("Client disconnected: %v", c.Request.Context().Err())
cliCancel() // Cancel the backend request.
return
}
// Process incoming response chunks.
case chunk, okStream := <-respChan:
if !okStream {
cliCancel()
return
} else {
hasFirstResponse = true
if cliClient.GetGenerativeLanguageAPIKey() != "" {
chunk, _ = sjson.SetRawBytes(chunk, "response", chunk)
}
_, _ = c.Writer.Write([]byte("data: "))
_, _ = c.Writer.Write(chunk)
_, _ = c.Writer.Write([]byte("\n\n"))
flusher.Flush()
}
// Handle errors from the backend.
case err, okError := <-errChan:
if okError {
if err.StatusCode == 429 && h.cfg.QuotaExceeded.SwitchProject {
continue outLoop
} else {
c.Status(err.StatusCode)
_, _ = fmt.Fprint(c.Writer, err.Error.Error())
flusher.Flush()
cliCancel()
}
return
}
// Send a keep-alive signal to the client.
case <-time.After(500 * time.Millisecond):
if hasFirstResponse {
_, _ = c.Writer.Write([]byte("\n"))
flusher.Flush()
}
}
}
}
}
func (h *APIHandlers) internalGenerateContent(c *gin.Context, rawJson []byte) {
c.Header("Content-Type", "application/json")
modelResult := gjson.GetBytes(rawJson, "model")
modelName := modelResult.String()
cliCtx, cliCancel := context.WithCancel(context.Background())
var cliClient *client.Client
defer func() {
if cliClient != nil {
cliClient.RequestMutex.Unlock()
}
}()
for {
var errorResponse *client.ErrorMessage
cliClient, errorResponse = h.getClient(modelName)
if errorResponse != nil {
c.Status(errorResponse.StatusCode)
_, _ = fmt.Fprint(c.Writer, errorResponse.Error)
cliCancel()
return
}
if glAPIKey := cliClient.GetGenerativeLanguageAPIKey(); glAPIKey != "" {
log.Debugf("Request use generative language API Key: %s", glAPIKey)
} else {
log.Debugf("Request use account: %s, project id: %s", cliClient.GetEmail(), cliClient.GetProjectID())
}
resp, err := cliClient.SendRawMessage(cliCtx, rawJson, "")
if err != nil {
if err.StatusCode == 429 && h.cfg.QuotaExceeded.SwitchProject {
continue
} else {
c.Status(err.StatusCode)
_, _ = c.Writer.Write([]byte(err.Error.Error()))
cliCancel()
}
break
} else {
_, _ = c.Writer.Write(resp)
cliCancel()
break
}
}
}

View File

@@ -0,0 +1,409 @@
package api
import (
"context"
"fmt"
"github.com/gin-gonic/gin"
"github.com/luispater/CLIProxyAPI/internal/api/translator"
"github.com/luispater/CLIProxyAPI/internal/client"
log "github.com/sirupsen/logrus"
"github.com/tidwall/gjson"
"github.com/tidwall/sjson"
"net/http"
"strings"
"time"
)
func (h *APIHandlers) GeminiModels(c *gin.Context) {
c.Status(http.StatusOK)
c.Header("Content-Type", "application/json; charset=UTF-8")
_, _ = c.Writer.Write([]byte(`{"models":[{"name":"models/gemini-2.5-flash","version":"001","displayName":"Gemini `))
_, _ = c.Writer.Write([]byte(`2.5 Flash","description":"Stable version of Gemini 2.5 Flash, our mid-size multimod`))
_, _ = c.Writer.Write([]byte(`al model that supports up to 1 million tokens, released in June of 2025.","inputTok`))
_, _ = c.Writer.Write([]byte(`enLimit":1048576,"outputTokenLimit":65536,"supportedGenerationMethods":["generateCo`))
_, _ = c.Writer.Write([]byte(`ntent","countTokens","createCachedContent","batchGenerateContent"],"temperature":1,`))
_, _ = c.Writer.Write([]byte(`"topP":0.95,"topK":64,"maxTemperature":2,"thinking":true},{"name":"models/gemini-2.`))
_, _ = c.Writer.Write([]byte(`5-pro","version":"2.5","displayName":"Gemini 2.5 Pro","description":"Stable release`))
_, _ = c.Writer.Write([]byte(` (June 17th, 2025) of Gemini 2.5 Pro","inputTokenLimit":1048576,"outputTokenLimit":`))
_, _ = c.Writer.Write([]byte(`65536,"supportedGenerationMethods":["generateContent","countTokens","createCachedCo`))
_, _ = c.Writer.Write([]byte(`ntent","batchGenerateContent"],"temperature":1,"topP":0.95,"topK":64,"maxTemperatur`))
_, _ = c.Writer.Write([]byte(`e":2,"thinking":true}],"nextPageToken":""}`))
}
func (h *APIHandlers) GeminiGetHandler(c *gin.Context) {
var request struct {
Action string `uri:"action" binding:"required"`
}
if err := c.ShouldBindUri(&request); err != nil {
c.JSON(http.StatusBadRequest, ErrorResponse{
Error: ErrorDetail{
Message: fmt.Sprintf("Invalid request: %v", err),
Type: "invalid_request_error",
},
})
return
}
if request.Action == "gemini-2.5-pro" {
c.Status(http.StatusOK)
c.Header("Content-Type", "application/json; charset=UTF-8")
_, _ = c.Writer.Write([]byte(`{"name":"models/gemini-2.5-pro","version":"2.5","displayName":"Gemini 2.5 Pro",`))
_, _ = c.Writer.Write([]byte(`"description":"Stable release (June 17th, 2025) of Gemini 2.5 Pro","inputTokenL`))
_, _ = c.Writer.Write([]byte(`imit":1048576,"outputTokenLimit":65536,"supportedGenerationMethods":["generateC`))
_, _ = c.Writer.Write([]byte(`ontent","countTokens","createCachedContent","batchGenerateContent"],"temperatur`))
_, _ = c.Writer.Write([]byte(`e":1,"topP":0.95,"topK":64,"maxTemperature":2,"thinking":true}`))
} else if request.Action == "gemini-2.5-flash" {
c.Status(http.StatusOK)
c.Header("Content-Type", "application/json; charset=UTF-8")
_, _ = c.Writer.Write([]byte(`{"name":"models/gemini-2.5-flash","version":"001","displayName":"Gemini 2.5 Fla`))
_, _ = c.Writer.Write([]byte(`sh","description":"Stable version of Gemini 2.5 Flash, our mid-size multimodal `))
_, _ = c.Writer.Write([]byte(`model that supports up to 1 million tokens, released in June of 2025.","inputTo`))
_, _ = c.Writer.Write([]byte(`kenLimit":1048576,"outputTokenLimit":65536,"supportedGenerationMethods":["gener`))
_, _ = c.Writer.Write([]byte(`ateContent","countTokens","createCachedContent","batchGenerateContent"],"temper`))
_, _ = c.Writer.Write([]byte(`ature":1,"topP":0.95,"topK":64,"maxTemperature":2,"thinking":true}`))
} else {
c.Status(http.StatusNotFound)
_, _ = c.Writer.Write([]byte(
`{"error":{"message":"Not Found","code":404,"status":"NOT_FOUND"}}`,
))
}
}
func (h *APIHandlers) GeminiHandler(c *gin.Context) {
var request struct {
Action string `uri:"action" binding:"required"`
}
if err := c.ShouldBindUri(&request); err != nil {
c.JSON(http.StatusBadRequest, ErrorResponse{
Error: ErrorDetail{
Message: fmt.Sprintf("Invalid request: %v", err),
Type: "invalid_request_error",
},
})
return
}
action := strings.Split(request.Action, ":")
if len(action) != 2 {
c.JSON(http.StatusNotFound, ErrorResponse{
Error: ErrorDetail{
Message: fmt.Sprintf("%s not found.", c.Request.URL.Path),
Type: "invalid_request_error",
},
})
return
}
modelName := action[0]
method := action[1]
rawJson, _ := c.GetRawData()
rawJson, _ = sjson.SetBytes(rawJson, "model", []byte(modelName))
if method == "generateContent" {
h.geminiGenerateContent(c, rawJson)
} else if method == "streamGenerateContent" {
h.geminiStreamGenerateContent(c, rawJson)
} else if method == "countTokens" {
h.geminiCountTokens(c, rawJson)
}
}
func (h *APIHandlers) geminiStreamGenerateContent(c *gin.Context, rawJson []byte) {
alt := h.getAlt(c)
if alt == "" {
c.Header("Content-Type", "text/event-stream")
c.Header("Cache-Control", "no-cache")
c.Header("Connection", "keep-alive")
c.Header("Access-Control-Allow-Origin", "*")
}
// Get the http.Flusher interface to manually flush the response.
flusher, ok := c.Writer.(http.Flusher)
if !ok {
c.JSON(http.StatusInternalServerError, ErrorResponse{
Error: ErrorDetail{
Message: "Streaming not supported",
Type: "server_error",
},
})
return
}
modelResult := gjson.GetBytes(rawJson, "model")
modelName := modelResult.String()
cliCtx, cliCancel := context.WithCancel(context.Background())
var cliClient *client.Client
defer func() {
// Ensure the client's mutex is unlocked on function exit.
if cliClient != nil {
cliClient.RequestMutex.Unlock()
}
}()
outLoop:
for {
var errorResponse *client.ErrorMessage
cliClient, errorResponse = h.getClient(modelName)
if errorResponse != nil {
c.Status(errorResponse.StatusCode)
_, _ = fmt.Fprint(c.Writer, errorResponse.Error)
flusher.Flush()
cliCancel()
return
}
template := `{"project":"","request":{},"model":""}`
template, _ = sjson.SetRaw(template, "request", string(rawJson))
template, _ = sjson.Set(template, "model", gjson.Get(template, "request.model").String())
template, _ = sjson.Delete(template, "request.model")
template, errFixCLIToolResponse := translator.FixCLIToolResponse(template)
if errFixCLIToolResponse != nil {
c.JSON(http.StatusInternalServerError, ErrorResponse{
Error: ErrorDetail{
Message: errFixCLIToolResponse.Error(),
Type: "server_error",
},
})
cliCancel()
return
}
systemInstructionResult := gjson.Get(template, "request.system_instruction")
if systemInstructionResult.Exists() {
template, _ = sjson.SetRaw(template, "request.systemInstruction", systemInstructionResult.Raw)
template, _ = sjson.Delete(template, "request.system_instruction")
}
rawJson = []byte(template)
if glAPIKey := cliClient.GetGenerativeLanguageAPIKey(); glAPIKey != "" {
log.Debugf("Request use generative language API Key: %s", glAPIKey)
} else {
log.Debugf("Request use account: %s, project id: %s", cliClient.GetEmail(), cliClient.GetProjectID())
}
// Send the message and receive response chunks and errors via channels.
respChan, errChan := cliClient.SendRawMessageStream(cliCtx, rawJson, alt)
for {
select {
// Handle client disconnection.
case <-c.Request.Context().Done():
if c.Request.Context().Err().Error() == "context canceled" {
log.Debugf("Client disconnected: %v", c.Request.Context().Err())
cliCancel() // Cancel the backend request.
return
}
// Process incoming response chunks.
case chunk, okStream := <-respChan:
if !okStream {
cliCancel()
return
} else {
if cliClient.GetGenerativeLanguageAPIKey() == "" {
if alt == "" {
responseResult := gjson.GetBytes(chunk, "response")
if responseResult.Exists() {
chunk = []byte(responseResult.Raw)
}
} else {
chunkTemplate := "[]"
responseResult := gjson.ParseBytes(chunk)
if responseResult.IsArray() {
responseResultItems := responseResult.Array()
for i := 0; i < len(responseResultItems); i++ {
responseResultItem := responseResultItems[i]
if responseResultItem.Get("response").Exists() {
chunkTemplate, _ = sjson.SetRaw(chunkTemplate, "-1", responseResultItem.Get("response").Raw)
}
}
}
chunk = []byte(chunkTemplate)
}
}
if alt == "" {
_, _ = c.Writer.Write([]byte("data: "))
_, _ = c.Writer.Write(chunk)
_, _ = c.Writer.Write([]byte("\n\n"))
} else {
_, _ = c.Writer.Write(chunk)
}
flusher.Flush()
}
// Handle errors from the backend.
case err, okError := <-errChan:
if okError {
if err.StatusCode == 429 && h.cfg.QuotaExceeded.SwitchProject {
continue outLoop
} else {
c.Status(err.StatusCode)
_, _ = fmt.Fprint(c.Writer, err.Error.Error())
flusher.Flush()
cliCancel()
}
return
}
// Send a keep-alive signal to the client.
case <-time.After(500 * time.Millisecond):
}
}
}
}
func (h *APIHandlers) geminiCountTokens(c *gin.Context, rawJson []byte) {
c.Header("Content-Type", "application/json")
alt := h.getAlt(c)
// orgRawJson := rawJson
modelResult := gjson.GetBytes(rawJson, "model")
modelName := modelResult.String()
cliCtx, cliCancel := context.WithCancel(context.Background())
var cliClient *client.Client
defer func() {
if cliClient != nil {
cliClient.RequestMutex.Unlock()
}
}()
for {
var errorResponse *client.ErrorMessage
cliClient, errorResponse = h.getClient(modelName, false)
if errorResponse != nil {
c.Status(errorResponse.StatusCode)
_, _ = fmt.Fprint(c.Writer, errorResponse.Error)
cliCancel()
return
}
if glAPIKey := cliClient.GetGenerativeLanguageAPIKey(); glAPIKey != "" {
log.Debugf("Request use generative language API Key: %s", glAPIKey)
} else {
log.Debugf("Request use account: %s, project id: %s", cliClient.GetEmail(), cliClient.GetProjectID())
template := `{"request":{}}`
if gjson.GetBytes(rawJson, "generateContentRequest").Exists() {
template, _ = sjson.SetRaw(template, "request", gjson.GetBytes(rawJson, "generateContentRequest").Raw)
template, _ = sjson.Delete(template, "generateContentRequest")
} else if gjson.GetBytes(rawJson, "contents").Exists() {
template, _ = sjson.SetRaw(template, "request.contents", gjson.GetBytes(rawJson, "contents").Raw)
template, _ = sjson.Delete(template, "contents")
}
rawJson = []byte(template)
}
resp, err := cliClient.SendRawTokenCount(cliCtx, rawJson, alt)
if err != nil {
if err.StatusCode == 429 && h.cfg.QuotaExceeded.SwitchProject {
continue
} else {
c.Status(err.StatusCode)
_, _ = c.Writer.Write([]byte(err.Error.Error()))
cliCancel()
// log.Debugf(err.Error.Error())
// log.Debugf(string(rawJson))
// log.Debugf(string(orgRawJson))
}
break
} else {
if cliClient.GetGenerativeLanguageAPIKey() == "" {
responseResult := gjson.GetBytes(resp, "response")
if responseResult.Exists() {
resp = []byte(responseResult.Raw)
}
}
_, _ = c.Writer.Write(resp)
cliCancel()
break
}
}
}
func (h *APIHandlers) geminiGenerateContent(c *gin.Context, rawJson []byte) {
c.Header("Content-Type", "application/json")
alt := h.getAlt(c)
modelResult := gjson.GetBytes(rawJson, "model")
modelName := modelResult.String()
cliCtx, cliCancel := context.WithCancel(context.Background())
var cliClient *client.Client
defer func() {
if cliClient != nil {
cliClient.RequestMutex.Unlock()
}
}()
for {
var errorResponse *client.ErrorMessage
cliClient, errorResponse = h.getClient(modelName)
if errorResponse != nil {
c.Status(errorResponse.StatusCode)
_, _ = fmt.Fprint(c.Writer, errorResponse.Error)
cliCancel()
return
}
template := `{"project":"","request":{},"model":""}`
template, _ = sjson.SetRaw(template, "request", string(rawJson))
template, _ = sjson.Set(template, "model", gjson.Get(template, "request.model").String())
template, _ = sjson.Delete(template, "request.model")
template, errFixCLIToolResponse := translator.FixCLIToolResponse(template)
if errFixCLIToolResponse != nil {
c.JSON(http.StatusInternalServerError, ErrorResponse{
Error: ErrorDetail{
Message: errFixCLIToolResponse.Error(),
Type: "server_error",
},
})
cliCancel()
return
}
systemInstructionResult := gjson.Get(template, "request.system_instruction")
if systemInstructionResult.Exists() {
template, _ = sjson.SetRaw(template, "request.systemInstruction", systemInstructionResult.Raw)
template, _ = sjson.Delete(template, "request.system_instruction")
}
rawJson = []byte(template)
if glAPIKey := cliClient.GetGenerativeLanguageAPIKey(); glAPIKey != "" {
log.Debugf("Request use generative language API Key: %s", glAPIKey)
} else {
log.Debugf("Request use account: %s, project id: %s", cliClient.GetEmail(), cliClient.GetProjectID())
}
resp, err := cliClient.SendRawMessage(cliCtx, rawJson, alt)
if err != nil {
if err.StatusCode == 429 && h.cfg.QuotaExceeded.SwitchProject {
continue
} else {
c.Status(err.StatusCode)
_, _ = c.Writer.Write([]byte(err.Error.Error()))
cliCancel()
}
break
} else {
if cliClient.GetGenerativeLanguageAPIKey() == "" {
responseResult := gjson.GetBytes(resp, "response")
if responseResult.Exists() {
resp = []byte(responseResult.Raw)
}
}
_, _ = c.Writer.Write(resp)
cliCancel()
break
}
}
}
func (h *APIHandlers) getAlt(c *gin.Context) string {
var alt string
var hasAlt bool
alt, hasAlt = c.GetQuery("alt")
if !hasAlt {
alt, _ = c.GetQuery("$alt")
}
if alt == "sse" {
return ""
}
return alt
}

View File

@@ -2,14 +2,13 @@ package api
import (
"context"
"encoding/json"
"fmt"
"github.com/luispater/CLIProxyAPI/internal/api/translator"
"github.com/luispater/CLIProxyAPI/internal/client"
"github.com/luispater/CLIProxyAPI/internal/config"
log "github.com/sirupsen/logrus"
"github.com/tidwall/gjson"
"github.com/tidwall/sjson"
"net/http"
"strings"
"sync"
"time"
@@ -21,63 +20,27 @@ var (
lastUsedClientIndex = 0
)
// APIHandlers contains the handlers for API endpoints
// APIHandlers contains the handlers for API endpoints.
// It holds a pool of clients to interact with the backend service.
type APIHandlers struct {
cliClients []*client.Client
debug bool
cfg *config.Config
}
// NewAPIHandlers creates a new API handlers instance
func NewAPIHandlers(cliClients []*client.Client, debug bool) *APIHandlers {
// NewAPIHandlers creates a new API handlers instance.
// It takes a slice of clients and a debug flag as input.
func NewAPIHandlers(cliClients []*client.Client, cfg *config.Config) *APIHandlers {
return &APIHandlers{
cliClients: cliClients,
debug: debug,
cfg: cfg,
}
}
// Models handles the /v1/models endpoint.
// It returns a hardcoded list of available AI models.
func (h *APIHandlers) Models(c *gin.Context) {
c.JSON(http.StatusOK, gin.H{
"data": []map[string]any{
{
"id": "gemini-2.5-pro-preview-05-06",
"object": "model",
"version": "2.5-preview-05-06",
"name": "Gemini 2.5 Pro Preview 05-06",
"description": "Preview release (May 6th, 2025) of Gemini 2.5 Pro",
"context_length": 1048576,
"max_completion_tokens": 65536,
"supported_parameters": []string{
"tools",
"temperature",
"top_p",
"top_k",
},
"temperature": 1,
"topP": 0.95,
"topK": 64,
"maxTemperature": 2,
"thinking": true,
},
{
"id": "gemini-2.5-pro-preview-06-05",
"object": "model",
"version": "2.5-preview-06-05",
"name": "Gemini 2.5 Pro Preview",
"description": "Preview release (June 5th, 2025) of Gemini 2.5 Pro",
"context_length": 1048576,
"max_completion_tokens": 65536,
"supported_parameters": []string{
"tools",
"temperature",
"top_p",
"top_k",
},
"temperature": 1,
"topP": 0.95,
"topK": 64,
"maxTemperature": 2,
"thinking": true,
},
{
"id": "gemini-2.5-pro",
"object": "model",
@@ -98,46 +61,6 @@ func (h *APIHandlers) Models(c *gin.Context) {
"maxTemperature": 2,
"thinking": true,
},
{
"id": "gemini-2.5-flash-preview-04-17",
"object": "model",
"version": "2.5-preview-04-17",
"name": "Gemini 2.5 Flash Preview 04-17",
"description": "Preview release (April 17th, 2025) of Gemini 2.5 Flash",
"context_length": 1048576,
"max_completion_tokens": 65536,
"supported_parameters": []string{
"tools",
"temperature",
"top_p",
"top_k",
},
"temperature": 1,
"topP": 0.95,
"topK": 64,
"maxTemperature": 2,
"thinking": true,
},
{
"id": "gemini-2.5-flash-preview-05-20",
"object": "model",
"version": "2.5-preview-05-20",
"name": "Gemini 2.5 Flash Preview 05-20",
"description": "Preview release (April 17th, 2025) of Gemini 2.5 Flash",
"context_length": 1048576,
"max_completion_tokens": 65536,
"supported_parameters": []string{
"tools",
"temperature",
"top_p",
"top_k",
},
"temperature": 1,
"topP": 0.95,
"topK": 64,
"maxTemperature": 2,
"thinking": true,
},
{
"id": "gemini-2.5-flash",
"object": "model",
@@ -162,241 +85,41 @@ func (h *APIHandlers) Models(c *gin.Context) {
})
}
// ChatCompletions handles the /v1/chat/completions endpoint
func (h *APIHandlers) ChatCompletions(c *gin.Context) {
rawJson, err := c.GetRawData()
// If data retrieval fails, return 400 error
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("Invalid request: %v", err), "code": 400})
return
func (h *APIHandlers) getClient(modelName string, isGenerateContent ...bool) (*client.Client, *client.ErrorMessage) {
if len(h.cliClients) == 0 {
return nil, &client.ErrorMessage{StatusCode: 500, Error: fmt.Errorf("no clients available")}
}
streamResult := gjson.GetBytes(rawJson, "stream")
if streamResult.Type == gjson.True {
h.handleStreamingResponse(c, rawJson)
} else {
h.handleNonStreamingResponse(c, rawJson)
}
}
func (h *APIHandlers) prepareRequest(rawJson []byte) (string, []client.Content, []client.ToolDeclaration) {
// log.Debug(string(rawJson))
modelName := "gemini-2.5-pro"
modelResult := gjson.GetBytes(rawJson, "model")
if modelResult.Type == gjson.String {
modelName = modelResult.String()
}
contents := make([]client.Content, 0)
messagesResult := gjson.GetBytes(rawJson, "messages")
if messagesResult.IsArray() {
messagesResults := messagesResult.Array()
for i := 0; i < len(messagesResults); i++ {
messageResult := messagesResults[i]
roleResult := messageResult.Get("role")
contentResult := messageResult.Get("content")
if roleResult.Type == gjson.String {
if roleResult.String() == "system" {
if contentResult.Type == gjson.String {
contents = append(contents, client.Content{Role: "user", Parts: []client.Part{{Text: contentResult.String()}}})
} else if contentResult.IsObject() {
contentTypeResult := contentResult.Get("type")
if contentTypeResult.Type == gjson.String && contentTypeResult.String() == "text" {
contentTextResult := contentResult.Get("text")
if contentTextResult.Type == gjson.String {
contents = append(contents, client.Content{Role: "user", Parts: []client.Part{{Text: contentTextResult.String()}}})
contents = append(contents, client.Content{Role: "model", Parts: []client.Part{{Text: "Understood. I will follow these instructions and use my tools to assist you."}}})
}
}
}
} else if roleResult.String() == "user" {
if contentResult.Type == gjson.String {
contents = append(contents, client.Content{Role: "user", Parts: []client.Part{{Text: contentResult.String()}}})
} else if contentResult.IsObject() {
contentTypeResult := contentResult.Get("type")
if contentTypeResult.Type == gjson.String && contentTypeResult.String() == "text" {
contentTextResult := contentResult.Get("text")
if contentTextResult.Type == gjson.String {
contents = append(contents, client.Content{Role: "user", Parts: []client.Part{{Text: contentTextResult.String()}}})
}
}
} else if contentResult.IsArray() {
contentItemResults := contentResult.Array()
parts := make([]client.Part, 0)
for j := 0; j < len(contentItemResults); j++ {
contentItemResult := contentItemResults[j]
contentTypeResult := contentItemResult.Get("type")
if contentTypeResult.Type == gjson.String && contentTypeResult.String() == "text" {
contentTextResult := contentItemResult.Get("text")
if contentTextResult.Type == gjson.String {
parts = append(parts, client.Part{Text: contentTextResult.String()})
}
} else if contentTypeResult.Type == gjson.String && contentTypeResult.String() == "image_url" {
imageURLResult := contentItemResult.Get("image_url.url")
if imageURLResult.Type == gjson.String {
imageURL := imageURLResult.String()
if len(imageURL) > 5 {
imageURLs := strings.SplitN(imageURL[5:], ";", 2)
if len(imageURLs) == 2 {
if len(imageURLs[1]) > 7 {
parts = append(parts, client.Part{InlineData: &client.InlineData{
MimeType: imageURLs[0],
Data: imageURLs[1][7:],
}})
}
}
}
}
} else if contentTypeResult.Type == gjson.String && contentTypeResult.String() == "file" {
filenameResult := contentItemResult.Get("file.filename")
fileDataResult := contentItemResult.Get("file.file_data")
if filenameResult.Type == gjson.String && fileDataResult.Type == gjson.String {
filename := filenameResult.String()
splitFilename := strings.Split(filename, ".")
ext := splitFilename[len(splitFilename)-1]
mimeType, ok := MimeTypes[ext]
if !ok {
log.Warnf("Unknown file name extension '%s' at index %d, skipping file", ext, j)
continue
}
parts = append(parts, client.Part{InlineData: &client.InlineData{
MimeType: mimeType,
Data: fileDataResult.String(),
}})
}
}
}
contents = append(contents, client.Content{Role: "user", Parts: parts})
}
} else if roleResult.String() == "assistant" {
if contentResult.Type == gjson.String {
contents = append(contents, client.Content{Role: "model", Parts: []client.Part{{Text: contentResult.String()}}})
} else if contentResult.IsObject() {
contentTypeResult := contentResult.Get("type")
if contentTypeResult.Type == gjson.String && contentTypeResult.String() == "text" {
contentTextResult := contentResult.Get("text")
if contentTextResult.Type == gjson.String {
contents = append(contents, client.Content{Role: "user", Parts: []client.Part{{Text: contentTextResult.String()}}})
}
}
} else if !contentResult.Exists() || contentResult.Type == gjson.Null {
toolCallsResult := messageResult.Get("tool_calls")
if toolCallsResult.IsArray() {
tcsResult := toolCallsResult.Array()
for j := 0; j < len(tcsResult); j++ {
tcResult := tcsResult[j]
functionNameResult := tcResult.Get("function.name")
functionArguments := tcResult.Get("function.arguments")
if functionNameResult.Exists() && functionNameResult.Type == gjson.String && functionArguments.Exists() && functionArguments.Type == gjson.String {
var args map[string]any
err := json.Unmarshal([]byte(functionArguments.String()), &args)
if err == nil {
contents = append(contents, client.Content{
Role: "model", Parts: []client.Part{
{
FunctionCall: &client.FunctionCall{
Name: functionNameResult.String(),
Args: args,
},
},
},
})
}
}
}
}
}
} else if roleResult.String() == "tool" {
toolCallIDResult := messageResult.Get("tool_call_id")
if toolCallIDResult.Exists() && toolCallIDResult.Type == gjson.String {
if contentResult.Type == gjson.String {
functionResponse := client.FunctionResponse{Name: toolCallIDResult.String(), Response: map[string]interface{}{"result": contentResult.String()}}
contents = append(contents, client.Content{Role: "tool", Parts: []client.Part{{FunctionResponse: &functionResponse}}})
} else if contentResult.IsObject() {
contentTypeResult := contentResult.Get("type")
if contentTypeResult.Type == gjson.String && contentTypeResult.String() == "text" {
contentTextResult := contentResult.Get("text")
if contentTextResult.Type == gjson.String {
functionResponse := client.FunctionResponse{Name: toolCallIDResult.String(), Response: map[string]interface{}{"result": contentResult.String()}}
contents = append(contents, client.Content{Role: "tool", Parts: []client.Part{{FunctionResponse: &functionResponse}}})
}
}
}
}
}
}
}
}
var tools []client.ToolDeclaration
toolsResult := gjson.GetBytes(rawJson, "tools")
if toolsResult.IsArray() {
tools = make([]client.ToolDeclaration, 1)
tools[0].FunctionDeclarations = make([]any, 0)
toolsResults := toolsResult.Array()
for i := 0; i < len(toolsResults); i++ {
toolTypeResult := toolsResults[i].Get("type")
if toolTypeResult.Type != gjson.String || toolTypeResult.String() != "function" {
continue
}
functionTypeResult := toolsResults[i].Get("function")
if functionTypeResult.Exists() && functionTypeResult.IsObject() {
var functionDeclaration any
err := json.Unmarshal([]byte(functionTypeResult.Raw), &functionDeclaration)
if err == nil {
tools[0].FunctionDeclarations = append(tools[0].FunctionDeclarations, functionDeclaration)
}
}
}
} else {
tools = make([]client.ToolDeclaration, 0)
}
return modelName, contents, tools
}
// handleNonStreamingResponse handles non-streaming responses
func (h *APIHandlers) handleNonStreamingResponse(c *gin.Context, rawJson []byte) {
c.Header("Content-Type", "application/json")
// Handle streaming manually
flusher, ok := c.Writer.(http.Flusher)
if !ok {
c.JSON(http.StatusInternalServerError, ErrorResponse{
Error: ErrorDetail{
Message: "Streaming not supported",
Type: "server_error",
},
})
return
}
modelName, contents, tools := h.prepareRequest(rawJson)
cliCtx, cliCancel := context.WithCancel(context.Background())
var cliClient *client.Client
defer func() {
if cliClient != nil {
cliClient.RequestMutex.Unlock()
}
}()
// Lock the mutex to update the last used page index
// Lock the mutex to update the last used client index
mutex.Lock()
startIndex := lastUsedClientIndex
currentIndex := (startIndex + 1) % len(h.cliClients)
lastUsedClientIndex = currentIndex
if (len(isGenerateContent) > 0 && isGenerateContent[0]) || len(isGenerateContent) == 0 {
currentIndex := (startIndex + 1) % len(h.cliClients)
lastUsedClientIndex = currentIndex
}
mutex.Unlock()
// Reorder the pages to start from the last used index
reorderedPages := make([]*client.Client, len(h.cliClients))
// Reorder the client to start from the last used index
reorderedClients := make([]*client.Client, 0)
for i := 0; i < len(h.cliClients); i++ {
reorderedPages[i] = h.cliClients[(startIndex+1+i)%len(h.cliClients)]
cliClient = h.cliClients[(startIndex+1+i)%len(h.cliClients)]
if cliClient.IsModelQuotaExceeded(modelName) {
log.Debugf("Model %s is quota exceeded for account %s, project id: %s", modelName, cliClient.GetEmail(), cliClient.GetProjectID())
cliClient = nil
continue
}
reorderedClients = append(reorderedClients, cliClient)
}
if len(reorderedClients) == 0 {
return nil, &client.ErrorMessage{StatusCode: 429, Error: fmt.Errorf(`{"error":{"code":429,"message":"All the models of '%s' are quota exceeded","status":"RESOURCE_EXHAUSTED"}}`, modelName)}
}
locked := false
for i := 0; i < len(reorderedPages); i++ {
cliClient = reorderedPages[i]
for i := 0; i < len(reorderedClients); i++ {
cliClient = reorderedClients[i]
if cliClient.RequestMutex.TryLock() {
locked = true
break
@@ -407,40 +130,84 @@ func (h *APIHandlers) handleNonStreamingResponse(c *gin.Context, rawJson []byte)
cliClient.RequestMutex.Lock()
}
log.Debugf("Request use account: %s", cliClient.Email)
jsonTemplate := `{"id":"","object":"chat.completion","created":123456,"model":"model","choices":[{"index":0,"message":{"role":"assistant","content":null,"reasoning_content":null,"tool_calls":null},"finish_reason":null,"native_finish_reason":null}]}`
respChan, errChan := cliClient.SendMessageStream(cliCtx, rawJson, modelName, contents, tools)
return cliClient, nil
}
// ChatCompletions handles the /v1/chat/completions endpoint.
// It determines whether the request is for a streaming or non-streaming response
// and calls the appropriate handler.
func (h *APIHandlers) ChatCompletions(c *gin.Context) {
rawJson, err := c.GetRawData()
// If data retrieval fails, return a 400 Bad Request error.
if err != nil {
c.JSON(http.StatusBadRequest, ErrorResponse{
Error: ErrorDetail{
Message: fmt.Sprintf("Invalid request: %v", err),
Type: "invalid_request_error",
},
})
return
}
// Check if the client requested a streaming response.
streamResult := gjson.GetBytes(rawJson, "stream")
if streamResult.Type == gjson.True {
h.handleStreamingResponse(c, rawJson)
} else {
h.handleNonStreamingResponse(c, rawJson)
}
}
// handleNonStreamingResponse handles non-streaming chat completion responses.
// It selects a client from the pool, sends the request, and aggregates the response
// before sending it back to the client.
func (h *APIHandlers) handleNonStreamingResponse(c *gin.Context, rawJson []byte) {
c.Header("Content-Type", "application/json")
modelName, systemInstruction, contents, tools := translator.PrepareRequest(rawJson)
cliCtx, cliCancel := context.WithCancel(context.Background())
var cliClient *client.Client
defer func() {
if cliClient != nil {
cliClient.RequestMutex.Unlock()
}
}()
for {
select {
case <-c.Request.Context().Done():
if c.Request.Context().Err().Error() == "context canceled" {
log.Debugf("Client disconnected: %v", c.Request.Context().Err())
cliCancel()
return
}
case chunk, okStream := <-respChan:
if !okStream {
_, _ = fmt.Fprint(c.Writer, jsonTemplate)
flusher.Flush()
cliCancel()
return
var errorResponse *client.ErrorMessage
cliClient, errorResponse = h.getClient(modelName)
if errorResponse != nil {
c.Status(errorResponse.StatusCode)
_, _ = fmt.Fprint(c.Writer, errorResponse.Error)
cliCancel()
return
}
isGlAPIKey := false
if glAPIKey := cliClient.GetGenerativeLanguageAPIKey(); glAPIKey != "" {
log.Debugf("Request use generative language API Key: %s", glAPIKey)
isGlAPIKey = true
} else {
log.Debugf("Request use account: %s, project id: %s", cliClient.GetEmail(), cliClient.GetProjectID())
}
resp, err := cliClient.SendMessage(cliCtx, rawJson, modelName, systemInstruction, contents, tools)
if err != nil {
if err.StatusCode == 429 && h.cfg.QuotaExceeded.SwitchProject {
continue
} else {
jsonTemplate = h.convertCliToOpenAINonStream(jsonTemplate, chunk)
}
case err, okError := <-errChan:
if okError {
c.JSON(http.StatusInternalServerError, ErrorResponse{
Error: ErrorDetail{
Message: err.Error(),
Type: "server_error",
},
})
c.Status(err.StatusCode)
_, _ = c.Writer.Write([]byte(err.Error.Error()))
cliCancel()
return
}
case <-time.After(500 * time.Millisecond):
_, _ = c.Writer.Write([]byte("\n"))
flusher.Flush()
break
} else {
openAIFormat := translator.ConvertCliToOpenAINonStream(resp, time.Now().Unix(), isGlAPIKey)
if openAIFormat != "" {
_, _ = c.Writer.Write([]byte(openAIFormat))
}
cliCancel()
break
}
}
}
@@ -452,7 +219,7 @@ func (h *APIHandlers) handleStreamingResponse(c *gin.Context, rawJson []byte) {
c.Header("Connection", "keep-alive")
c.Header("Access-Control-Allow-Origin", "*")
// Handle streaming manually
// Get the http.Flusher interface to manually flush the response.
flusher, ok := c.Writer.(http.Flusher)
if !ok {
c.JSON(http.StatusInternalServerError, ErrorResponse{
@@ -463,262 +230,86 @@ func (h *APIHandlers) handleStreamingResponse(c *gin.Context, rawJson []byte) {
})
return
}
modelName, contents, tools := h.prepareRequest(rawJson)
// Prepare the request for the backend client.
modelName, systemInstruction, contents, tools := translator.PrepareRequest(rawJson)
cliCtx, cliCancel := context.WithCancel(context.Background())
var cliClient *client.Client
defer func() {
// Ensure the client's mutex is unlocked on function exit.
if cliClient != nil {
cliClient.RequestMutex.Unlock()
}
}()
// Lock the mutex to update the last used page index
mutex.Lock()
startIndex := lastUsedClientIndex
currentIndex := (startIndex + 1) % len(h.cliClients)
lastUsedClientIndex = currentIndex
mutex.Unlock()
// Reorder the pages to start from the last used index
reorderedPages := make([]*client.Client, len(h.cliClients))
for i := 0; i < len(h.cliClients); i++ {
reorderedPages[i] = h.cliClients[(startIndex+1+i)%len(h.cliClients)]
}
locked := false
for i := 0; i < len(reorderedPages); i++ {
cliClient = reorderedPages[i]
if cliClient.RequestMutex.TryLock() {
locked = true
break
}
}
if !locked {
cliClient = h.cliClients[0]
cliClient.RequestMutex.Lock()
}
log.Debugf("Request use account: %s", cliClient.Email)
respChan, errChan := cliClient.SendMessageStream(cliCtx, rawJson, modelName, contents, tools)
outLoop:
for {
select {
case <-c.Request.Context().Done():
if c.Request.Context().Err().Error() == "context canceled" {
log.Debugf("Client disconnected: %v", c.Request.Context().Err())
cliCancel()
return
}
case chunk, okStream := <-respChan:
if !okStream {
_, _ = fmt.Fprintf(c.Writer, "data: [DONE]\n\n")
flusher.Flush()
cliCancel()
return
} else {
openAIFormat := h.convertCliToOpenAI(chunk)
if openAIFormat != "" {
_, _ = fmt.Fprintf(c.Writer, "data: %s\n\n", openAIFormat)
var errorResponse *client.ErrorMessage
cliClient, errorResponse = h.getClient(modelName)
if errorResponse != nil {
c.Status(errorResponse.StatusCode)
_, _ = fmt.Fprint(c.Writer, errorResponse.Error)
flusher.Flush()
cliCancel()
return
}
isGlAPIKey := false
if glAPIKey := cliClient.GetGenerativeLanguageAPIKey(); glAPIKey != "" {
log.Debugf("Request use generative language API Key: %s", glAPIKey)
isGlAPIKey = true
} else {
log.Debugf("Request use account: %s, project id: %s", cliClient.GetEmail(), cliClient.GetProjectID())
}
// Send the message and receive response chunks and errors via channels.
respChan, errChan := cliClient.SendMessageStream(cliCtx, rawJson, modelName, systemInstruction, contents, tools)
hasFirstResponse := false
for {
select {
// Handle client disconnection.
case <-c.Request.Context().Done():
if c.Request.Context().Err().Error() == "context canceled" {
log.Debugf("Client disconnected: %v", c.Request.Context().Err())
cliCancel() // Cancel the backend request.
return
}
// Process incoming response chunks.
case chunk, okStream := <-respChan:
if !okStream {
// Stream is closed, send the final [DONE] message.
_, _ = fmt.Fprintf(c.Writer, "data: [DONE]\n\n")
flusher.Flush()
cliCancel()
return
} else {
// Convert the chunk to OpenAI format and send it to the client.
hasFirstResponse = true
openAIFormat := translator.ConvertCliToOpenAI(chunk, time.Now().Unix(), isGlAPIKey)
if openAIFormat != "" {
_, _ = fmt.Fprintf(c.Writer, "data: %s\n\n", openAIFormat)
flusher.Flush()
}
}
// Handle errors from the backend.
case err, okError := <-errChan:
if okError {
if err.StatusCode == 429 && h.cfg.QuotaExceeded.SwitchProject {
continue outLoop
} else {
c.Status(err.StatusCode)
_, _ = fmt.Fprint(c.Writer, err.Error.Error())
flusher.Flush()
cliCancel()
}
return
}
// Send a keep-alive signal to the client.
case <-time.After(500 * time.Millisecond):
if hasFirstResponse {
_, _ = c.Writer.Write([]byte(": CLI-PROXY-API PROCESSING\n\n"))
flusher.Flush()
}
}
case err, okError := <-errChan:
if okError {
c.JSON(http.StatusInternalServerError, ErrorResponse{
Error: ErrorDetail{
Message: err.Error(),
Type: "server_error",
},
})
cliCancel()
return
}
case <-time.After(500 * time.Millisecond):
_, _ = c.Writer.Write([]byte(": CLI-PROXY-API PROCESSING\n\n"))
flusher.Flush()
}
}
}
func (h *APIHandlers) convertCliToOpenAI(rawJson []byte) string {
// log.Debugf(string(rawJson))
template := `{"id":"","object":"chat.completion.chunk","created":12345,"model":"model","choices":[{"index":0,"delta":{"role":null,"content":null,"reasoning_content":null,"tool_calls":null},"finish_reason":null,"native_finish_reason":null}]}`
modelVersionResult := gjson.GetBytes(rawJson, "response.modelVersion")
if modelVersionResult.Exists() && modelVersionResult.Type == gjson.String {
template, _ = sjson.Set(template, "model", modelVersionResult.String())
}
createTimeResult := gjson.GetBytes(rawJson, "response.createTime")
if createTimeResult.Exists() && createTimeResult.Type == gjson.String {
t, err := time.Parse(time.RFC3339Nano, createTimeResult.String())
var unixTimestamp int64
if err == nil {
unixTimestamp = t.Unix()
} else {
unixTimestamp = time.Now().Unix()
}
template, _ = sjson.Set(template, "created", unixTimestamp)
}
responseIdResult := gjson.GetBytes(rawJson, "response.responseId")
if responseIdResult.Exists() && responseIdResult.Type == gjson.String {
template, _ = sjson.Set(template, "id", responseIdResult.String())
}
finishReasonResult := gjson.GetBytes(rawJson, "response.candidates.0.finishReason")
if finishReasonResult.Exists() && finishReasonResult.Type == gjson.String {
template, _ = sjson.Set(template, "choices.0.finish_reason", finishReasonResult.String())
template, _ = sjson.Set(template, "choices.0.native_finish_reason", finishReasonResult.String())
}
usageResult := gjson.GetBytes(rawJson, "response.usageMetadata")
candidatesTokenCountResult := usageResult.Get("candidatesTokenCount")
if candidatesTokenCountResult.Exists() && candidatesTokenCountResult.Type == gjson.Number {
template, _ = sjson.Set(template, "usage.completion_tokens", candidatesTokenCountResult.Int())
}
totalTokenCountResult := usageResult.Get("totalTokenCount")
if totalTokenCountResult.Exists() && totalTokenCountResult.Type == gjson.Number {
template, _ = sjson.Set(template, "usage.total_tokens", totalTokenCountResult.Int())
}
thoughtsTokenCountResult := usageResult.Get("thoughtsTokenCount")
promptTokenCountResult := usageResult.Get("promptTokenCount")
if promptTokenCountResult.Exists() && promptTokenCountResult.Type == gjson.Number {
if thoughtsTokenCountResult.Exists() && thoughtsTokenCountResult.Type == gjson.Number {
template, _ = sjson.Set(template, "usage.prompt_tokens", promptTokenCountResult.Int()+thoughtsTokenCountResult.Int())
} else {
template, _ = sjson.Set(template, "usage.prompt_tokens", promptTokenCountResult.Int())
}
}
if thoughtsTokenCountResult.Exists() && thoughtsTokenCountResult.Type == gjson.Number {
template, _ = sjson.Set(template, "usage.completion_tokens_details.reasoning_tokens", thoughtsTokenCountResult.Int())
}
partResult := gjson.GetBytes(rawJson, "response.candidates.0.content.parts.0")
partTextResult := partResult.Get("text")
functionCallResult := partResult.Get("functionCall")
if partTextResult.Exists() && partTextResult.Type == gjson.String {
partThoughtResult := partResult.Get("thought")
if partThoughtResult.Exists() && partThoughtResult.Type == gjson.True {
template, _ = sjson.Set(template, "choices.0.delta.reasoning_content", partTextResult.String())
} else {
template, _ = sjson.Set(template, "choices.0.delta.content", partTextResult.String())
}
template, _ = sjson.Set(template, "choices.0.delta.role", "assistant")
} else if functionCallResult.Exists() {
functionCallTemplate := `[{"id": "","type": "function","function": {"name": "","arguments": ""}}]`
fcNameResult := functionCallResult.Get("name")
if fcNameResult.Exists() && fcNameResult.Type == gjson.String {
functionCallTemplate, _ = sjson.Set(functionCallTemplate, "0.id", fcNameResult.String())
functionCallTemplate, _ = sjson.Set(functionCallTemplate, "0.function.name", fcNameResult.String())
}
fcArgsResult := functionCallResult.Get("args")
if fcArgsResult.Exists() && fcArgsResult.IsObject() {
functionCallTemplate, _ = sjson.Set(functionCallTemplate, "0.function.arguments", fcArgsResult.Raw)
}
template, _ = sjson.Set(template, "choices.0.delta.role", "assistant")
template, _ = sjson.SetRaw(template, "choices.0.delta.tool_calls", functionCallTemplate)
} else {
return ""
}
return template
}
func (h *APIHandlers) convertCliToOpenAINonStream(template string, rawJson []byte) string {
modelVersionResult := gjson.GetBytes(rawJson, "response.modelVersion")
if modelVersionResult.Exists() && modelVersionResult.Type == gjson.String {
template, _ = sjson.Set(template, "model", modelVersionResult.String())
}
createTimeResult := gjson.GetBytes(rawJson, "response.createTime")
if createTimeResult.Exists() && createTimeResult.Type == gjson.String {
t, err := time.Parse(time.RFC3339Nano, createTimeResult.String())
var unixTimestamp int64
if err == nil {
unixTimestamp = t.Unix()
} else {
unixTimestamp = time.Now().Unix()
}
template, _ = sjson.Set(template, "created", unixTimestamp)
}
responseIdResult := gjson.GetBytes(rawJson, "response.responseId")
if responseIdResult.Exists() && responseIdResult.Type == gjson.String {
template, _ = sjson.Set(template, "id", responseIdResult.String())
}
finishReasonResult := gjson.GetBytes(rawJson, "response.candidates.0.finishReason")
if finishReasonResult.Exists() && finishReasonResult.Type == gjson.String {
template, _ = sjson.Set(template, "choices.0.finish_reason", finishReasonResult.String())
template, _ = sjson.Set(template, "choices.0.native_finish_reason", finishReasonResult.String())
}
usageResult := gjson.GetBytes(rawJson, "response.usageMetadata")
candidatesTokenCountResult := usageResult.Get("candidatesTokenCount")
if candidatesTokenCountResult.Exists() && candidatesTokenCountResult.Type == gjson.Number {
template, _ = sjson.Set(template, "usage.completion_tokens", candidatesTokenCountResult.Int())
}
totalTokenCountResult := usageResult.Get("totalTokenCount")
if totalTokenCountResult.Exists() && totalTokenCountResult.Type == gjson.Number {
template, _ = sjson.Set(template, "usage.total_tokens", totalTokenCountResult.Int())
}
thoughtsTokenCountResult := usageResult.Get("thoughtsTokenCount")
promptTokenCountResult := usageResult.Get("promptTokenCount")
if promptTokenCountResult.Exists() && promptTokenCountResult.Type == gjson.Number {
if thoughtsTokenCountResult.Exists() && thoughtsTokenCountResult.Type == gjson.Number {
template, _ = sjson.Set(template, "usage.prompt_tokens", promptTokenCountResult.Int()+thoughtsTokenCountResult.Int())
} else {
template, _ = sjson.Set(template, "usage.prompt_tokens", promptTokenCountResult.Int())
}
}
if thoughtsTokenCountResult.Exists() && thoughtsTokenCountResult.Type == gjson.Number {
template, _ = sjson.Set(template, "usage.completion_tokens_details.reasoning_tokens", thoughtsTokenCountResult.Int())
}
partResult := gjson.GetBytes(rawJson, "response.candidates.0.content.parts.0")
partTextResult := partResult.Get("text")
functionCallResult := partResult.Get("functionCall")
if partTextResult.Exists() && partTextResult.Type == gjson.String {
partThoughtResult := partResult.Get("thought")
if partThoughtResult.Exists() && partThoughtResult.Type == gjson.True {
reasoningContentResult := gjson.Get(template, "choices.0.message.reasoning_content")
if reasoningContentResult.Type == gjson.String {
template, _ = sjson.Set(template, "choices.0.message.reasoning_content", reasoningContentResult.String()+partTextResult.String())
} else {
template, _ = sjson.Set(template, "choices.0.message.reasoning_content", partTextResult.String())
}
} else {
reasoningContentResult := gjson.Get(template, "choices.0.message.content")
if reasoningContentResult.Type == gjson.String {
template, _ = sjson.Set(template, "choices.0.message.content", reasoningContentResult.String()+partTextResult.String())
} else {
template, _ = sjson.Set(template, "choices.0.message.content", partTextResult.String())
}
}
template, _ = sjson.Set(template, "choices.0.message.role", "assistant")
} else if functionCallResult.Exists() {
toolCallsResult := gjson.Get(template, "choices.0.message.tool_calls")
if !toolCallsResult.Exists() || toolCallsResult.Type == gjson.Null {
template, _ = sjson.SetRaw(template, "choices.0.message.tool_calls", `[]`)
}
functionCallItemTemplate := `{"id": "","type": "function","function": {"name": "","arguments": ""}}`
fcNameResult := functionCallResult.Get("name")
if fcNameResult.Exists() && fcNameResult.Type == gjson.String {
functionCallItemTemplate, _ = sjson.Set(functionCallItemTemplate, "id", fcNameResult.String())
functionCallItemTemplate, _ = sjson.Set(functionCallItemTemplate, "function.name", fcNameResult.String())
}
fcArgsResult := functionCallResult.Get("args")
if fcArgsResult.Exists() && fcArgsResult.IsObject() {
functionCallItemTemplate, _ = sjson.Set(functionCallItemTemplate, "function.arguments", fcArgsResult.Raw)
}
template, _ = sjson.Set(template, "choices.0.message.role", "assistant")
template, _ = sjson.SetRaw(template, "choices.0.message.tool_calls.-1", functionCallItemTemplate)
} else {
return ""
}
return template
}

View File

@@ -1,13 +1,18 @@
package api
// ErrorResponse represents an error response
// ErrorResponse represents a standard error response format for the API.
// It contains a single ErrorDetail field.
type ErrorResponse struct {
Error ErrorDetail `json:"error"`
}
// ErrorDetail represents error details
// ErrorDetail provides specific information about an error that occurred.
// It includes a human-readable message, an error type, and an optional error code.
type ErrorDetail struct {
// A human-readable message providing more details about the error.
Message string `json:"message"`
Type string `json:"type"`
Code string `json:"code,omitempty"`
// The type of error that occurred (e.g., "invalid_request_error").
Type string `json:"type"`
// A short code identifying the error, if applicable.
Code string `json:"code,omitempty"`
}

View File

@@ -6,35 +6,31 @@ import (
"fmt"
"github.com/gin-gonic/gin"
"github.com/luispater/CLIProxyAPI/internal/client"
"github.com/luispater/CLIProxyAPI/internal/config"
log "github.com/sirupsen/logrus"
"net/http"
"strings"
)
// Server represents the API server
// Server represents the main API server.
// It encapsulates the Gin engine, HTTP server, handlers, and configuration.
type Server struct {
engine *gin.Engine
server *http.Server
handlers *APIHandlers
cfg *ServerConfig
cfg *config.Config
}
// ServerConfig contains configuration for the API server
type ServerConfig struct {
Port string
Debug bool
ApiKeys []string
}
// NewServer creates a new API server instance
func NewServer(config *ServerConfig, cliClients []*client.Client) *Server {
// NewServer creates and initializes a new API server instance.
// It sets up the Gin engine, middleware, routes, and handlers.
func NewServer(cfg *config.Config, cliClients []*client.Client) *Server {
// Set gin mode
if !config.Debug {
if !cfg.Debug {
gin.SetMode(gin.ReleaseMode)
}
// Create handlers
handlers := NewAPIHandlers(cliClients, config.Debug)
handlers := NewAPIHandlers(cliClients, cfg)
// Create gin engine
engine := gin.New()
@@ -48,7 +44,7 @@ func NewServer(config *ServerConfig, cliClients []*client.Client) *Server {
s := &Server{
engine: engine,
handlers: handlers,
cfg: config,
cfg: cfg,
}
// Setup routes
@@ -56,14 +52,15 @@ func NewServer(config *ServerConfig, cliClients []*client.Client) *Server {
// Create HTTP server
s.server = &http.Server{
Addr: ":" + config.Port,
Addr: fmt.Sprintf(":%d", cfg.Port),
Handler: engine,
}
return s
}
// setupRoutes configures the API routes
// setupRoutes configures the API routes for the server.
// It defines the endpoints and associates them with their respective handlers.
func (s *Server) setupRoutes() {
// OpenAI compatible API routes
v1 := s.engine.Group("/v1")
@@ -71,6 +68,16 @@ func (s *Server) setupRoutes() {
{
v1.GET("/models", s.handlers.Models)
v1.POST("/chat/completions", s.handlers.ChatCompletions)
v1.POST("/messages", s.handlers.ClaudeMessages)
}
// Gemini compatible API routes
v1beta := s.engine.Group("/v1beta")
v1beta.Use(AuthMiddleware(s.cfg))
{
v1beta.GET("/models", s.handlers.GeminiModels)
v1beta.POST("/models/:action", s.handlers.GeminiHandler)
v1beta.GET("/models/:action", s.handlers.GeminiGetHandler)
}
// Root endpoint
@@ -84,13 +91,16 @@ func (s *Server) setupRoutes() {
},
})
})
s.engine.POST("/v1internal:method", s.handlers.CLIHandler)
}
// Start starts the API server
// Start begins listening for and serving HTTP requests.
// It's a blocking call and will only return on an unrecoverable error.
func (s *Server) Start() error {
log.Debugf("Starting API server on %s", s.server.Addr)
// Start the HTTP server
// Start the HTTP server.
if err := s.server.ListenAndServe(); err != nil && !errors.Is(err, http.ErrServerClosed) {
return fmt.Errorf("failed to start HTTP server: %v", err)
}
@@ -98,11 +108,12 @@ func (s *Server) Start() error {
return nil
}
// Stop gracefully stops the API server
// Stop gracefully shuts down the API server without interrupting any
// active connections.
func (s *Server) Stop(ctx context.Context) error {
log.Debug("Stopping API server...")
// Shutdown the HTTP server
// Shutdown the HTTP server.
if err := s.server.Shutdown(ctx); err != nil {
return fmt.Errorf("failed to shutdown HTTP server: %v", err)
}
@@ -111,7 +122,8 @@ func (s *Server) Stop(ctx context.Context) error {
return nil
}
// corsMiddleware adds CORS headers
// corsMiddleware returns a Gin middleware handler that adds CORS headers
// to every response, allowing cross-origin requests.
func corsMiddleware() gin.HandlerFunc {
return func(c *gin.Context) {
c.Header("Access-Control-Allow-Origin", "*")
@@ -127,8 +139,9 @@ func corsMiddleware() gin.HandlerFunc {
}
}
// AuthMiddleware authenticates requests using API keys
func AuthMiddleware(cfg *ServerConfig) gin.HandlerFunc {
// AuthMiddleware returns a Gin middleware handler that authenticates requests
// using API keys. If no API keys are configured, it allows all requests.
func AuthMiddleware(cfg *config.Config) gin.HandlerFunc {
return func(c *gin.Context) {
if len(cfg.ApiKeys) == 0 {
c.Next()
@@ -137,7 +150,13 @@ func AuthMiddleware(cfg *ServerConfig) gin.HandlerFunc {
// Get the Authorization header
authHeader := c.GetHeader("Authorization")
if authHeader == "" {
authHeaderGoogle := c.GetHeader("X-Goog-Api-Key")
authHeaderAnthropic := c.GetHeader("X-Api-Key")
// Get the API key from the query parameter
apiKeyQuery, _ := c.GetQuery("key")
if authHeader == "" && authHeaderGoogle == "" && authHeaderAnthropic == "" && apiKeyQuery == "" {
c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{
"error": "Missing API key",
})
@@ -156,7 +175,7 @@ func AuthMiddleware(cfg *ServerConfig) gin.HandlerFunc {
// Find the API key in the in-memory list
var foundKey string
for i := range cfg.ApiKeys {
if cfg.ApiKeys[i] == apiKey {
if cfg.ApiKeys[i] == apiKey || cfg.ApiKeys[i] == authHeaderGoogle || cfg.ApiKeys[i] == authHeaderAnthropic || cfg.ApiKeys[i] == apiKeyQuery {
foundKey = cfg.ApiKeys[i]
break
}

View File

@@ -1,5 +1,7 @@
package api
package translator
// MimeTypes is a comprehensive map of file extensions to their corresponding MIME types.
// This is used to identify the type of file being uploaded or processed.
var MimeTypes = map[string]string{
"ez": "application/andrew-inset",
"aw": "application/applixware",

View File

@@ -0,0 +1,544 @@
package translator
import (
"bytes"
"encoding/json"
"fmt"
"github.com/tidwall/sjson"
"strings"
"github.com/luispater/CLIProxyAPI/internal/client"
log "github.com/sirupsen/logrus"
"github.com/tidwall/gjson"
)
// PrepareRequest translates a raw JSON request from an OpenAI-compatible format
// to the internal format expected by the backend client. It parses messages,
// roles, content types (text, image, file), and tool calls.
func PrepareRequest(rawJson []byte) (string, *client.Content, []client.Content, []client.ToolDeclaration) {
// Extract the model name from the request, defaulting to "gemini-2.5-pro".
modelName := "gemini-2.5-pro"
modelResult := gjson.GetBytes(rawJson, "model")
if modelResult.Type == gjson.String {
modelName = modelResult.String()
}
// Initialize data structures for processing conversation messages
// contents: stores the processed conversation history
// systemInstruction: stores system-level instructions separate from conversation
contents := make([]client.Content, 0)
var systemInstruction *client.Content
messagesResult := gjson.GetBytes(rawJson, "messages")
// Pre-process tool responses to create a lookup map
// This first pass collects all tool responses so they can be matched with their corresponding calls
toolItems := make(map[string]*client.FunctionResponse)
if messagesResult.IsArray() {
messagesResults := messagesResult.Array()
for i := 0; i < len(messagesResults); i++ {
messageResult := messagesResults[i]
roleResult := messageResult.Get("role")
if roleResult.Type != gjson.String {
continue
}
contentResult := messageResult.Get("content")
// Extract tool responses for later matching with function calls
if roleResult.String() == "tool" {
toolCallID := messageResult.Get("tool_call_id").String()
if toolCallID != "" {
var responseData string
// Handle both string and object-based tool response formats
if contentResult.Type == gjson.String {
responseData = contentResult.String()
} else if contentResult.IsObject() && contentResult.Get("type").String() == "text" {
responseData = contentResult.Get("text").String()
}
// Clean up tool call ID by removing timestamp suffix
// This normalizes IDs for consistent matching between calls and responses
toolCallIDs := strings.Split(toolCallID, "-")
strings.Join(toolCallIDs, "-")
newToolCallID := strings.Join(toolCallIDs[:len(toolCallIDs)-1], "-")
// Create function response object with normalized ID and response data
functionResponse := client.FunctionResponse{Name: newToolCallID, Response: map[string]interface{}{"result": responseData}}
toolItems[toolCallID] = &functionResponse
}
}
}
}
if messagesResult.IsArray() {
messagesResults := messagesResult.Array()
for i := 0; i < len(messagesResults); i++ {
messageResult := messagesResults[i]
roleResult := messageResult.Get("role")
contentResult := messageResult.Get("content")
if roleResult.Type != gjson.String {
continue
}
switch roleResult.String() {
// System messages are converted to a user message followed by a model's acknowledgment.
case "system":
if contentResult.Type == gjson.String {
systemInstruction = &client.Content{Role: "user", Parts: []client.Part{{Text: contentResult.String()}}}
} else if contentResult.IsObject() {
// Handle object-based system messages.
if contentResult.Get("type").String() == "text" {
systemInstruction = &client.Content{Role: "user", Parts: []client.Part{{Text: contentResult.Get("text").String()}}}
}
}
// User messages can contain simple text or a multi-part body.
case "user":
if contentResult.Type == gjson.String {
contents = append(contents, client.Content{Role: "user", Parts: []client.Part{{Text: contentResult.String()}}})
} else if contentResult.IsArray() {
// Handle multi-part user messages (text, images, files).
contentItemResults := contentResult.Array()
parts := make([]client.Part, 0)
for j := 0; j < len(contentItemResults); j++ {
contentItemResult := contentItemResults[j]
contentTypeResult := contentItemResult.Get("type")
switch contentTypeResult.String() {
case "text":
parts = append(parts, client.Part{Text: contentItemResult.Get("text").String()})
case "image_url":
// Parse data URI for images.
imageURL := contentItemResult.Get("image_url.url").String()
if len(imageURL) > 5 {
imageURLs := strings.SplitN(imageURL[5:], ";", 2)
if len(imageURLs) == 2 && len(imageURLs[1]) > 7 {
parts = append(parts, client.Part{InlineData: &client.InlineData{
MimeType: imageURLs[0],
Data: imageURLs[1][7:],
}})
}
}
case "file":
// Handle file attachments by determining MIME type from extension.
filename := contentItemResult.Get("file.filename").String()
fileData := contentItemResult.Get("file.file_data").String()
ext := ""
if split := strings.Split(filename, "."); len(split) > 1 {
ext = split[len(split)-1]
}
if mimeType, ok := MimeTypes[ext]; ok {
parts = append(parts, client.Part{InlineData: &client.InlineData{
MimeType: mimeType,
Data: fileData,
}})
} else {
log.Warnf("Unknown file name extension '%s' at index %d, skipping file", ext, j)
}
}
}
contents = append(contents, client.Content{Role: "user", Parts: parts})
}
// Assistant messages can contain text responses or tool calls
// In the internal format, assistant messages are converted to "model" role
case "assistant":
if contentResult.Type == gjson.String {
// Simple text response from the assistant
contents = append(contents, client.Content{Role: "model", Parts: []client.Part{{Text: contentResult.String()}}})
} else if !contentResult.Exists() || contentResult.Type == gjson.Null {
// Handle complex tool calls made by the assistant
// This processes function calls and matches them with their responses
functionIDs := make([]string, 0)
toolCallsResult := messageResult.Get("tool_calls")
if toolCallsResult.IsArray() {
parts := make([]client.Part, 0)
tcsResult := toolCallsResult.Array()
// Process each tool call in the assistant's message
for j := 0; j < len(tcsResult); j++ {
tcResult := tcsResult[j]
// Extract function call details
functionID := tcResult.Get("id").String()
functionIDs = append(functionIDs, functionID)
functionName := tcResult.Get("function.name").String()
functionArgs := tcResult.Get("function.arguments").String()
// Parse function arguments from JSON string to map
var args map[string]any
if err := json.Unmarshal([]byte(functionArgs), &args); err == nil {
parts = append(parts, client.Part{
FunctionCall: &client.FunctionCall{
Name: functionName,
Args: args,
},
})
}
}
// Add the model's function calls to the conversation
if len(parts) > 0 {
contents = append(contents, client.Content{
Role: "model", Parts: parts,
})
// Create a separate tool response message with the collected responses
// This matches function calls with their corresponding responses
toolParts := make([]client.Part, 0)
for _, functionID := range functionIDs {
if functionResponse, ok := toolItems[functionID]; ok {
toolParts = append(toolParts, client.Part{FunctionResponse: functionResponse})
}
}
// Add the tool responses as a separate message in the conversation
contents = append(contents, client.Content{Role: "tool", Parts: toolParts})
}
}
}
}
}
}
// Translate the tool declarations from the request.
var tools []client.ToolDeclaration
toolsResult := gjson.GetBytes(rawJson, "tools")
if toolsResult.IsArray() {
tools = make([]client.ToolDeclaration, 1)
tools[0].FunctionDeclarations = make([]any, 0)
toolsResults := toolsResult.Array()
for i := 0; i < len(toolsResults); i++ {
toolResult := toolsResults[i]
if toolResult.Get("type").String() == "function" {
functionTypeResult := toolResult.Get("function")
if functionTypeResult.Exists() && functionTypeResult.IsObject() {
var functionDeclaration any
if err := json.Unmarshal([]byte(functionTypeResult.Raw), &functionDeclaration); err == nil {
tools[0].FunctionDeclarations = append(tools[0].FunctionDeclarations, functionDeclaration)
}
}
}
}
} else {
tools = make([]client.ToolDeclaration, 0)
}
return modelName, systemInstruction, contents, tools
}
// FunctionCallGroup represents a group of function calls and their responses
type FunctionCallGroup struct {
ModelContent map[string]interface{}
FunctionCalls []gjson.Result
ResponsesNeeded int
}
// FixCLIToolResponse performs sophisticated tool response format conversion and grouping.
// This function transforms the CLI tool response format by intelligently grouping function calls
// with their corresponding responses, ensuring proper conversation flow and API compatibility.
// It converts from a linear format (1.json) to a grouped format (2.json) where function calls
// and their responses are properly associated and structured.
func FixCLIToolResponse(input string) (string, error) {
// Parse the input JSON to extract the conversation structure
parsed := gjson.Parse(input)
// Extract the contents array which contains the conversation messages
contents := parsed.Get("request.contents")
if !contents.Exists() {
return input, fmt.Errorf("contents not found in input")
}
// Initialize data structures for processing and grouping
var newContents []interface{} // Final processed contents array
var pendingGroups []*FunctionCallGroup // Groups awaiting completion with responses
var collectedResponses []gjson.Result // Standalone responses to be matched
// Process each content object in the conversation
// This iterates through messages and groups function calls with their responses
contents.ForEach(func(key, value gjson.Result) bool {
role := value.Get("role").String()
parts := value.Get("parts")
// Check if this content has function responses
var responsePartsInThisContent []gjson.Result
parts.ForEach(func(_, part gjson.Result) bool {
if part.Get("functionResponse").Exists() {
responsePartsInThisContent = append(responsePartsInThisContent, part)
}
return true
})
// If this content has function responses, collect them
if len(responsePartsInThisContent) > 0 {
collectedResponses = append(collectedResponses, responsePartsInThisContent...)
// Check if any pending groups can be satisfied
for i := len(pendingGroups) - 1; i >= 0; i-- {
group := pendingGroups[i]
if len(collectedResponses) >= group.ResponsesNeeded {
// Take the needed responses for this group
groupResponses := collectedResponses[:group.ResponsesNeeded]
collectedResponses = collectedResponses[group.ResponsesNeeded:]
// Create merged function response content
var responseParts []interface{}
for _, response := range groupResponses {
var responseMap map[string]interface{}
errUnmarshal := json.Unmarshal([]byte(response.Raw), &responseMap)
if errUnmarshal != nil {
log.Warnf("failed to unmarshal function response: %v\n", errUnmarshal)
continue
}
responseParts = append(responseParts, responseMap)
}
if len(responseParts) > 0 {
functionResponseContent := map[string]interface{}{
"parts": responseParts,
"role": "function",
}
newContents = append(newContents, functionResponseContent)
}
// Remove this group as it's been satisfied
pendingGroups = append(pendingGroups[:i], pendingGroups[i+1:]...)
break
}
}
return true // Skip adding this content, responses are merged
}
// If this is a model with function calls, create a new group
if role == "model" {
var functionCallsInThisModel []gjson.Result
parts.ForEach(func(_, part gjson.Result) bool {
if part.Get("functionCall").Exists() {
functionCallsInThisModel = append(functionCallsInThisModel, part)
}
return true
})
if len(functionCallsInThisModel) > 0 {
// Add the model content
var contentMap map[string]interface{}
errUnmarshal := json.Unmarshal([]byte(value.Raw), &contentMap)
if errUnmarshal != nil {
log.Warnf("failed to unmarshal model content: %v\n", errUnmarshal)
return true
}
newContents = append(newContents, contentMap)
// Create a new group for tracking responses
group := &FunctionCallGroup{
ModelContent: contentMap,
FunctionCalls: functionCallsInThisModel,
ResponsesNeeded: len(functionCallsInThisModel),
}
pendingGroups = append(pendingGroups, group)
} else {
// Regular model content without function calls
var contentMap map[string]interface{}
errUnmarshal := json.Unmarshal([]byte(value.Raw), &contentMap)
if errUnmarshal != nil {
log.Warnf("failed to unmarshal content: %v\n", errUnmarshal)
return true
}
newContents = append(newContents, contentMap)
}
} else {
// Non-model content (user, etc.)
var contentMap map[string]interface{}
errUnmarshal := json.Unmarshal([]byte(value.Raw), &contentMap)
if errUnmarshal != nil {
log.Warnf("failed to unmarshal content: %v\n", errUnmarshal)
return true
}
newContents = append(newContents, contentMap)
}
return true
})
// Handle any remaining pending groups with remaining responses
for _, group := range pendingGroups {
if len(collectedResponses) >= group.ResponsesNeeded {
groupResponses := collectedResponses[:group.ResponsesNeeded]
collectedResponses = collectedResponses[group.ResponsesNeeded:]
var responseParts []interface{}
for _, response := range groupResponses {
var responseMap map[string]interface{}
errUnmarshal := json.Unmarshal([]byte(response.Raw), &responseMap)
if errUnmarshal != nil {
log.Warnf("failed to unmarshal function response: %v\n", errUnmarshal)
continue
}
responseParts = append(responseParts, responseMap)
}
if len(responseParts) > 0 {
functionResponseContent := map[string]interface{}{
"parts": responseParts,
"role": "function",
}
newContents = append(newContents, functionResponseContent)
}
}
}
// Update the original JSON with the new contents
result := input
newContentsJSON, _ := json.Marshal(newContents)
result, _ = sjson.Set(result, "request.contents", json.RawMessage(newContentsJSON))
return result, nil
}
func PrepareClaudeRequest(rawJson []byte) (string, *client.Content, []client.Content, []client.ToolDeclaration) {
var pathsToDelete []string
root := gjson.ParseBytes(rawJson)
walk(root, "", "additionalProperties", &pathsToDelete)
walk(root, "", "$schema", &pathsToDelete)
var err error
for _, p := range pathsToDelete {
rawJson, err = sjson.DeleteBytes(rawJson, p)
if err != nil {
continue
}
}
rawJson = bytes.Replace(rawJson, []byte(`"url":{"type":"string","format":"uri",`), []byte(`"url":{"type":"string",`), -1)
// log.Debug(string(rawJson))
modelName := "gemini-2.5-pro"
modelResult := gjson.GetBytes(rawJson, "model")
if modelResult.Type == gjson.String {
modelName = modelResult.String()
}
contents := make([]client.Content, 0)
var systemInstruction *client.Content
systemResult := gjson.GetBytes(rawJson, "system")
if systemResult.IsArray() {
systemResults := systemResult.Array()
systemInstruction = &client.Content{Role: "user", Parts: []client.Part{}}
for i := 0; i < len(systemResults); i++ {
systemPromptResult := systemResults[i]
systemTypePromptResult := systemPromptResult.Get("type")
if systemTypePromptResult.Type == gjson.String && systemTypePromptResult.String() == "text" {
systemPrompt := systemPromptResult.Get("text").String()
systemPart := client.Part{Text: systemPrompt}
systemInstruction.Parts = append(systemInstruction.Parts, systemPart)
}
}
if len(systemInstruction.Parts) == 0 {
systemInstruction = nil
}
}
messagesResult := gjson.GetBytes(rawJson, "messages")
if messagesResult.IsArray() {
messageResults := messagesResult.Array()
for i := 0; i < len(messageResults); i++ {
messageResult := messageResults[i]
roleResult := messageResult.Get("role")
if roleResult.Type != gjson.String {
continue
}
role := roleResult.String()
if role == "assistant" {
role = "model"
}
clientContent := client.Content{Role: role, Parts: []client.Part{}}
contentsResult := messageResult.Get("content")
if contentsResult.IsArray() {
contentResults := contentsResult.Array()
for j := 0; j < len(contentResults); j++ {
contentResult := contentResults[j]
contentTypeResult := contentResult.Get("type")
if contentTypeResult.Type == gjson.String && contentTypeResult.String() == "text" {
prompt := contentResult.Get("text").String()
clientContent.Parts = append(clientContent.Parts, client.Part{Text: prompt})
} else if contentTypeResult.Type == gjson.String && contentTypeResult.String() == "tool_use" {
functionName := contentResult.Get("name").String()
functionArgs := contentResult.Get("input").String()
var args map[string]any
if err = json.Unmarshal([]byte(functionArgs), &args); err == nil {
clientContent.Parts = append(clientContent.Parts, client.Part{
FunctionCall: &client.FunctionCall{
Name: functionName,
Args: args,
},
})
}
} else if contentTypeResult.Type == gjson.String && contentTypeResult.String() == "tool_result" {
toolCallID := contentResult.Get("tool_use_id").String()
if toolCallID != "" {
funcName := toolCallID
toolCallIDs := strings.Split(toolCallID, "-")
if len(toolCallIDs) > 1 {
funcName = strings.Join(toolCallIDs[0:len(toolCallIDs)-1], "-")
}
responseData := contentResult.Get("content").String()
functionResponse := client.FunctionResponse{Name: funcName, Response: map[string]interface{}{"result": responseData}}
clientContent.Parts = append(clientContent.Parts, client.Part{FunctionResponse: &functionResponse})
}
}
}
contents = append(contents, clientContent)
} else if contentsResult.Type == gjson.String {
prompt := contentsResult.String()
contents = append(contents, client.Content{Role: role, Parts: []client.Part{{Text: prompt}}})
}
}
}
var tools []client.ToolDeclaration
toolsResult := gjson.GetBytes(rawJson, "tools")
if toolsResult.IsArray() {
tools = make([]client.ToolDeclaration, 1)
tools[0].FunctionDeclarations = make([]any, 0)
toolsResults := toolsResult.Array()
for i := 0; i < len(toolsResults); i++ {
toolResult := toolsResults[i]
inputSchemaResult := toolResult.Get("input_schema")
if inputSchemaResult.Exists() && inputSchemaResult.IsObject() {
inputSchema := inputSchemaResult.Raw
inputSchema, _ = sjson.Delete(inputSchema, "additionalProperties")
inputSchema, _ = sjson.Delete(inputSchema, "$schema")
tool, _ := sjson.Delete(toolResult.Raw, "input_schema")
tool, _ = sjson.SetRaw(tool, "parameters", inputSchema)
var toolDeclaration any
if err = json.Unmarshal([]byte(tool), &toolDeclaration); err == nil {
tools[0].FunctionDeclarations = append(tools[0].FunctionDeclarations, toolDeclaration)
}
}
}
} else {
tools = make([]client.ToolDeclaration, 0)
}
return modelName, systemInstruction, contents, tools
}
func walk(value gjson.Result, path, field string, pathsToDelete *[]string) {
switch value.Type {
case gjson.JSON:
value.ForEach(func(key, val gjson.Result) bool {
var childPath string
if path == "" {
childPath = key.String()
} else {
childPath = path + "." + key.String()
}
if key.String() == field {
*pathsToDelete = append(*pathsToDelete, childPath)
}
walk(val, childPath, field, pathsToDelete)
return true
})
case gjson.String, gjson.Number, gjson.True, gjson.False, gjson.Null:
}
}

View File

@@ -0,0 +1,382 @@
package translator
import (
"bytes"
"fmt"
"time"
"github.com/tidwall/gjson"
"github.com/tidwall/sjson"
)
// ConvertCliToOpenAI translates a single chunk of a streaming response from the
// backend client format to the OpenAI Server-Sent Events (SSE) format.
// It returns an empty string if the chunk contains no useful data.
func ConvertCliToOpenAI(rawJson []byte, unixTimestamp int64, isGlAPIKey bool) string {
if isGlAPIKey {
rawJson, _ = sjson.SetRawBytes(rawJson, "response", rawJson)
}
// Initialize the OpenAI SSE template.
template := `{"id":"","object":"chat.completion.chunk","created":12345,"model":"model","choices":[{"index":0,"delta":{"role":null,"content":null,"reasoning_content":null,"tool_calls":null},"finish_reason":null,"native_finish_reason":null}]}`
// Extract and set the model version.
if modelVersionResult := gjson.GetBytes(rawJson, "response.modelVersion"); modelVersionResult.Exists() {
template, _ = sjson.Set(template, "model", modelVersionResult.String())
}
// Extract and set the creation timestamp.
if createTimeResult := gjson.GetBytes(rawJson, "response.createTime"); createTimeResult.Exists() {
t, err := time.Parse(time.RFC3339Nano, createTimeResult.String())
if err == nil {
unixTimestamp = t.Unix()
}
template, _ = sjson.Set(template, "created", unixTimestamp)
} else {
template, _ = sjson.Set(template, "created", unixTimestamp)
}
// Extract and set the response ID.
if responseIdResult := gjson.GetBytes(rawJson, "response.responseId"); responseIdResult.Exists() {
template, _ = sjson.Set(template, "id", responseIdResult.String())
}
// Extract and set the finish reason.
if finishReasonResult := gjson.GetBytes(rawJson, "response.candidates.0.finishReason"); finishReasonResult.Exists() {
template, _ = sjson.Set(template, "choices.0.finish_reason", finishReasonResult.String())
template, _ = sjson.Set(template, "choices.0.native_finish_reason", finishReasonResult.String())
}
// Extract and set usage metadata (token counts).
if usageResult := gjson.GetBytes(rawJson, "response.usageMetadata"); usageResult.Exists() {
if candidatesTokenCountResult := usageResult.Get("candidatesTokenCount"); candidatesTokenCountResult.Exists() {
template, _ = sjson.Set(template, "usage.completion_tokens", candidatesTokenCountResult.Int())
}
if totalTokenCountResult := usageResult.Get("totalTokenCount"); totalTokenCountResult.Exists() {
template, _ = sjson.Set(template, "usage.total_tokens", totalTokenCountResult.Int())
}
promptTokenCount := usageResult.Get("promptTokenCount").Int()
thoughtsTokenCount := usageResult.Get("thoughtsTokenCount").Int()
template, _ = sjson.Set(template, "usage.prompt_tokens", promptTokenCount+thoughtsTokenCount)
if thoughtsTokenCount > 0 {
template, _ = sjson.Set(template, "usage.completion_tokens_details.reasoning_tokens", thoughtsTokenCount)
}
}
// Process the main content part of the response.
partsResult := gjson.GetBytes(rawJson, "response.candidates.0.content.parts")
if partsResult.IsArray() {
partResults := partsResult.Array()
for i := 0; i < len(partResults); i++ {
partResult := partResults[i]
partTextResult := partResult.Get("text")
functionCallResult := partResult.Get("functionCall")
if partTextResult.Exists() {
// Handle text content, distinguishing between regular content and reasoning/thoughts.
if partResult.Get("thought").Bool() {
template, _ = sjson.Set(template, "choices.0.delta.reasoning_content", partTextResult.String())
} else {
template, _ = sjson.Set(template, "choices.0.delta.content", partTextResult.String())
}
template, _ = sjson.Set(template, "choices.0.delta.role", "assistant")
} else if functionCallResult.Exists() {
// Handle function call content.
toolCallsResult := gjson.Get(template, "choices.0.delta.tool_calls")
if !toolCallsResult.Exists() || !toolCallsResult.IsArray() {
template, _ = sjson.SetRaw(template, "choices.0.delta.tool_calls", `[]`)
}
functionCallTemplate := `{"id": "","type": "function","function": {"name": "","arguments": ""}}`
fcName := functionCallResult.Get("name").String()
functionCallTemplate, _ = sjson.Set(functionCallTemplate, "id", fmt.Sprintf("%s-%d", fcName, time.Now().UnixNano()))
functionCallTemplate, _ = sjson.Set(functionCallTemplate, "function.name", fcName)
if fcArgsResult := functionCallResult.Get("args"); fcArgsResult.Exists() {
functionCallTemplate, _ = sjson.Set(functionCallTemplate, "function.arguments", fcArgsResult.Raw)
}
template, _ = sjson.Set(template, "choices.0.delta.role", "assistant")
template, _ = sjson.SetRaw(template, "choices.0.message.tool_calls.-1", functionCallTemplate)
}
}
}
return template
}
// ConvertCliToOpenAINonStream aggregates response from the backend client
// convert a single, non-streaming OpenAI-compatible JSON response.
func ConvertCliToOpenAINonStream(rawJson []byte, unixTimestamp int64, isGlAPIKey bool) string {
if isGlAPIKey {
rawJson, _ = sjson.SetRawBytes(rawJson, "response", rawJson)
}
template := `{"id":"","object":"chat.completion","created":123456,"model":"model","choices":[{"index":0,"message":{"role":"assistant","content":null,"reasoning_content":null,"tool_calls":null},"finish_reason":null,"native_finish_reason":null}]}`
if modelVersionResult := gjson.GetBytes(rawJson, "response.modelVersion"); modelVersionResult.Exists() {
template, _ = sjson.Set(template, "model", modelVersionResult.String())
}
if createTimeResult := gjson.GetBytes(rawJson, "response.createTime"); createTimeResult.Exists() {
t, err := time.Parse(time.RFC3339Nano, createTimeResult.String())
if err == nil {
unixTimestamp = t.Unix()
}
template, _ = sjson.Set(template, "created", unixTimestamp)
} else {
template, _ = sjson.Set(template, "created", unixTimestamp)
}
if responseIdResult := gjson.GetBytes(rawJson, "response.responseId"); responseIdResult.Exists() {
template, _ = sjson.Set(template, "id", responseIdResult.String())
}
if finishReasonResult := gjson.GetBytes(rawJson, "response.candidates.0.finishReason"); finishReasonResult.Exists() {
template, _ = sjson.Set(template, "choices.0.finish_reason", finishReasonResult.String())
template, _ = sjson.Set(template, "choices.0.native_finish_reason", finishReasonResult.String())
}
if usageResult := gjson.GetBytes(rawJson, "response.usageMetadata"); usageResult.Exists() {
if candidatesTokenCountResult := usageResult.Get("candidatesTokenCount"); candidatesTokenCountResult.Exists() {
template, _ = sjson.Set(template, "usage.completion_tokens", candidatesTokenCountResult.Int())
}
if totalTokenCountResult := usageResult.Get("totalTokenCount"); totalTokenCountResult.Exists() {
template, _ = sjson.Set(template, "usage.total_tokens", totalTokenCountResult.Int())
}
promptTokenCount := usageResult.Get("promptTokenCount").Int()
thoughtsTokenCount := usageResult.Get("thoughtsTokenCount").Int()
template, _ = sjson.Set(template, "usage.prompt_tokens", promptTokenCount+thoughtsTokenCount)
if thoughtsTokenCount > 0 {
template, _ = sjson.Set(template, "usage.completion_tokens_details.reasoning_tokens", thoughtsTokenCount)
}
}
// Process the main content part of the response.
partsResult := gjson.GetBytes(rawJson, "response.candidates.0.content.parts")
if partsResult.IsArray() {
partsResults := partsResult.Array()
for i := 0; i < len(partsResults); i++ {
partResult := partsResults[i]
partTextResult := partResult.Get("text")
functionCallResult := partResult.Get("functionCall")
if partTextResult.Exists() {
// Append text content, distinguishing between regular content and reasoning.
if partResult.Get("thought").Bool() {
template, _ = sjson.Set(template, "choices.0.message.reasoning_content", partTextResult.String())
} else {
template, _ = sjson.Set(template, "choices.0.message.content", partTextResult.String())
}
template, _ = sjson.Set(template, "choices.0.message.role", "assistant")
} else if functionCallResult.Exists() {
// Append function call content to the tool_calls array.
toolCallsResult := gjson.Get(template, "choices.0.message.tool_calls")
if !toolCallsResult.Exists() || !toolCallsResult.IsArray() {
template, _ = sjson.SetRaw(template, "choices.0.message.tool_calls", `[]`)
}
functionCallItemTemplate := `{"id": "","type": "function","function": {"name": "","arguments": ""}}`
fcName := functionCallResult.Get("name").String()
functionCallItemTemplate, _ = sjson.Set(functionCallItemTemplate, "id", fmt.Sprintf("%s-%d", fcName, time.Now().UnixNano()))
functionCallItemTemplate, _ = sjson.Set(functionCallItemTemplate, "function.name", fcName)
if fcArgsResult := functionCallResult.Get("args"); fcArgsResult.Exists() {
functionCallItemTemplate, _ = sjson.Set(functionCallItemTemplate, "function.arguments", fcArgsResult.Raw)
}
template, _ = sjson.Set(template, "choices.0.message.role", "assistant")
template, _ = sjson.SetRaw(template, "choices.0.message.tool_calls.-1", functionCallItemTemplate)
} else {
// If no usable content is found, return an empty string.
return ""
}
}
}
return template
}
// ConvertCliToClaude performs sophisticated streaming response format conversion.
// This function implements a complex state machine that translates backend client responses
// into Claude-compatible Server-Sent Events (SSE) format. It manages different response types
// and handles state transitions between content blocks, thinking processes, and function calls.
//
// Response type states: 0=none, 1=content, 2=thinking, 3=function
// The function maintains state across multiple calls to ensure proper SSE event sequencing.
func ConvertCliToClaude(rawJson []byte, isGlAPIKey, hasFirstResponse bool, responseType, responseIndex *int) string {
// Normalize the response format for different API key types
// Generative Language API keys have a different response structure
if isGlAPIKey {
rawJson, _ = sjson.SetRawBytes(rawJson, "response", rawJson)
}
// Track whether tools are being used in this response chunk
usedTool := false
output := ""
// Initialize the streaming session with a message_start event
// This is only sent for the very first response chunk
if !hasFirstResponse {
output = "event: message_start\n"
// Create the initial message structure with default values
// This follows the Claude API specification for streaming message initialization
messageStartTemplate := `{"type": "message_start", "message": {"id": "msg_1nZdL29xx5MUA1yADyHTEsnR8uuvGzszyY", "type": "message", "role": "assistant", "content": [], "model": "claude-3-5-sonnet-20241022", "stop_reason": null, "stop_sequence": null, "usage": {"input_tokens": 0, "output_tokens": 0}}}`
// Override default values with actual response metadata if available
if modelVersionResult := gjson.GetBytes(rawJson, "response.modelVersion"); modelVersionResult.Exists() {
messageStartTemplate, _ = sjson.Set(messageStartTemplate, "message.model", modelVersionResult.String())
}
if responseIdResult := gjson.GetBytes(rawJson, "response.responseId"); responseIdResult.Exists() {
messageStartTemplate, _ = sjson.Set(messageStartTemplate, "message.id", responseIdResult.String())
}
output = output + fmt.Sprintf("data: %s\n\n\n", messageStartTemplate)
}
// Process the response parts array from the backend client
// Each part can contain text content, thinking content, or function calls
partsResult := gjson.GetBytes(rawJson, "response.candidates.0.content.parts")
if partsResult.IsArray() {
partResults := partsResult.Array()
for i := 0; i < len(partResults); i++ {
partResult := partResults[i]
// Extract the different types of content from each part
partTextResult := partResult.Get("text")
functionCallResult := partResult.Get("functionCall")
// Handle text content (both regular content and thinking)
if partTextResult.Exists() {
// Process thinking content (internal reasoning)
if partResult.Get("thought").Bool() {
// Continue existing thinking block
if *responseType == 2 {
output = output + "event: content_block_delta\n"
data, _ := sjson.Set(fmt.Sprintf(`{"type":"content_block_delta","index":%d,"delta":{"type":"thinking_delta","thinking":""}}`, *responseIndex), "delta.thinking", partTextResult.String())
output = output + fmt.Sprintf("data: %s\n\n\n", data)
} else {
// Transition from another state to thinking
// First, close any existing content block
if *responseType != 0 {
if *responseType == 2 {
output = output + "event: content_block_delta\n"
output = output + fmt.Sprintf(`data: {"type":"content_block_delta","index":%d,"delta":{"type":"signature_delta","signature":null}}`, *responseIndex)
output = output + "\n\n\n"
}
output = output + "event: content_block_stop\n"
output = output + fmt.Sprintf(`data: {"type":"content_block_stop","index":%d}`, *responseIndex)
output = output + "\n\n\n"
*responseIndex++
}
// Start a new thinking content block
output = output + "event: content_block_start\n"
output = output + fmt.Sprintf(`data: {"type":"content_block_start","index":%d,"content_block":{"type":"thinking","thinking":""}}`, *responseIndex)
output = output + "\n\n\n"
output = output + "event: content_block_delta\n"
data, _ := sjson.Set(fmt.Sprintf(`{"type":"content_block_delta","index":%d,"delta":{"type":"thinking_delta","thinking":""}}`, *responseIndex), "delta.thinking", partTextResult.String())
output = output + fmt.Sprintf("data: %s\n\n\n", data)
*responseType = 2 // Set state to thinking
}
} else {
// Process regular text content (user-visible output)
// Continue existing text block
if *responseType == 1 {
output = output + "event: content_block_delta\n"
data, _ := sjson.Set(fmt.Sprintf(`{"type":"content_block_delta","index":%d,"delta":{"type":"text_delta","text":""}}`, *responseIndex), "delta.text", partTextResult.String())
output = output + fmt.Sprintf("data: %s\n\n\n", data)
} else {
// Transition from another state to text content
// First, close any existing content block
if *responseType != 0 {
if *responseType == 2 {
output = output + "event: content_block_delta\n"
output = output + fmt.Sprintf(`data: {"type":"content_block_delta","index":%d,"delta":{"type":"signature_delta","signature":null}}`, *responseIndex)
output = output + "\n\n\n"
}
output = output + "event: content_block_stop\n"
output = output + fmt.Sprintf(`data: {"type":"content_block_stop","index":%d}`, *responseIndex)
output = output + "\n\n\n"
*responseIndex++
}
// Start a new text content block
output = output + "event: content_block_start\n"
output = output + fmt.Sprintf(`data: {"type":"content_block_start","index":%d,"content_block":{"type":"text","text":""}}`, *responseIndex)
output = output + "\n\n\n"
output = output + "event: content_block_delta\n"
data, _ := sjson.Set(fmt.Sprintf(`{"type":"content_block_delta","index":%d,"delta":{"type":"text_delta","text":""}}`, *responseIndex), "delta.text", partTextResult.String())
output = output + fmt.Sprintf("data: %s\n\n\n", data)
*responseType = 1 // Set state to content
}
}
} else if functionCallResult.Exists() {
// Handle function/tool calls from the AI model
// This processes tool usage requests and formats them for Claude API compatibility
usedTool = true
fcName := functionCallResult.Get("name").String()
// Handle state transitions when switching to function calls
// Close any existing function call block first
if *responseType == 3 {
output = output + "event: content_block_stop\n"
output = output + fmt.Sprintf(`data: {"type":"content_block_stop","index":%d}`, *responseIndex)
output = output + "\n\n\n"
*responseIndex++
*responseType = 0
}
// Special handling for thinking state transition
if *responseType == 2 {
output = output + "event: content_block_delta\n"
output = output + fmt.Sprintf(`data: {"type":"content_block_delta","index":%d,"delta":{"type":"signature_delta","signature":null}}`, *responseIndex)
output = output + "\n\n\n"
}
// Close any other existing content block
if *responseType != 0 {
output = output + "event: content_block_stop\n"
output = output + fmt.Sprintf(`data: {"type":"content_block_stop","index":%d}`, *responseIndex)
output = output + "\n\n\n"
*responseIndex++
}
// Start a new tool use content block
// This creates the structure for a function call in Claude format
output = output + "event: content_block_start\n"
// Create the tool use block with unique ID and function details
data := fmt.Sprintf(`{"type":"content_block_start","index":%d,"content_block":{"type":"tool_use","id":"","name":"","input":{}}}`, *responseIndex)
data, _ = sjson.Set(data, "content_block.id", fmt.Sprintf("%s-%d", fcName, time.Now().UnixNano()))
data, _ = sjson.Set(data, "content_block.name", fcName)
output = output + fmt.Sprintf("data: %s\n\n\n", data)
if fcArgsResult := functionCallResult.Get("args"); fcArgsResult.Exists() {
output = output + "event: content_block_delta\n"
data, _ = sjson.Set(fmt.Sprintf(`{"type":"content_block_delta","index":%d,"delta":{"type":"input_json_delta","partial_json":""}}`, *responseIndex), "delta.partial_json", fcArgsResult.Raw)
output = output + fmt.Sprintf("data: %s\n\n\n", data)
}
*responseType = 3
}
}
}
usageResult := gjson.GetBytes(rawJson, "response.usageMetadata")
if usageResult.Exists() && bytes.Contains(rawJson, []byte(`"finishReason"`)) {
if candidatesTokenCountResult := usageResult.Get("candidatesTokenCount"); candidatesTokenCountResult.Exists() {
output = output + "event: content_block_stop\n"
output = output + fmt.Sprintf(`data: {"type":"content_block_stop","index":%d}`, *responseIndex)
output = output + "\n\n\n"
output = output + "event: message_delta\n"
output = output + `data: `
template := `{"type":"message_delta","delta":{"stop_reason":"end_turn","stop_sequence":null},"usage":{"input_tokens":0,"output_tokens":0}}`
if usedTool {
template = `{"type":"message_delta","delta":{"stop_reason":"tool_use","stop_sequence":null},"usage":{"input_tokens":0,"output_tokens":0}}`
}
thoughtsTokenCount := usageResult.Get("thoughtsTokenCount").Int()
template, _ = sjson.Set(template, "usage.output_tokens", candidatesTokenCountResult.Int()+thoughtsTokenCount)
template, _ = sjson.Set(template, "usage.input_tokens", usageResult.Get("promptTokenCount").Int())
output = output + template + "\n\n\n"
}
}
return output
}

View File

@@ -5,15 +5,18 @@ import (
"encoding/json"
"errors"
"fmt"
log "github.com/sirupsen/logrus"
"github.com/tidwall/gjson"
"io"
"net"
"net/http"
"os"
"path/filepath"
"net/url"
"time"
"github.com/luispater/CLIProxyAPI/internal/config"
log "github.com/sirupsen/logrus"
"github.com/skratchdot/open-golang/open"
"github.com/tidwall/gjson"
"golang.org/x/net/proxy"
"golang.org/x/oauth2"
"golang.org/x/oauth2/google"
)
@@ -31,48 +34,79 @@ var (
}
)
type TokenStorage struct {
Token any `json:"token"`
ProjectID string `json:"project_id"`
Email string `json:"email"`
}
// GetAuthenticatedClient configures and returns an HTTP client ready for making authenticated API calls.
// It manages the entire OAuth2 flow, including handling proxies, loading existing tokens,
// initiating a new web-based OAuth flow if necessary, and refreshing tokens.
func GetAuthenticatedClient(ctx context.Context, ts *TokenStorage, cfg *config.Config) (*http.Client, error) {
// Configure proxy settings for the HTTP client if a proxy URL is provided.
proxyURL, err := url.Parse(cfg.ProxyUrl)
if err == nil {
var transport *http.Transport
if proxyURL.Scheme == "socks5" {
// Handle SOCKS5 proxy.
username := proxyURL.User.Username()
password, _ := proxyURL.User.Password()
auth := &proxy.Auth{User: username, Password: password}
dialer, errSOCKS5 := proxy.SOCKS5("tcp", proxyURL.Host, auth, proxy.Direct)
if errSOCKS5 != nil {
log.Fatalf("create SOCKS5 dialer failed: %v", errSOCKS5)
}
transport = &http.Transport{
DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
return dialer.Dial(network, addr)
},
}
} else if proxyURL.Scheme == "http" || proxyURL.Scheme == "https" {
// Handle HTTP/HTTPS proxy.
transport = &http.Transport{Proxy: http.ProxyURL(proxyURL)}
}
// GetAuthenticatedClient configures and returns an HTTP client with OAuth2 tokens.
// It handles the entire flow: loading, refreshing, and fetching new tokens.
func GetAuthenticatedClient(ctx context.Context, ts *TokenStorage, authDir string) (*http.Client, error) {
if transport != nil {
proxyClient := &http.Client{Transport: transport}
ctx = context.WithValue(ctx, oauth2.HTTPClient, proxyClient)
}
}
// Configure the OAuth2 client.
conf := &oauth2.Config{
ClientID: oauthClientID,
ClientSecret: oauthClientSecret,
RedirectURL: "http://localhost:8085/oauth2callback", // Placeholder, will be updated
RedirectURL: "http://localhost:8085/oauth2callback", // This will be used by the local server.
Scopes: oauthScopes,
Endpoint: google.Endpoint,
}
var token *oauth2.Token
var err error
// If no token is found in storage, initiate the web-based OAuth flow.
if ts.Token == nil {
log.Info("Could not load token from file, starting OAuth flow.")
token, err = getTokenFromWeb(ctx, conf)
if err != nil {
return nil, fmt.Errorf("failed to get token from web: %w", err)
}
ts, err = saveTokenToFile(ctx, conf, token, ts.ProjectID, authDir)
if err != nil {
// Log the error but proceed, as we have a valid token for the session.
log.Errorf("Warning: failed to save token to file: %v", err)
// After getting a new token, create a new token storage object with user info.
newTs, errCreateTokenStorage := createTokenStorage(ctx, conf, token, ts.ProjectID)
if errCreateTokenStorage != nil {
log.Errorf("Warning: failed to create token storage: %v", errCreateTokenStorage)
return nil, errCreateTokenStorage
}
}
tsToken, _ := json.Marshal(ts.Token)
if err = json.Unmarshal(tsToken, &token); err != nil {
return nil, err
*ts = *newTs
}
// Unmarshal the stored token into an oauth2.Token object.
tsToken, _ := json.Marshal(ts.Token)
if err = json.Unmarshal(tsToken, &token); err != nil {
return nil, fmt.Errorf("failed to unmarshal token: %w", err)
}
// Return an HTTP client that automatically handles token refreshing.
return conf.Client(ctx, token), nil
}
// saveTokenToFile saves a token to the local credentials file.
func saveTokenToFile(ctx context.Context, config *oauth2.Config, token *oauth2.Token, projectID, authDir string) (*TokenStorage, error) {
// createTokenStorage creates a new TokenStorage object. It fetches the user's email
// using the provided token and populates the storage structure.
func createTokenStorage(ctx context.Context, config *oauth2.Config, token *oauth2.Token, projectID string) (*TokenStorage, error) {
httpClient := config.Client(ctx, token)
req, err := http.NewRequestWithContext(ctx, "GET", "https://www.googleapis.com/oauth2/v1/userinfo?alt=json", nil)
if err != nil {
@@ -86,7 +120,9 @@ func saveTokenToFile(ctx context.Context, config *oauth2.Config, token *oauth2.T
return nil, fmt.Errorf("failed to execute request: %w", err)
}
defer func() {
_ = resp.Body.Close()
if err = resp.Body.Close(); err != nil {
log.Printf("warn: failed to close response body: %v", err)
}
}()
bodyBytes, _ := io.ReadAll(resp.Body)
@@ -101,19 +137,6 @@ func saveTokenToFile(ctx context.Context, config *oauth2.Config, token *oauth2.T
log.Info("Failed to get user email from token")
}
log.Infof("Saving credentials to %s", filepath.Join(authDir, fmt.Sprintf("%s.json", emailResult.String())))
if err = os.MkdirAll(authDir, 0700); err != nil {
return nil, fmt.Errorf("failed to create directory: %w", err)
}
f, err := os.Create(filepath.Join(authDir, fmt.Sprintf("%s.json", emailResult.String())))
if err != nil {
return nil, fmt.Errorf("failed to create token file: %w", err)
}
defer func() {
_ = f.Close()
}()
var ifToken map[string]any
jsonData, _ := json.Marshal(token)
err = json.Unmarshal(jsonData, &ifToken)
@@ -133,23 +156,24 @@ func saveTokenToFile(ctx context.Context, config *oauth2.Config, token *oauth2.T
Email: emailResult.String(),
}
if err = json.NewEncoder(f).Encode(ts); err != nil {
return nil, fmt.Errorf("failed to write token to file: %w", err)
}
return &ts, nil
}
// getTokenFromWeb starts a local server to handle the OAuth2 flow.
// getTokenFromWeb initiates the web-based OAuth2 authorization flow.
// It starts a local HTTP server to listen for the callback from Google's auth server,
// opens the user's browser to the authorization URL, and exchanges the received
// authorization code for an access token.
func getTokenFromWeb(ctx context.Context, config *oauth2.Config) (*oauth2.Token, error) {
// Use a channel to pass the authorization code from the HTTP handler to the main function.
codeChan := make(chan string)
errChan := make(chan error)
// Create a new HTTP server.
server := &http.Server{Addr: "localhost:8085"}
// Create a new HTTP server with its own multiplexer.
mux := http.NewServeMux()
server := &http.Server{Addr: ":8085", Handler: mux}
config.RedirectURL = "http://localhost:8085/oauth2callback"
http.HandleFunc("/oauth2callback", func(w http.ResponseWriter, r *http.Request) {
mux.HandleFunc("/oauth2callback", func(w http.ResponseWriter, r *http.Request) {
if err := r.URL.Query().Get("error"); err != "" {
_, _ = fmt.Fprintf(w, "Authentication failed: %s", err)
errChan <- fmt.Errorf("authentication failed via callback: %s", err)
@@ -174,9 +198,10 @@ func getTokenFromWeb(ctx context.Context, config *oauth2.Config) (*oauth2.Token,
// Open the authorization URL in the user's browser.
authURL := config.AuthCodeURL("state-token", oauth2.AccessTypeOffline, oauth2.SetAuthURLParam("prompt", "consent"))
log.Debugf("CLI login required.\nAttempting to open authentication page in your browser.\nIf it does not open, please navigate to this URL:\n\n%s\n\n", authURL)
log.Debugf("CLI login required.\nAttempting to open authentication page in your browser.\nIf it does not open, please navigate to this URL:\n\n%s\n", authURL)
err := open.Run(authURL)
var err error
err = open.Run(authURL)
if err != nil {
log.Errorf("Failed to open browser: %v. Please open the URL manually.", err)
}

17
internal/auth/models.go Normal file
View File

@@ -0,0 +1,17 @@
package auth
// TokenStorage defines the structure for storing OAuth2 token information,
// along with associated user and project details. This data is typically
// serialized to a JSON file for persistence.
type TokenStorage struct {
// Token holds the raw OAuth2 token data, including access and refresh tokens.
Token any `json:"token"`
// ProjectID is the Google Cloud Project ID associated with this token.
ProjectID string `json:"project_id"`
// Email is the email address of the authenticated user.
Email string `json:"email"`
// Auto indicates if the project ID was automatically selected.
Auto bool `json:"auto"`
// Checked indicates if the associated Cloud AI API has been verified as enabled.
Checked bool `json:"checked"`
}

File diff suppressed because it is too large Load Diff

91
internal/client/models.go Normal file
View File

@@ -0,0 +1,91 @@
package client
import "time"
// ErrorMessage encapsulates an error with an associated HTTP status code.
type ErrorMessage struct {
StatusCode int
Error error
}
// GCPProject represents the response structure for a Google Cloud project list request.
type GCPProject struct {
Projects []GCPProjectProjects `json:"projects"`
}
// GCPProjectLabels defines the labels associated with a GCP project.
type GCPProjectLabels struct {
GenerativeLanguage string `json:"generative-language"`
}
// GCPProjectProjects contains details about a single Google Cloud project.
type GCPProjectProjects struct {
ProjectNumber string `json:"projectNumber"`
ProjectID string `json:"projectId"`
LifecycleState string `json:"lifecycleState"`
Name string `json:"name"`
Labels GCPProjectLabels `json:"labels"`
CreateTime time.Time `json:"createTime"`
}
// Content represents a single message in a conversation, with a role and parts.
type Content struct {
Role string `json:"role"`
Parts []Part `json:"parts"`
}
// Part represents a distinct piece of content within a message, which can be
// text, inline data (like an image), a function call, or a function response.
type Part struct {
Text string `json:"text,omitempty"`
InlineData *InlineData `json:"inlineData,omitempty"`
FunctionCall *FunctionCall `json:"functionCall,omitempty"`
FunctionResponse *FunctionResponse `json:"functionResponse,omitempty"`
}
// InlineData represents base64-encoded data with its MIME type.
type InlineData struct {
MimeType string `json:"mime_type,omitempty"`
Data string `json:"data,omitempty"`
}
// FunctionCall represents a tool call requested by the model, including the
// function name and its arguments.
type FunctionCall struct {
Name string `json:"name"`
Args map[string]interface{} `json:"args"`
}
// FunctionResponse represents the result of a tool execution, sent back to the model.
type FunctionResponse struct {
Name string `json:"name"`
Response map[string]interface{} `json:"response"`
}
// GenerateContentRequest is the top-level request structure for the streamGenerateContent endpoint.
type GenerateContentRequest struct {
SystemInstruction *Content `json:"systemInstruction,omitempty"`
Contents []Content `json:"contents"`
Tools []ToolDeclaration `json:"tools,omitempty"`
GenerationConfig `json:"generationConfig"`
}
// GenerationConfig defines parameters that control the model's generation behavior.
type GenerationConfig struct {
ThinkingConfig GenerationConfigThinkingConfig `json:"thinkingConfig,omitempty"`
Temperature float64 `json:"temperature,omitempty"`
TopP float64 `json:"topP,omitempty"`
TopK float64 `json:"topK,omitempty"`
}
// GenerationConfigThinkingConfig specifies configuration for the model's "thinking" process.
type GenerationConfigThinkingConfig struct {
// IncludeThoughts determines whether the model should output its reasoning process.
IncludeThoughts bool `json:"include_thoughts,omitempty"`
}
// ToolDeclaration defines the structure for declaring tools (like functions)
// that the model can call.
type ToolDeclaration struct {
FunctionDeclarations []interface{} `json:"functionDeclarations"`
}

86
internal/cmd/login.go Normal file
View File

@@ -0,0 +1,86 @@
package cmd
import (
"context"
"github.com/luispater/CLIProxyAPI/internal/auth"
"github.com/luispater/CLIProxyAPI/internal/client"
"github.com/luispater/CLIProxyAPI/internal/config"
log "github.com/sirupsen/logrus"
"os"
)
// DoLogin handles the entire user login and setup process.
// It authenticates the user, sets up the user's project, checks API enablement,
// and saves the token for future use.
func DoLogin(cfg *config.Config, projectID string) {
var err error
var ts auth.TokenStorage
if projectID != "" {
ts.ProjectID = projectID
}
// Initialize an authenticated HTTP client. This will trigger the OAuth flow if necessary.
clientCtx := context.Background()
log.Info("Initializing authentication...")
httpClient, errGetClient := auth.GetAuthenticatedClient(clientCtx, &ts, cfg)
if errGetClient != nil {
log.Fatalf("failed to get authenticated client: %v", errGetClient)
return
}
log.Info("Authentication successful.")
// Initialize the API client.
cliClient := client.NewClient(httpClient, &ts, cfg)
// Perform the user setup process.
err = cliClient.SetupUser(clientCtx, ts.Email, projectID)
if err != nil {
// Handle the specific case where a project ID is required but not provided.
if err.Error() == "failed to start user onboarding, need define a project id" {
log.Error("Failed to start user onboarding: A project ID is required.")
// Fetch and display the user's available projects to help them choose one.
project, errGetProjectList := cliClient.GetProjectList(clientCtx)
if errGetProjectList != nil {
log.Fatalf("Failed to get project list: %v", err)
} else {
log.Infof("Your account %s needs to specify a project ID.", ts.Email)
log.Info("========================================================================")
for _, p := range project.Projects {
log.Infof("Project ID: %s", p.ProjectID)
log.Infof("Project Name: %s", p.Name)
log.Info("------------------------------------------------------------------------")
}
log.Infof("Please run this command to login again with a specific project:\n\n%s --login --project_id <project_id>\n", os.Args[0])
}
} else {
log.Fatalf("Failed to complete user setup: %v", err)
}
return // Exit after handling the error.
}
// If setup is successful, proceed to check API status and save the token.
auto := projectID == ""
cliClient.SetIsAuto(auto)
// If the project was not automatically selected, check if the Cloud AI API is enabled.
if !cliClient.IsChecked() && !cliClient.IsAuto() {
isChecked, checkErr := cliClient.CheckCloudAPIIsEnabled()
if checkErr != nil {
log.Fatalf("Failed to check if Cloud AI API is enabled: %v", checkErr)
return
}
cliClient.SetIsChecked(isChecked)
// If the check fails (returns false), the CheckCloudAPIIsEnabled function
// will have already printed instructions, so we can just exit.
if !isChecked {
log.Fatal("Failed to check if Cloud AI API is enabled. If you encounter an error message, please create an issue.")
return
}
}
// Save the successfully obtained and verified token to a file.
err = cliClient.SaveTokenToFile()
if err != nil {
log.Fatalf("Failed to save token to file: %v", err)
}
}

115
internal/cmd/run.go Normal file
View File

@@ -0,0 +1,115 @@
package cmd
import (
"context"
"encoding/json"
"github.com/luispater/CLIProxyAPI/internal/api"
"github.com/luispater/CLIProxyAPI/internal/auth"
"github.com/luispater/CLIProxyAPI/internal/client"
"github.com/luispater/CLIProxyAPI/internal/config"
"github.com/luispater/CLIProxyAPI/internal/util"
log "github.com/sirupsen/logrus"
"io/fs"
"net/http"
"os"
"os/signal"
"path/filepath"
"strings"
"syscall"
"time"
)
// StartService initializes and starts the main API proxy service.
// It loads all available authentication tokens, creates a pool of clients,
// starts the API server, and handles graceful shutdown signals.
func StartService(cfg *config.Config) {
// Create a pool of API clients, one for each token file found.
cliClients := make([]*client.Client, 0)
err := filepath.Walk(cfg.AuthDir, func(path string, info fs.FileInfo, err error) error {
if err != nil {
return err
}
// Process only JSON files in the auth directory.
if !info.IsDir() && strings.HasSuffix(info.Name(), ".json") {
log.Debugf("Loading token from: %s", path)
f, errOpen := os.Open(path)
if errOpen != nil {
return errOpen
}
defer func() {
_ = f.Close()
}()
// Decode the token storage file.
var ts auth.TokenStorage
if err = json.NewDecoder(f).Decode(&ts); err == nil {
// For each valid token, create an authenticated client.
clientCtx := context.Background()
log.Info("Initializing authentication for token...")
httpClient, errGetClient := auth.GetAuthenticatedClient(clientCtx, &ts, cfg)
if errGetClient != nil {
// Log fatal will exit, but we return the error for completeness.
log.Fatalf("failed to get authenticated client for token %s: %v", path, errGetClient)
return errGetClient
}
log.Info("Authentication successful.")
// Add the new client to the pool.
cliClient := client.NewClient(httpClient, &ts, cfg)
cliClients = append(cliClients, cliClient)
}
}
return nil
})
if err != nil {
log.Fatalf("Error walking auth directory: %v", err)
}
if len(cfg.GlAPIKey) > 0 {
for i := 0; i < len(cfg.GlAPIKey); i++ {
httpClient, errSetProxy := util.SetProxy(cfg, &http.Client{})
if errSetProxy != nil {
log.Fatalf("set proxy failed: %v", errSetProxy)
}
log.Debug("Initializing with Generative Language API key...")
cliClient := client.NewClient(httpClient, nil, cfg, cfg.GlAPIKey[i])
cliClients = append(cliClients, cliClient)
}
}
// Create and start the API server with the pool of clients.
apiServer := api.NewServer(cfg, cliClients)
log.Infof("Starting API server on port %d", cfg.Port)
if err = apiServer.Start(); err != nil {
log.Fatalf("API server failed to start: %v", err)
}
// Set up a channel to listen for OS signals for graceful shutdown.
sigChan := make(chan os.Signal, 1)
signal.Notify(sigChan, syscall.SIGINT, syscall.SIGTERM)
// Main loop to wait for shutdown signal.
for {
select {
case <-sigChan:
log.Debugf("Received shutdown signal. Cleaning up...")
// Create a context with a timeout for the shutdown process.
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
_ = cancel
// Stop the API server gracefully.
if err = apiServer.Stop(ctx); err != nil {
log.Debugf("Error stopping API server: %v", err)
}
log.Debugf("Cleanup completed. Exiting...")
os.Exit(0)
case <-time.After(5 * time.Second):
// This case is currently empty and acts as a periodic check.
// It could be used for periodic tasks in the future.
}
}
}

View File

@@ -6,32 +6,46 @@ import (
"os"
)
// Config represents the application's configuration
// Config represents the application's configuration, loaded from a YAML file.
type Config struct {
Port int `yaml:"port"`
AuthDir string `yaml:"auth_dir"`
Debug bool `yaml:"debug"`
ApiKeys []string `yaml:"api_keys"`
// Port is the network port on which the API server will listen.
Port int `yaml:"port"`
// AuthDir is the directory where authentication token files are stored.
AuthDir string `yaml:"auth-dir"`
// Debug enables or disables debug-level logging and other debug features.
Debug bool `yaml:"debug"`
// ProxyUrl is the URL of an optional proxy server to use for outbound requests.
ProxyUrl string `yaml:"proxy-url"`
// ApiKeys is a list of keys for authenticating clients to this proxy server.
ApiKeys []string `yaml:"api-keys"`
// QuotaExceeded defines the behavior when a quota is exceeded.
QuotaExceeded ConfigQuotaExceeded `yaml:"quota-exceeded"`
// GlAPIKey is the API key for the generative language API.
GlAPIKey []string `yaml:"generative-language-api-key"`
}
// / LoadConfig loads the configuration from the specified file
type ConfigQuotaExceeded struct {
// SwitchProject indicates whether to automatically switch to another project when a quota is exceeded.
SwitchProject bool `yaml:"switch-project"`
// SwitchPreviewModel indicates whether to automatically switch to a preview model when a quota is exceeded.
SwitchPreviewModel bool `yaml:"switch-preview-model"`
}
// LoadConfig reads a YAML configuration file from the given path,
// unmarshals it into a Config struct, and returns it.
func LoadConfig(configFile string) (*Config, error) {
// Read the configuration file
// Read the entire configuration file into memory.
data, err := os.ReadFile(configFile)
// If reading the file fails
if err != nil {
// Return an error
return nil, fmt.Errorf("failed to read config file: %w", err)
}
// Parse the YAML data
// Unmarshal the YAML data into the Config struct.
var config Config
// If parsing the YAML data fails
if err = yaml.Unmarshal(data, &config); err != nil {
// Return an error
return nil, fmt.Errorf("failed to parse config file: %w", err)
}
// Return the configuration
// Return the populated configuration struct.
return &config, nil
}

37
internal/util/proxy.go Normal file
View File

@@ -0,0 +1,37 @@
package util
import (
"context"
"github.com/luispater/CLIProxyAPI/internal/config"
"golang.org/x/net/proxy"
"net"
"net/http"
"net/url"
)
func SetProxy(cfg *config.Config, httpClient *http.Client) (*http.Client, error) {
var transport *http.Transport
proxyURL, errParse := url.Parse(cfg.ProxyUrl)
if errParse == nil {
if proxyURL.Scheme == "socks5" {
username := proxyURL.User.Username()
password, _ := proxyURL.User.Password()
proxyAuth := &proxy.Auth{User: username, Password: password}
dialer, errSOCKS5 := proxy.SOCKS5("tcp", proxyURL.Host, proxyAuth, proxy.Direct)
if errSOCKS5 != nil {
return nil, errSOCKS5
}
transport = &http.Transport{
DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
return dialer.Dial(network, addr)
},
}
} else if proxyURL.Scheme == "http" || proxyURL.Scheme == "https" {
transport = &http.Transport{Proxy: http.ProxyURL(proxyURL)}
}
}
if transport != nil {
httpClient.Transport = transport
}
return httpClient, nil
}