mirror of
https://github.com/router-for-me/CLIProxyAPI.git
synced 2026-02-02 12:30:50 +08:00
Compare commits
34 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
d2394b0be9 | ||
|
|
ebcd4dbf3d | ||
|
|
1483c31c73 | ||
|
|
00f33f5f3a | ||
|
|
3c4dc07980 | ||
|
|
3b4634e2dc | ||
|
|
00bd6a3e46 | ||
|
|
5812229d9b | ||
|
|
0b026933a7 | ||
|
|
3b2ab0d7bd | ||
|
|
e64fa48823 | ||
|
|
beff9282f6 | ||
|
|
31a9e2d11f | ||
|
|
423faae3da | ||
|
|
ead71fb7ef | ||
|
|
58b7afdf1e | ||
|
|
c86545d7e1 | ||
|
|
f49a530c1a | ||
|
|
368796349e | ||
|
|
c601542f6f | ||
|
|
3c0c61aaf1 | ||
|
|
edeadfc389 | ||
|
|
aa9fd057fe | ||
|
|
b3607d3981 | ||
|
|
fa8d94971f | ||
|
|
ef68a97526 | ||
|
|
d880d1a1ea | ||
|
|
d4104214ed | ||
|
|
273e1d9cbe | ||
|
|
65f47c196a | ||
|
|
9be56fe8e0 | ||
|
|
589ae6d3aa | ||
|
|
7cb76ae1a5 | ||
|
|
e73f165070 |
42
.github/workflows/docker-image.yml
vendored
Normal file
42
.github/workflows/docker-image.yml
vendored
Normal file
@@ -0,0 +1,42 @@
|
||||
name: docker-image
|
||||
|
||||
on:
|
||||
push:
|
||||
tags:
|
||||
- v*
|
||||
|
||||
env:
|
||||
APP_NAME: CLIProxyAPI
|
||||
DOCKERHUB_REPO: eceasy/cli-proxy-api
|
||||
|
||||
jobs:
|
||||
docker:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@v4
|
||||
- name: Set up QEMU
|
||||
uses: docker/setup-qemu-action@v3
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@v3
|
||||
- name: Login to DockerHub
|
||||
uses: docker/login-action@v3
|
||||
with:
|
||||
username: ${{ secrets.DOCKERHUB_USERNAME }}
|
||||
password: ${{ secrets.DOCKERHUB_TOKEN }}
|
||||
- name: Generate App Version
|
||||
run: echo APP_VERSION=`git describe --tags --always` >> $GITHUB_ENV
|
||||
- name: Build and push
|
||||
uses: docker/build-push-action@v6
|
||||
with:
|
||||
context: .
|
||||
platforms: |
|
||||
linux/amd64
|
||||
linux/arm64
|
||||
push: true
|
||||
build-args: |
|
||||
APP_NAME=${{ env.APP_NAME }}
|
||||
APP_VERSION=${{ env.APP_VERSION }}
|
||||
tags: |
|
||||
${{ env.DOCKERHUB_REPO }}:latest
|
||||
${{ env.DOCKERHUB_REPO }}:${{ env.APP_VERSION }}
|
||||
2
.gitignore
vendored
Normal file
2
.gitignore
vendored
Normal file
@@ -0,0 +1,2 @@
|
||||
config.yaml
|
||||
docs/
|
||||
@@ -11,7 +11,12 @@ builds:
|
||||
binary: cli-proxy-api
|
||||
archives:
|
||||
- id: "cli-proxy-api"
|
||||
format: tar.gz
|
||||
format_overrides:
|
||||
- goos: windows
|
||||
format: zip
|
||||
files:
|
||||
- LICENSE
|
||||
- README.md
|
||||
- config.yaml
|
||||
- README_CN.md
|
||||
- config.example.yaml
|
||||
23
Dockerfile
Normal file
23
Dockerfile
Normal file
@@ -0,0 +1,23 @@
|
||||
FROM golang:1.24-alpine AS builder
|
||||
|
||||
WORKDIR /app
|
||||
|
||||
COPY go.mod go.sum ./
|
||||
|
||||
RUN go mod download
|
||||
|
||||
COPY . .
|
||||
|
||||
RUN CGO_ENABLED=0 GOOS=linux go build -o ./CLIProxyAPI ./cmd/server/
|
||||
|
||||
FROM alpine:3.22.0
|
||||
|
||||
RUN mkdir /CLIProxyAPI
|
||||
|
||||
COPY --from=builder ./app/CLIProxyAPI /CLIProxyAPI/CLIProxyAPI
|
||||
|
||||
WORKDIR /CLIProxyAPI
|
||||
|
||||
EXPOSE 8317
|
||||
|
||||
CMD ["./CLIProxyAPI"]
|
||||
83
README.md
83
README.md
@@ -1,15 +1,19 @@
|
||||
# CLI Proxy API
|
||||
|
||||
A proxy server that provides an OpenAI-compatible API interface for CLI. This allows you to use CLI models with tools and libraries designed for the OpenAI API.
|
||||
English | [中文](README_CN.md)
|
||||
|
||||
A proxy server that provides an OpenAI/Gemini/Claude compatible API interface for CLI. This allows you to use CLI models with tools and libraries designed for the OpenAI/Gemini/Claude API.
|
||||
|
||||
## Features
|
||||
|
||||
- OpenAI-compatible API endpoints for CLI models
|
||||
- OpenAI/Gemini/Claude compatible API endpoints for CLI models
|
||||
- Support for both streaming and non-streaming responses
|
||||
- Function calling/tools support
|
||||
- Multimodal input support (text and images)
|
||||
- Multiple account support with load balancing
|
||||
- Simple CLI authentication flow
|
||||
- Support for Generative Language API Key
|
||||
- Support Gemini CLI with multiple account load balancing
|
||||
|
||||
## Installation
|
||||
|
||||
@@ -134,7 +138,7 @@ console.log(response.choices[0].message.content);
|
||||
|
||||
- gemini-2.5-pro
|
||||
- gemini-2.5-flash
|
||||
- And various preview versions
|
||||
- And it automates switching to various preview versions
|
||||
|
||||
## Configuration
|
||||
|
||||
@@ -146,13 +150,17 @@ The server uses a YAML configuration file (`config.yaml`) located in the project
|
||||
|
||||
### Configuration Options
|
||||
|
||||
| Parameter | Type | Default | Description |
|
||||
|-------------|----------|--------------------|----------------------------------------------------------------------------------------------|
|
||||
| `port` | integer | 8317 | The port number on which the server will listen |
|
||||
| `auth_dir` | string | "~/.cli-proxy-api" | Directory where authentication tokens are stored. Supports using `~` for home directory |
|
||||
| `proxy-url` | string | "" | Proxy url, support socks5/http/https protocol, example: socks5://user:pass@192.168.1.1:1080/ |
|
||||
| `debug` | boolean | false | Enable debug mode for verbose logging |
|
||||
| `api_keys` | string[] | [] | List of API keys that can be used to authenticate requests |
|
||||
| Parameter | Type | Default | Description |
|
||||
|---------------------------------------|----------|--------------------|----------------------------------------------------------------------------------------------|
|
||||
| `port` | integer | 8317 | The port number on which the server will listen |
|
||||
| `auth-dir` | string | "~/.cli-proxy-api" | Directory where authentication tokens are stored. Supports using `~` for home directory |
|
||||
| `proxy-url` | string | "" | Proxy url, support socks5/http/https protocol, example: socks5://user:pass@192.168.1.1:1080/ |
|
||||
| `quota-exceeded` | object | {} | Configuration for handling quota exceeded |
|
||||
| `quota-exceeded.switch-project` | boolean | true | Whether to automatically switch to another project when a quota is exceeded |
|
||||
| `quota-exceeded.switch-preview-model` | boolean | true | Whether to automatically switch to a preview model when a quota is exceeded |
|
||||
| `debug` | boolean | false | Enable debug mode for verbose logging |
|
||||
| `api-keys` | string[] | [] | List of API keys that can be used to authenticate requests |
|
||||
| `generative-language-api-key` | string[] | [] | List of Generative Language API keys |
|
||||
|
||||
### Example Configuration File
|
||||
|
||||
@@ -161,29 +169,76 @@ The server uses a YAML configuration file (`config.yaml`) located in the project
|
||||
port: 8317
|
||||
|
||||
# Authentication directory (supports ~ for home directory)
|
||||
auth_dir: "~/.cli-proxy-api"
|
||||
auth-dir: "~/.cli-proxy-api"
|
||||
|
||||
# Enable debug logging
|
||||
debug: false
|
||||
|
||||
# Proxy url, support socks5/http/https protocol, example: socks5://user:pass@192.168.1.1:1080/
|
||||
proxy-url: ""
|
||||
|
||||
# Quota exceeded behavior
|
||||
quota-exceeded:
|
||||
switch-project: true # Whether to automatically switch to another project when a quota is exceeded
|
||||
switch-preview-model: true # Whether to automatically switch to a preview model when a quota is exceeded
|
||||
|
||||
# API keys for authentication
|
||||
api_keys:
|
||||
api-keys:
|
||||
- "your-api-key-1"
|
||||
- "your-api-key-2"
|
||||
|
||||
# API keys for official Generative Language API
|
||||
generative-language-api-key:
|
||||
- "AIzaSy...01"
|
||||
- "AIzaSy...02"
|
||||
- "AIzaSy...03"
|
||||
- "AIzaSy...04"
|
||||
```
|
||||
|
||||
### Authentication Directory
|
||||
|
||||
The `auth_dir` parameter specifies where authentication tokens are stored. When you run the login command, the application will create JSON files in this directory containing the authentication tokens for your Google accounts. Multiple accounts can be used for load balancing.
|
||||
The `auth-dir` parameter specifies where authentication tokens are stored. When you run the login command, the application will create JSON files in this directory containing the authentication tokens for your Google accounts. Multiple accounts can be used for load balancing.
|
||||
|
||||
### API Keys
|
||||
|
||||
The `api_keys` parameter allows you to define a list of API keys that can be used to authenticate requests to your proxy server. When making requests to the API, you can include one of these keys in the `Authorization` header:
|
||||
The `api-keys` parameter allows you to define a list of API keys that can be used to authenticate requests to your proxy server. When making requests to the API, you can include one of these keys in the `Authorization` header:
|
||||
|
||||
```
|
||||
Authorization: Bearer your-api-key-1
|
||||
```
|
||||
|
||||
### Official Generative Language API
|
||||
|
||||
The `generative-language-api-key` parameter allows you to define a list of API keys that can be used to authenticate requests to the official Generative Language API.
|
||||
|
||||
## Gemini CLI with multiple account load balancing
|
||||
|
||||
Start CLI Proxy API server, and then set the `CODE_ASSIST_ENDPOINT` environment variable to the URL of the CLI Proxy API server.
|
||||
|
||||
```bash
|
||||
export CODE_ASSIST_ENDPOINT="http://127.0.0.1:8317"
|
||||
```
|
||||
|
||||
The server will relay the `loadCodeAssist`, `onboardUser`, and `countTokens` requests. And automatically load balance the text generation requests between the multiple accounts.
|
||||
|
||||
> [!NOTE]
|
||||
> This feature only allows local access because I couldn't find a way to authenticate the requests.
|
||||
> I hardcoded `127.0.0.1` into the load balancing.
|
||||
|
||||
## Run with Docker
|
||||
|
||||
Run the following command to login:
|
||||
|
||||
```bash
|
||||
docker run --rm -p 8085:8085 -v /path/to/your/config.yaml:/CLIProxyAPI/config.yaml -v /path/to/your/auth-dir:/root/.cli-proxy-api eceasy/cli-proxy-api:latest /CLIProxyAPI/CLIProxyAPI --login
|
||||
```
|
||||
|
||||
Run the following command to start the server:
|
||||
|
||||
```bash
|
||||
docker run --rm -p 8317:8317 -v /path/to/your/config.yaml:/CLIProxyAPI/config.yaml -v /path/to/your/auth-dir:/root/.cli-proxy-api eceasy/cli-proxy-api:latest
|
||||
```
|
||||
|
||||
## Contributing
|
||||
|
||||
Contributions are welcome! Please feel free to submit a Pull Request.
|
||||
|
||||
254
README_CN.md
Normal file
254
README_CN.md
Normal file
@@ -0,0 +1,254 @@
|
||||
# CLI 代理 API
|
||||
|
||||
[English](README.md) | 中文
|
||||
|
||||
一个为 CLI 提供 OpenAI/Gemini/Claude 兼容 API 接口的代理服务器。这让您可以摆脱终端界面的束缚,将 Gemini 的强大能力以 API 的形式轻松接入到任何您喜爱的客户端或应用中。
|
||||
|
||||
## 功能特性
|
||||
|
||||
- 为 CLI 模型提供 OpenAI/Gemini/Claude 兼容的 API 端点
|
||||
- 支持流式和非流式响应
|
||||
- 函数调用/工具支持
|
||||
- 多模态输入支持(文本和图像)
|
||||
- 多账户支持与负载均衡
|
||||
- 简单的 CLI 身份验证流程
|
||||
- 支持 Gemini AIStudio API 密钥
|
||||
- 支持 Gemini CLI 多账户轮询
|
||||
|
||||
## 安装
|
||||
|
||||
### 前置要求
|
||||
|
||||
- Go 1.24 或更高版本
|
||||
- 有权访问 CLI 模型的 Google 账户
|
||||
|
||||
### 从源码构建
|
||||
|
||||
1. 克隆仓库:
|
||||
```bash
|
||||
git clone https://github.com/luispater/CLIProxyAPI.git
|
||||
cd CLIProxyAPI
|
||||
```
|
||||
|
||||
2. 构建应用程序:
|
||||
```bash
|
||||
go build -o cli-proxy-api ./cmd/server
|
||||
```
|
||||
|
||||
## 使用方法
|
||||
|
||||
### 身份验证
|
||||
|
||||
在使用 API 之前,您需要使用 Google 账户进行身份验证:
|
||||
|
||||
```bash
|
||||
./cli-proxy-api --login
|
||||
```
|
||||
|
||||
如果您是旧版 gemini code 用户,可能需要指定项目 ID:
|
||||
|
||||
```bash
|
||||
./cli-proxy-api --login --project_id <your_project_id>
|
||||
```
|
||||
|
||||
### 启动服务器
|
||||
|
||||
身份验证完成后,启动服务器:
|
||||
|
||||
```bash
|
||||
./cli-proxy-api
|
||||
```
|
||||
|
||||
默认情况下,服务器在端口 8317 上运行。
|
||||
|
||||
### API 端点
|
||||
|
||||
#### 列出模型
|
||||
|
||||
```
|
||||
GET http://localhost:8317/v1/models
|
||||
```
|
||||
|
||||
#### 聊天补全
|
||||
|
||||
```
|
||||
POST http://localhost:8317/v1/chat/completions
|
||||
```
|
||||
|
||||
请求体示例:
|
||||
|
||||
```json
|
||||
{
|
||||
"model": "gemini-2.5-pro",
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": "你好,你好吗?"
|
||||
}
|
||||
],
|
||||
"stream": true
|
||||
}
|
||||
```
|
||||
|
||||
### 与 OpenAI 库一起使用
|
||||
|
||||
您可以通过将基础 URL 设置为本地服务器来将此代理与任何 OpenAI 兼容的库一起使用:
|
||||
|
||||
#### Python(使用 OpenAI 库)
|
||||
|
||||
```python
|
||||
from openai import OpenAI
|
||||
|
||||
client = OpenAI(
|
||||
api_key="dummy", # 不使用但必需
|
||||
base_url="http://localhost:8317/v1"
|
||||
)
|
||||
|
||||
response = client.chat.completions.create(
|
||||
model="gemini-2.5-pro",
|
||||
messages=[
|
||||
{"role": "user", "content": "你好,你好吗?"}
|
||||
]
|
||||
)
|
||||
|
||||
print(response.choices[0].message.content)
|
||||
```
|
||||
|
||||
#### JavaScript/TypeScript
|
||||
|
||||
```javascript
|
||||
import OpenAI from 'openai';
|
||||
|
||||
const openai = new OpenAI({
|
||||
apiKey: 'dummy', // 不使用但必需
|
||||
baseURL: 'http://localhost:8317/v1',
|
||||
});
|
||||
|
||||
const response = await openai.chat.completions.create({
|
||||
model: 'gemini-2.5-pro',
|
||||
messages: [
|
||||
{ role: 'user', content: '你好,你好吗?' }
|
||||
],
|
||||
});
|
||||
|
||||
console.log(response.choices[0].message.content);
|
||||
```
|
||||
|
||||
## 支持的模型
|
||||
|
||||
- gemini-2.5-pro
|
||||
- gemini-2.5-flash
|
||||
- 并且自动切换到之前的预览版本
|
||||
|
||||
## 配置
|
||||
|
||||
服务器默认使用位于项目根目录的 YAML 配置文件(`config.yaml`)。您可以使用 `--config` 标志指定不同的配置文件路径:
|
||||
|
||||
```bash
|
||||
./cli-proxy --config /path/to/your/config.yaml
|
||||
```
|
||||
|
||||
### 配置选项
|
||||
|
||||
| 参数 | 类型 | 默认值 | 描述 |
|
||||
|---------------------------------------|----------|--------------------|------------------------------------------------------------------------|
|
||||
| `port` | integer | 8317 | 服务器监听的端口号 |
|
||||
| `auth-dir` | string | "~/.cli-proxy-api" | 存储身份验证令牌的目录。支持使用 `~` 表示主目录 |
|
||||
| `proxy-url` | string | "" | 代理 URL,支持 socks5/http/https 协议,示例:socks5://user:pass@192.168.1.1:1080/ |
|
||||
| `quota-exceeded` | object | {} | 处理配额超限的配置 |
|
||||
| `quota-exceeded.switch-project` | boolean | true | 当配额超限时是否自动切换到另一个项目 |
|
||||
| `quota-exceeded.switch-preview-model` | boolean | true | 当配额超限时是否自动切换到预览模型 |
|
||||
| `debug` | boolean | false | 启用调试模式以进行详细日志记录 |
|
||||
| `api-keys` | string[] | [] | 可用于验证请求的 API 密钥列表 |
|
||||
| `generative-language-api-key` | string[] | [] | 生成式语言 API 密钥列表 |
|
||||
|
||||
### 配置文件示例
|
||||
|
||||
```yaml
|
||||
# 服务器端口
|
||||
port: 8317
|
||||
|
||||
# 身份验证目录(支持 ~ 表示主目录)
|
||||
auth-dir: "~/.cli-proxy-api"
|
||||
|
||||
# 启用调试日志
|
||||
debug: false
|
||||
|
||||
# 代理 URL,支持 socks5/http/https 协议,示例:socks5://user:pass@192.168.1.1:1080/
|
||||
proxy-url: ""
|
||||
|
||||
# 配额超限行为
|
||||
quota-exceeded:
|
||||
switch-project: true # 当配额超限时是否自动切换到另一个项目
|
||||
switch-preview-model: true # 当配额超限时是否自动切换到预览模型
|
||||
|
||||
# 用于本地身份验证的 API 密钥
|
||||
api-keys:
|
||||
- "your-api-key-1"
|
||||
- "your-api-key-2"
|
||||
|
||||
# AIStduio Gemini API 的 API 密钥
|
||||
generative-language-api-key:
|
||||
- "AIzaSy...01"
|
||||
- "AIzaSy...02"
|
||||
- "AIzaSy...03"
|
||||
- "AIzaSy...04"
|
||||
```
|
||||
|
||||
### 身份验证目录
|
||||
|
||||
`auth-dir` 参数指定身份验证令牌的存储位置。当您运行登录命令时,应用程序将在此目录中创建包含 Google 账户身份验证令牌的 JSON 文件。多个账户可用于轮询。
|
||||
|
||||
### API 密钥
|
||||
|
||||
`api-keys` 参数允许您定义可用于验证对代理服务器请求的 API 密钥列表。在向 API 发出请求时,您可以在 `Authorization` 标头中包含其中一个密钥:
|
||||
|
||||
```
|
||||
Authorization: Bearer your-api-key-1
|
||||
```
|
||||
|
||||
### 官方生成式语言 API
|
||||
|
||||
`generative-language-api-key` 参数允许您定义可用于验证对官方 AIStudio Gemini API 请求的 API 密钥列表。
|
||||
|
||||
## Gemini CLI 多账户负载均衡
|
||||
|
||||
启动 CLI 代理 API 服务器,然后将 `CODE_ASSIST_ENDPOINT` 环境变量设置为 CLI 代理 API 服务器的 URL。
|
||||
|
||||
```bash
|
||||
export CODE_ASSIST_ENDPOINT="http://127.0.0.1:8317"
|
||||
```
|
||||
|
||||
服务器将中继 `loadCodeAssist`、`onboardUser` 和 `countTokens` 请求。并自动在多个账户之间轮询文本生成请求。
|
||||
|
||||
> [!NOTE]
|
||||
> 此功能仅允许本地访问,因为找不到一个可以验证请求的方法。
|
||||
> 所以只能强制只有 `127.0.0.1` 可以访问。
|
||||
|
||||
## 使用 Docker 运行
|
||||
|
||||
运行以下命令进行登录:
|
||||
|
||||
```bash
|
||||
docker run --rm -p 8085:8085 -v /path/to/your/config.yaml:/CLIProxyAPI/config.yaml -v /path/to/your/auth-dir:/root/.cli-proxy-api eceasy/cli-proxy-api:latest /CLIProxyAPI/CLIProxyAPI --login
|
||||
```
|
||||
|
||||
运行以下命令启动服务器:
|
||||
|
||||
```bash
|
||||
docker run --rm -p 8317:8317 -v /path/to/your/config.yaml:/CLIProxyAPI/config.yaml -v /path/to/your/auth-dir:/root/.cli-proxy-api eceasy/cli-proxy-api:latest
|
||||
```
|
||||
|
||||
## 贡献
|
||||
|
||||
欢迎贡献!请随时提交 Pull Request。
|
||||
|
||||
1. Fork 仓库
|
||||
2. 创建您的功能分支(`git checkout -b feature/amazing-feature`)
|
||||
3. 提交您的更改(`git commit -m 'Add some amazing feature'`)
|
||||
4. 推送到分支(`git push origin feature/amazing-feature`)
|
||||
5. 打开 Pull Request
|
||||
|
||||
## 许可证
|
||||
|
||||
此项目根据 MIT 许可证授权 - 有关详细信息,请参阅 [LICENSE](LICENSE) 文件。
|
||||
@@ -1,3 +1,6 @@
|
||||
// Package main provides the entry point for the CLI Proxy API server.
|
||||
// This server acts as a proxy that provides OpenAI/Gemini/Claude compatible API interfaces
|
||||
// for CLI models, allowing CLI models to be used with tools and libraries designed for standard AI APIs.
|
||||
package main
|
||||
|
||||
import (
|
||||
@@ -63,14 +66,17 @@ func main() {
|
||||
var wd string
|
||||
|
||||
// Load configuration from the specified path or the default path.
|
||||
var configFilePath string
|
||||
if configPath != "" {
|
||||
configFilePath = configPath
|
||||
cfg, err = config.LoadConfig(configPath)
|
||||
} else {
|
||||
wd, err = os.Getwd()
|
||||
if err != nil {
|
||||
log.Fatalf("failed to get working directory: %v", err)
|
||||
}
|
||||
cfg, err = config.LoadConfig(path.Join(wd, "config.yaml"))
|
||||
configFilePath = path.Join(wd, "config.yaml")
|
||||
cfg, err = config.LoadConfig(configFilePath)
|
||||
}
|
||||
if err != nil {
|
||||
log.Fatalf("failed to load config: %v", err)
|
||||
@@ -102,6 +108,6 @@ func main() {
|
||||
if login {
|
||||
cmd.DoLogin(cfg, projectID)
|
||||
} else {
|
||||
cmd.StartService(cfg)
|
||||
cmd.StartService(cfg, configFilePath)
|
||||
}
|
||||
}
|
||||
|
||||
15
config.example.yaml
Normal file
15
config.example.yaml
Normal file
@@ -0,0 +1,15 @@
|
||||
port: 8317
|
||||
auth-dir: "~/.cli-proxy-api"
|
||||
debug: true
|
||||
proxy-url: ""
|
||||
quota-exceeded:
|
||||
switch-project: true
|
||||
switch-preview-model: true
|
||||
api-keys:
|
||||
- "12345"
|
||||
- "23456"
|
||||
generative-language-api-key:
|
||||
- "AIzaSy...01"
|
||||
- "AIzaSy...02"
|
||||
- "AIzaSy...03"
|
||||
- "AIzaSy...04"
|
||||
@@ -1,7 +0,0 @@
|
||||
port: 8317
|
||||
auth_dir: "~/.cli-proxy-api"
|
||||
debug: true
|
||||
proxy-url: ""
|
||||
api_keys:
|
||||
- "12345"
|
||||
- "23456"
|
||||
3
go.mod
3
go.mod
@@ -8,6 +8,7 @@ require (
|
||||
github.com/skratchdot/open-golang v0.0.0-20200116055534-eef842397966
|
||||
github.com/tidwall/gjson v1.18.0
|
||||
github.com/tidwall/sjson v1.2.5
|
||||
golang.org/x/net v0.37.1-0.20250305215238-2914f4677317
|
||||
golang.org/x/oauth2 v0.30.0
|
||||
gopkg.in/yaml.v3 v3.0.1
|
||||
)
|
||||
@@ -18,6 +19,7 @@ require (
|
||||
github.com/bytedance/sonic/loader v0.1.1 // indirect
|
||||
github.com/cloudwego/base64x v0.1.4 // indirect
|
||||
github.com/cloudwego/iasm v0.2.0 // indirect
|
||||
github.com/fsnotify/fsnotify v1.9.0 // indirect
|
||||
github.com/gabriel-vasile/mimetype v1.4.3 // indirect
|
||||
github.com/gin-contrib/sse v0.1.0 // indirect
|
||||
github.com/go-playground/locales v0.14.1 // indirect
|
||||
@@ -37,7 +39,6 @@ require (
|
||||
github.com/ugorji/go/codec v1.2.12 // indirect
|
||||
golang.org/x/arch v0.8.0 // indirect
|
||||
golang.org/x/crypto v0.36.0 // indirect
|
||||
golang.org/x/net v0.37.1-0.20250305215238-2914f4677317 // indirect
|
||||
golang.org/x/sys v0.31.0 // indirect
|
||||
golang.org/x/text v0.23.0 // indirect
|
||||
google.golang.org/protobuf v1.34.1 // indirect
|
||||
|
||||
2
go.sum
2
go.sum
@@ -11,6 +11,8 @@ github.com/cloudwego/iasm v0.2.0/go.mod h1:8rXZaNYT2n95jn+zTI1sDr+IgcD2GVs0nlbbQ
|
||||
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
|
||||
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||
github.com/fsnotify/fsnotify v1.9.0 h1:2Ml+OJNzbYCTzsxtv8vKSFD9PbJjmhYF14k/jKC7S9k=
|
||||
github.com/fsnotify/fsnotify v1.9.0/go.mod h1:8jBTzvmWwFyi3Pb8djgCCO5IBqzKJ/Jwo8TRcHyHii0=
|
||||
github.com/gabriel-vasile/mimetype v1.4.3 h1:in2uUcidCuFcDKtdcBxlR0rJ1+fsokWf+uqxgUFjbI0=
|
||||
github.com/gabriel-vasile/mimetype v1.4.3/go.mod h1:d8uq/6HKRL6CGdk+aubisF/M5GcPfT7nKyLpA0lbSSk=
|
||||
github.com/gin-contrib/sse v0.1.0 h1:Y/yl/+YNO8GZSjAhjMsSuLt29uWRFHdHYUb5lYOV9qE=
|
||||
|
||||
@@ -1,388 +0,0 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"github.com/luispater/CLIProxyAPI/internal/api/translator"
|
||||
"github.com/luispater/CLIProxyAPI/internal/client"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"github.com/tidwall/gjson"
|
||||
"net/http"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
var (
|
||||
mutex = &sync.Mutex{}
|
||||
lastUsedClientIndex = 0
|
||||
)
|
||||
|
||||
// APIHandlers contains the handlers for API endpoints.
|
||||
// It holds a pool of clients to interact with the backend service.
|
||||
type APIHandlers struct {
|
||||
cliClients []*client.Client
|
||||
debug bool
|
||||
}
|
||||
|
||||
// NewAPIHandlers creates a new API handlers instance.
|
||||
// It takes a slice of clients and a debug flag as input.
|
||||
func NewAPIHandlers(cliClients []*client.Client, debug bool) *APIHandlers {
|
||||
return &APIHandlers{
|
||||
cliClients: cliClients,
|
||||
debug: debug,
|
||||
}
|
||||
}
|
||||
|
||||
// Models handles the /v1/models endpoint.
|
||||
// It returns a hardcoded list of available AI models.
|
||||
func (h *APIHandlers) Models(c *gin.Context) {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"data": []map[string]any{
|
||||
{
|
||||
"id": "gemini-2.5-pro-preview-05-06",
|
||||
"object": "model",
|
||||
"version": "2.5-preview-05-06",
|
||||
"name": "Gemini 2.5 Pro Preview 05-06",
|
||||
"description": "Preview release (May 6th, 2025) of Gemini 2.5 Pro",
|
||||
"context_length": 1048576,
|
||||
"max_completion_tokens": 65536,
|
||||
"supported_parameters": []string{
|
||||
"tools",
|
||||
"temperature",
|
||||
"top_p",
|
||||
"top_k",
|
||||
},
|
||||
"temperature": 1,
|
||||
"topP": 0.95,
|
||||
"topK": 64,
|
||||
"maxTemperature": 2,
|
||||
"thinking": true,
|
||||
},
|
||||
{
|
||||
"id": "gemini-2.5-pro-preview-06-05",
|
||||
"object": "model",
|
||||
"version": "2.5-preview-06-05",
|
||||
"name": "Gemini 2.5 Pro Preview 06-05",
|
||||
"description": "Preview release (June 5th, 2025) of Gemini 2.5 Pro",
|
||||
"context_length": 1048576,
|
||||
"max_completion_tokens": 65536,
|
||||
"supported_parameters": []string{
|
||||
"tools",
|
||||
"temperature",
|
||||
"top_p",
|
||||
"top_k",
|
||||
},
|
||||
"temperature": 1,
|
||||
"topP": 0.95,
|
||||
"topK": 64,
|
||||
"maxTemperature": 2,
|
||||
"thinking": true,
|
||||
},
|
||||
{
|
||||
"id": "gemini-2.5-pro",
|
||||
"object": "model",
|
||||
"version": "2.5",
|
||||
"name": "Gemini 2.5 Pro",
|
||||
"description": "Stable release (June 17th, 2025) of Gemini 2.5 Pro",
|
||||
"context_length": 1048576,
|
||||
"max_completion_tokens": 65536,
|
||||
"supported_parameters": []string{
|
||||
"tools",
|
||||
"temperature",
|
||||
"top_p",
|
||||
"top_k",
|
||||
},
|
||||
"temperature": 1,
|
||||
"topP": 0.95,
|
||||
"topK": 64,
|
||||
"maxTemperature": 2,
|
||||
"thinking": true,
|
||||
},
|
||||
{
|
||||
"id": "gemini-2.5-flash-preview-04-17",
|
||||
"object": "model",
|
||||
"version": "2.5-preview-04-17",
|
||||
"name": "Gemini 2.5 Flash Preview 04-17",
|
||||
"description": "Preview release (April 17th, 2025) of Gemini 2.5 Flash",
|
||||
"context_length": 1048576,
|
||||
"max_completion_tokens": 65536,
|
||||
"supported_parameters": []string{
|
||||
"tools",
|
||||
"temperature",
|
||||
"top_p",
|
||||
"top_k",
|
||||
},
|
||||
"temperature": 1,
|
||||
"topP": 0.95,
|
||||
"topK": 64,
|
||||
"maxTemperature": 2,
|
||||
"thinking": true,
|
||||
},
|
||||
{
|
||||
"id": "gemini-2.5-flash-preview-05-20",
|
||||
"object": "model",
|
||||
"version": "2.5-preview-05-20",
|
||||
"name": "Gemini 2.5 Flash Preview 05-20",
|
||||
"description": "Preview release (April 17th, 2025) of Gemini 2.5 Flash",
|
||||
"context_length": 1048576,
|
||||
"max_completion_tokens": 65536,
|
||||
"supported_parameters": []string{
|
||||
"tools",
|
||||
"temperature",
|
||||
"top_p",
|
||||
"top_k",
|
||||
},
|
||||
"temperature": 1,
|
||||
"topP": 0.95,
|
||||
"topK": 64,
|
||||
"maxTemperature": 2,
|
||||
"thinking": true,
|
||||
},
|
||||
{
|
||||
"id": "gemini-2.5-flash",
|
||||
"object": "model",
|
||||
"version": "001",
|
||||
"name": "Gemini 2.5 Flash",
|
||||
"description": "Stable version of Gemini 2.5 Flash, our mid-size multimodal model that supports up to 1 million tokens, released in June of 2025.",
|
||||
"context_length": 1048576,
|
||||
"max_completion_tokens": 65536,
|
||||
"supported_parameters": []string{
|
||||
"tools",
|
||||
"temperature",
|
||||
"top_p",
|
||||
"top_k",
|
||||
},
|
||||
"temperature": 1,
|
||||
"topP": 0.95,
|
||||
"topK": 64,
|
||||
"maxTemperature": 2,
|
||||
"thinking": true,
|
||||
},
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
// ChatCompletions handles the /v1/chat/completions endpoint.
|
||||
// It determines whether the request is for a streaming or non-streaming response
|
||||
// and calls the appropriate handler.
|
||||
func (h *APIHandlers) ChatCompletions(c *gin.Context) {
|
||||
rawJson, err := c.GetRawData()
|
||||
// If data retrieval fails, return a 400 Bad Request error.
|
||||
if err != nil {
|
||||
c.JSON(http.StatusBadRequest, ErrorResponse{
|
||||
Error: ErrorDetail{
|
||||
Message: fmt.Sprintf("Invalid request: %v", err),
|
||||
Type: "invalid_request_error",
|
||||
},
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// Check if the client requested a streaming response.
|
||||
streamResult := gjson.GetBytes(rawJson, "stream")
|
||||
if streamResult.Type == gjson.True {
|
||||
h.handleStreamingResponse(c, rawJson)
|
||||
} else {
|
||||
h.handleNonStreamingResponse(c, rawJson)
|
||||
}
|
||||
}
|
||||
|
||||
// handleNonStreamingResponse handles non-streaming chat completion responses.
|
||||
// It selects a client from the pool, sends the request, and aggregates the response
|
||||
// before sending it back to the client.
|
||||
func (h *APIHandlers) handleNonStreamingResponse(c *gin.Context, rawJson []byte) {
|
||||
c.Header("Content-Type", "application/json")
|
||||
|
||||
// Handle streaming manually
|
||||
flusher, ok := c.Writer.(http.Flusher)
|
||||
if !ok {
|
||||
c.JSON(http.StatusInternalServerError, ErrorResponse{
|
||||
Error: ErrorDetail{
|
||||
Message: "Streaming not supported",
|
||||
Type: "server_error",
|
||||
},
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
modelName, contents, tools := translator.PrepareRequest(rawJson)
|
||||
cliCtx, cliCancel := context.WithCancel(context.Background())
|
||||
var cliClient *client.Client
|
||||
defer func() {
|
||||
if cliClient != nil {
|
||||
cliClient.RequestMutex.Unlock()
|
||||
}
|
||||
}()
|
||||
|
||||
// Lock the mutex to update the last used page index
|
||||
mutex.Lock()
|
||||
startIndex := lastUsedClientIndex
|
||||
currentIndex := (startIndex + 1) % len(h.cliClients)
|
||||
lastUsedClientIndex = currentIndex
|
||||
mutex.Unlock()
|
||||
|
||||
// Reorder the pages to start from the last used index
|
||||
reorderedPages := make([]*client.Client, len(h.cliClients))
|
||||
for i := 0; i < len(h.cliClients); i++ {
|
||||
reorderedPages[i] = h.cliClients[(startIndex+1+i)%len(h.cliClients)]
|
||||
}
|
||||
|
||||
locked := false
|
||||
for i := 0; i < len(reorderedPages); i++ {
|
||||
cliClient = reorderedPages[i]
|
||||
if cliClient.RequestMutex.TryLock() {
|
||||
locked = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !locked {
|
||||
cliClient = h.cliClients[0]
|
||||
cliClient.RequestMutex.Lock()
|
||||
}
|
||||
|
||||
log.Debugf("Request use account: %s, project id: %s", cliClient.GetEmail(), cliClient.GetProjectID())
|
||||
|
||||
respChan := make(chan []byte)
|
||||
errChan := make(chan *client.ErrorMessage)
|
||||
go func() {
|
||||
resp, err := cliClient.SendMessage(cliCtx, rawJson, modelName, contents, tools)
|
||||
if err != nil {
|
||||
errChan <- err
|
||||
} else {
|
||||
respChan <- resp
|
||||
}
|
||||
}()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-c.Request.Context().Done():
|
||||
if c.Request.Context().Err().Error() == "context canceled" {
|
||||
log.Debugf("Client disconnected: %v", c.Request.Context().Err())
|
||||
cliCancel()
|
||||
return
|
||||
}
|
||||
case respBody := <-respChan:
|
||||
openAIFormat := translator.ConvertCliToOpenAINonStream(respBody)
|
||||
if openAIFormat != "" {
|
||||
_, _ = fmt.Fprintf(c.Writer, "data: %s\n\n", openAIFormat)
|
||||
flusher.Flush()
|
||||
}
|
||||
cliCancel()
|
||||
return
|
||||
case err := <-errChan:
|
||||
c.Status(err.StatusCode)
|
||||
_, _ = fmt.Fprint(c.Writer, err.Error.Error())
|
||||
flusher.Flush()
|
||||
cliCancel()
|
||||
return
|
||||
case <-time.After(500 * time.Millisecond):
|
||||
_, _ = c.Writer.Write([]byte("\n"))
|
||||
flusher.Flush()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// handleStreamingResponse handles streaming responses
|
||||
func (h *APIHandlers) handleStreamingResponse(c *gin.Context, rawJson []byte) {
|
||||
c.Header("Content-Type", "text/event-stream")
|
||||
c.Header("Cache-Control", "no-cache")
|
||||
c.Header("Connection", "keep-alive")
|
||||
c.Header("Access-Control-Allow-Origin", "*")
|
||||
|
||||
// Get the http.Flusher interface to manually flush the response.
|
||||
flusher, ok := c.Writer.(http.Flusher)
|
||||
if !ok {
|
||||
c.JSON(http.StatusInternalServerError, ErrorResponse{
|
||||
Error: ErrorDetail{
|
||||
Message: "Streaming not supported",
|
||||
Type: "server_error",
|
||||
},
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// Prepare the request for the backend client.
|
||||
modelName, contents, tools := translator.PrepareRequest(rawJson)
|
||||
cliCtx, cliCancel := context.WithCancel(context.Background())
|
||||
var cliClient *client.Client
|
||||
defer func() {
|
||||
// Ensure the client's mutex is unlocked on function exit.
|
||||
if cliClient != nil {
|
||||
cliClient.RequestMutex.Unlock()
|
||||
}
|
||||
}()
|
||||
|
||||
// Use a round-robin approach to select the next available client.
|
||||
// This distributes the load among the available clients.
|
||||
mutex.Lock()
|
||||
startIndex := lastUsedClientIndex
|
||||
currentIndex := (startIndex + 1) % len(h.cliClients)
|
||||
lastUsedClientIndex = currentIndex
|
||||
mutex.Unlock()
|
||||
|
||||
// Reorder the clients to start from the next client in the rotation.
|
||||
reorderedPages := make([]*client.Client, len(h.cliClients))
|
||||
for i := 0; i < len(h.cliClients); i++ {
|
||||
reorderedPages[i] = h.cliClients[(startIndex+1+i)%len(h.cliClients)]
|
||||
}
|
||||
|
||||
// Attempt to lock a client for the request.
|
||||
locked := false
|
||||
for i := 0; i < len(reorderedPages); i++ {
|
||||
cliClient = reorderedPages[i]
|
||||
if cliClient.RequestMutex.TryLock() {
|
||||
locked = true
|
||||
break
|
||||
}
|
||||
}
|
||||
// If no client is available, block and wait for the first client.
|
||||
if !locked {
|
||||
cliClient = h.cliClients[0]
|
||||
cliClient.RequestMutex.Lock()
|
||||
}
|
||||
log.Debugf("Request use account: %s, project id: %s", cliClient.GetEmail(), cliClient.GetProjectID())
|
||||
// Send the message and receive response chunks and errors via channels.
|
||||
respChan, errChan := cliClient.SendMessageStream(cliCtx, rawJson, modelName, contents, tools)
|
||||
for {
|
||||
select {
|
||||
// Handle client disconnection.
|
||||
case <-c.Request.Context().Done():
|
||||
if c.Request.Context().Err().Error() == "context canceled" {
|
||||
log.Debugf("Client disconnected: %v", c.Request.Context().Err())
|
||||
cliCancel() // Cancel the backend request.
|
||||
return
|
||||
}
|
||||
// Process incoming response chunks.
|
||||
case chunk, okStream := <-respChan:
|
||||
if !okStream {
|
||||
// Stream is closed, send the final [DONE] message.
|
||||
_, _ = fmt.Fprintf(c.Writer, "data: [DONE]\n\n")
|
||||
flusher.Flush()
|
||||
cliCancel()
|
||||
return
|
||||
} else {
|
||||
// Convert the chunk to OpenAI format and send it to the client.
|
||||
openAIFormat := translator.ConvertCliToOpenAI(chunk)
|
||||
if openAIFormat != "" {
|
||||
_, _ = fmt.Fprintf(c.Writer, "data: %s\n\n", openAIFormat)
|
||||
flusher.Flush()
|
||||
}
|
||||
}
|
||||
// Handle errors from the backend.
|
||||
case err, okError := <-errChan:
|
||||
if okError {
|
||||
c.Status(err.StatusCode)
|
||||
_, _ = fmt.Fprint(c.Writer, err.Error.Error())
|
||||
flusher.Flush()
|
||||
cliCancel()
|
||||
return
|
||||
}
|
||||
// Send a keep-alive signal to the client.
|
||||
case <-time.After(500 * time.Millisecond):
|
||||
_, _ = c.Writer.Write([]byte(": CLI-PROXY-API PROCESSING\n\n"))
|
||||
flusher.Flush()
|
||||
}
|
||||
}
|
||||
}
|
||||
208
internal/api/handlers/claude/code-handlers.go
Normal file
208
internal/api/handlers/claude/code-handlers.go
Normal file
@@ -0,0 +1,208 @@
|
||||
// Package claude provides HTTP handlers for Claude API code-related functionality.
|
||||
// This package implements Claude-compatible streaming chat completions with sophisticated
|
||||
// client rotation and quota management systems to ensure high availability and optimal
|
||||
// resource utilization across multiple backend clients. It handles request translation
|
||||
// between Claude API format and the underlying Gemini backend, providing seamless
|
||||
// API compatibility while maintaining robust error handling and connection management.
|
||||
package claude
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/luispater/CLIProxyAPI/internal/api/handlers"
|
||||
"github.com/luispater/CLIProxyAPI/internal/api/translator/claude/code"
|
||||
"github.com/luispater/CLIProxyAPI/internal/client"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
// ClaudeCodeAPIHandlers contains the handlers for Claude API endpoints.
|
||||
// It holds a pool of clients to interact with the backend service.
|
||||
type ClaudeCodeAPIHandlers struct {
|
||||
*handlers.APIHandlers
|
||||
}
|
||||
|
||||
// NewClaudeCodeAPIHandlers creates a new Claude API handlers instance.
|
||||
// It takes an APIHandlers instance as input and returns a ClaudeCodeAPIHandlers.
|
||||
func NewClaudeCodeAPIHandlers(apiHandlers *handlers.APIHandlers) *ClaudeCodeAPIHandlers {
|
||||
return &ClaudeCodeAPIHandlers{
|
||||
APIHandlers: apiHandlers,
|
||||
}
|
||||
}
|
||||
|
||||
// ClaudeMessages handles Claude-compatible streaming chat completions.
|
||||
// This function implements a sophisticated client rotation and quota management system
|
||||
// to ensure high availability and optimal resource utilization across multiple backend clients.
|
||||
func (h *ClaudeCodeAPIHandlers) ClaudeMessages(c *gin.Context) {
|
||||
// Extract raw JSON data from the incoming request
|
||||
rawJSON, err := c.GetRawData()
|
||||
// If data retrieval fails, return a 400 Bad Request error.
|
||||
if err != nil {
|
||||
c.JSON(http.StatusBadRequest, handlers.ErrorResponse{
|
||||
Error: handlers.ErrorDetail{
|
||||
Message: fmt.Sprintf("Invalid request: %v", err),
|
||||
Type: "invalid_request_error",
|
||||
},
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// Set up Server-Sent Events (SSE) headers for streaming response
|
||||
// These headers are essential for maintaining a persistent connection
|
||||
// and enabling real-time streaming of chat completions
|
||||
c.Header("Content-Type", "text/event-stream")
|
||||
c.Header("Cache-Control", "no-cache")
|
||||
c.Header("Connection", "keep-alive")
|
||||
c.Header("Access-Control-Allow-Origin", "*")
|
||||
|
||||
// Get the http.Flusher interface to manually flush the response.
|
||||
// This is crucial for streaming as it allows immediate sending of data chunks
|
||||
flusher, ok := c.Writer.(http.Flusher)
|
||||
if !ok {
|
||||
c.JSON(http.StatusInternalServerError, handlers.ErrorResponse{
|
||||
Error: handlers.ErrorDetail{
|
||||
Message: "Streaming not supported",
|
||||
Type: "server_error",
|
||||
},
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// Parse and prepare the Claude request, extracting model name, system instructions,
|
||||
// conversation contents, and available tools from the raw JSON
|
||||
modelName, systemInstruction, contents, tools := code.PrepareClaudeRequest(rawJSON)
|
||||
|
||||
// Map Claude model names to corresponding Gemini models
|
||||
// This allows the proxy to handle Claude API calls using Gemini backend
|
||||
if modelName == "claude-sonnet-4-20250514" {
|
||||
modelName = "gemini-2.5-pro"
|
||||
} else if modelName == "claude-3-5-haiku-20241022" {
|
||||
modelName = "gemini-2.5-flash"
|
||||
}
|
||||
|
||||
// Create a cancellable context for the backend client request
|
||||
// This allows proper cleanup and cancellation of ongoing requests
|
||||
cliCtx, cliCancel := context.WithCancel(context.Background())
|
||||
var cliClient *client.Client
|
||||
defer func() {
|
||||
// Ensure the client's mutex is unlocked on function exit.
|
||||
// This prevents deadlocks and ensures proper resource cleanup
|
||||
if cliClient != nil {
|
||||
cliClient.RequestMutex.Unlock()
|
||||
}
|
||||
}()
|
||||
|
||||
// Main client rotation loop with quota management
|
||||
// This loop implements a sophisticated load balancing and failover mechanism
|
||||
outLoop:
|
||||
for {
|
||||
var errorResponse *client.ErrorMessage
|
||||
cliClient, errorResponse = h.GetClient(modelName)
|
||||
if errorResponse != nil {
|
||||
c.Status(errorResponse.StatusCode)
|
||||
_, _ = fmt.Fprint(c.Writer, errorResponse.Error)
|
||||
flusher.Flush()
|
||||
cliCancel()
|
||||
return
|
||||
}
|
||||
|
||||
// Determine the authentication method being used by the selected client
|
||||
// This affects how responses are formatted and logged
|
||||
isGlAPIKey := false
|
||||
if glAPIKey := cliClient.GetGenerativeLanguageAPIKey(); glAPIKey != "" {
|
||||
log.Debugf("Request use generative language API Key: %s", glAPIKey)
|
||||
isGlAPIKey = true
|
||||
} else {
|
||||
log.Debugf("Request use account: %s, project id: %s", cliClient.GetEmail(), cliClient.GetProjectID())
|
||||
}
|
||||
// Initiate streaming communication with the backend client
|
||||
// This returns two channels: one for response chunks and one for errors
|
||||
|
||||
includeThoughts := false
|
||||
if userAgent, hasKey := c.Request.Header["User-Agent"]; hasKey {
|
||||
includeThoughts = !strings.Contains(userAgent[0], "claude-cli")
|
||||
}
|
||||
|
||||
respChan, errChan := cliClient.SendMessageStream(cliCtx, rawJSON, modelName, systemInstruction, contents, tools, includeThoughts)
|
||||
|
||||
// Track response state for proper Claude format conversion
|
||||
hasFirstResponse := false
|
||||
responseType := 0
|
||||
responseIndex := 0
|
||||
|
||||
// Main streaming loop - handles multiple concurrent events using Go channels
|
||||
// This select statement manages four different types of events simultaneously
|
||||
for {
|
||||
select {
|
||||
// Case 1: Handle client disconnection
|
||||
// Detects when the HTTP client has disconnected and cleans up resources
|
||||
case <-c.Request.Context().Done():
|
||||
if c.Request.Context().Err().Error() == "context canceled" {
|
||||
log.Debugf("Client disconnected: %v", c.Request.Context().Err())
|
||||
cliCancel() // Cancel the backend request to prevent resource leaks
|
||||
return
|
||||
}
|
||||
|
||||
// Case 2: Process incoming response chunks from the backend
|
||||
// This handles the actual streaming data from the AI model
|
||||
case chunk, okStream := <-respChan:
|
||||
if !okStream {
|
||||
// Stream has ended - send the final message_stop event
|
||||
// This follows the Claude API specification for stream termination
|
||||
_, _ = c.Writer.Write([]byte(`event: message_stop`))
|
||||
_, _ = c.Writer.Write([]byte("\n"))
|
||||
_, _ = c.Writer.Write([]byte(`data: {"type":"message_stop"}`))
|
||||
_, _ = c.Writer.Write([]byte("\n\n\n"))
|
||||
|
||||
flusher.Flush()
|
||||
cliCancel()
|
||||
return
|
||||
}
|
||||
// Convert the backend response to Claude-compatible format
|
||||
// This translation layer ensures API compatibility
|
||||
claudeFormat := code.ConvertCliToClaude(chunk, isGlAPIKey, hasFirstResponse, &responseType, &responseIndex)
|
||||
if claudeFormat != "" {
|
||||
_, _ = c.Writer.Write([]byte(claudeFormat))
|
||||
flusher.Flush() // Immediately send the chunk to the client
|
||||
}
|
||||
hasFirstResponse = true
|
||||
|
||||
// Case 3: Handle errors from the backend
|
||||
// This manages various error conditions and implements retry logic
|
||||
case errInfo, okError := <-errChan:
|
||||
if okError {
|
||||
// Special handling for quota exceeded errors
|
||||
// If configured, attempt to switch to a different project/client
|
||||
if errInfo.StatusCode == 429 && h.Cfg.QuotaExceeded.SwitchProject {
|
||||
continue outLoop // Restart the client selection process
|
||||
} else {
|
||||
// Forward other errors directly to the client
|
||||
c.Status(errInfo.StatusCode)
|
||||
_, _ = fmt.Fprint(c.Writer, errInfo.Error.Error())
|
||||
flusher.Flush()
|
||||
cliCancel()
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// Case 4: Send periodic keep-alive signals
|
||||
// Prevents connection timeouts during long-running requests
|
||||
case <-time.After(500 * time.Millisecond):
|
||||
if hasFirstResponse {
|
||||
// Send a ping event to maintain the connection
|
||||
// This is especially important for slow AI model responses
|
||||
output := "event: ping\n"
|
||||
output = output + `data: {"type": "ping"}`
|
||||
output = output + "\n\n\n"
|
||||
_, _ = c.Writer.Write([]byte(output))
|
||||
|
||||
flusher.Flush()
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
268
internal/api/handlers/gemini/cli/cli-handlers.go
Normal file
268
internal/api/handlers/gemini/cli/cli-handlers.go
Normal file
@@ -0,0 +1,268 @@
|
||||
// Package cli provides HTTP handlers for Gemini CLI API functionality.
|
||||
// This package implements handlers that process CLI-specific requests for Gemini API operations,
|
||||
// including content generation and streaming content generation endpoints.
|
||||
// The handlers restrict access to localhost only and manage communication with the backend service.
|
||||
package cli
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"fmt"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/luispater/CLIProxyAPI/internal/api/handlers"
|
||||
"github.com/luispater/CLIProxyAPI/internal/client"
|
||||
"github.com/luispater/CLIProxyAPI/internal/util"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"github.com/tidwall/gjson"
|
||||
"github.com/tidwall/sjson"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
// GeminiCLIAPIHandlers contains the handlers for Gemini CLI API endpoints.
|
||||
// It holds a pool of clients to interact with the backend service.
|
||||
type GeminiCLIAPIHandlers struct {
|
||||
*handlers.APIHandlers
|
||||
}
|
||||
|
||||
// NewGeminiCLIAPIHandlers creates a new Gemini CLI API handlers instance.
|
||||
// It takes an APIHandlers instance as input and returns a GeminiCLIAPIHandlers.
|
||||
func NewGeminiCLIAPIHandlers(apiHandlers *handlers.APIHandlers) *GeminiCLIAPIHandlers {
|
||||
return &GeminiCLIAPIHandlers{
|
||||
APIHandlers: apiHandlers,
|
||||
}
|
||||
}
|
||||
|
||||
// CLIHandler handles CLI-specific requests for Gemini API operations.
|
||||
// It restricts access to localhost only and routes requests to appropriate internal handlers.
|
||||
func (h *GeminiCLIAPIHandlers) CLIHandler(c *gin.Context) {
|
||||
if !strings.HasPrefix(c.Request.RemoteAddr, "127.0.0.1:") {
|
||||
c.JSON(http.StatusForbidden, handlers.ErrorResponse{
|
||||
Error: handlers.ErrorDetail{
|
||||
Message: "CLI reply only allow local access",
|
||||
Type: "forbidden",
|
||||
},
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
rawJSON, _ := c.GetRawData()
|
||||
requestRawURI := c.Request.URL.Path
|
||||
if requestRawURI == "/v1internal:generateContent" {
|
||||
h.internalGenerateContent(c, rawJSON)
|
||||
} else if requestRawURI == "/v1internal:streamGenerateContent" {
|
||||
h.internalStreamGenerateContent(c, rawJSON)
|
||||
} else {
|
||||
reqBody := bytes.NewBuffer(rawJSON)
|
||||
req, err := http.NewRequest("POST", fmt.Sprintf("https://cloudcode-pa.googleapis.com%s", c.Request.URL.RequestURI()), reqBody)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusBadRequest, handlers.ErrorResponse{
|
||||
Error: handlers.ErrorDetail{
|
||||
Message: fmt.Sprintf("Invalid request: %v", err),
|
||||
Type: "invalid_request_error",
|
||||
},
|
||||
})
|
||||
return
|
||||
}
|
||||
for key, value := range c.Request.Header {
|
||||
req.Header[key] = value
|
||||
}
|
||||
|
||||
httpClient, err := util.SetProxy(h.Cfg, &http.Client{})
|
||||
if err != nil {
|
||||
log.Fatalf("set proxy failed: %v", err)
|
||||
}
|
||||
|
||||
resp, err := httpClient.Do(req)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusBadRequest, handlers.ErrorResponse{
|
||||
Error: handlers.ErrorDetail{
|
||||
Message: fmt.Sprintf("Invalid request: %v", err),
|
||||
Type: "invalid_request_error",
|
||||
},
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
|
||||
defer func() {
|
||||
if err = resp.Body.Close(); err != nil {
|
||||
log.Printf("warn: failed to close response body: %v", err)
|
||||
}
|
||||
}()
|
||||
bodyBytes, _ := io.ReadAll(resp.Body)
|
||||
|
||||
c.JSON(http.StatusBadRequest, handlers.ErrorResponse{
|
||||
Error: handlers.ErrorDetail{
|
||||
Message: string(bodyBytes),
|
||||
Type: "invalid_request_error",
|
||||
},
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
defer func() {
|
||||
_ = resp.Body.Close()
|
||||
}()
|
||||
|
||||
for key, value := range resp.Header {
|
||||
c.Header(key, value[0])
|
||||
}
|
||||
output, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
log.Errorf("Failed to read response body: %v", err)
|
||||
return
|
||||
}
|
||||
_, _ = c.Writer.Write(output)
|
||||
}
|
||||
}
|
||||
|
||||
func (h *GeminiCLIAPIHandlers) internalStreamGenerateContent(c *gin.Context, rawJSON []byte) {
|
||||
alt := h.GetAlt(c)
|
||||
|
||||
if alt == "" {
|
||||
c.Header("Content-Type", "text/event-stream")
|
||||
c.Header("Cache-Control", "no-cache")
|
||||
c.Header("Connection", "keep-alive")
|
||||
c.Header("Access-Control-Allow-Origin", "*")
|
||||
}
|
||||
|
||||
// Get the http.Flusher interface to manually flush the response.
|
||||
flusher, ok := c.Writer.(http.Flusher)
|
||||
if !ok {
|
||||
c.JSON(http.StatusInternalServerError, handlers.ErrorResponse{
|
||||
Error: handlers.ErrorDetail{
|
||||
Message: "Streaming not supported",
|
||||
Type: "server_error",
|
||||
},
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
modelResult := gjson.GetBytes(rawJSON, "model")
|
||||
modelName := modelResult.String()
|
||||
|
||||
cliCtx, cliCancel := context.WithCancel(context.Background())
|
||||
var cliClient *client.Client
|
||||
defer func() {
|
||||
// Ensure the client's mutex is unlocked on function exit.
|
||||
if cliClient != nil {
|
||||
cliClient.RequestMutex.Unlock()
|
||||
}
|
||||
}()
|
||||
|
||||
outLoop:
|
||||
for {
|
||||
var errorResponse *client.ErrorMessage
|
||||
cliClient, errorResponse = h.GetClient(modelName)
|
||||
if errorResponse != nil {
|
||||
c.Status(errorResponse.StatusCode)
|
||||
_, _ = fmt.Fprint(c.Writer, errorResponse.Error)
|
||||
flusher.Flush()
|
||||
cliCancel()
|
||||
return
|
||||
}
|
||||
|
||||
if glAPIKey := cliClient.GetGenerativeLanguageAPIKey(); glAPIKey != "" {
|
||||
log.Debugf("Request use generative language API Key: %s", glAPIKey)
|
||||
} else {
|
||||
log.Debugf("Request use account: %s, project id: %s", cliClient.GetEmail(), cliClient.GetProjectID())
|
||||
}
|
||||
// Send the message and receive response chunks and errors via channels.
|
||||
respChan, errChan := cliClient.SendRawMessageStream(cliCtx, rawJSON, "")
|
||||
hasFirstResponse := false
|
||||
for {
|
||||
select {
|
||||
// Handle client disconnection.
|
||||
case <-c.Request.Context().Done():
|
||||
if c.Request.Context().Err().Error() == "context canceled" {
|
||||
log.Debugf("Client disconnected: %v", c.Request.Context().Err())
|
||||
cliCancel() // Cancel the backend request.
|
||||
return
|
||||
}
|
||||
// Process incoming response chunks.
|
||||
case chunk, okStream := <-respChan:
|
||||
if !okStream {
|
||||
cliCancel()
|
||||
return
|
||||
}
|
||||
hasFirstResponse = true
|
||||
if cliClient.GetGenerativeLanguageAPIKey() != "" {
|
||||
chunk, _ = sjson.SetRawBytes(chunk, "response", chunk)
|
||||
}
|
||||
_, _ = c.Writer.Write([]byte("data: "))
|
||||
_, _ = c.Writer.Write(chunk)
|
||||
_, _ = c.Writer.Write([]byte("\n\n"))
|
||||
flusher.Flush()
|
||||
// Handle errors from the backend.
|
||||
case err, okError := <-errChan:
|
||||
if okError {
|
||||
if err.StatusCode == 429 && h.Cfg.QuotaExceeded.SwitchProject {
|
||||
continue outLoop
|
||||
} else {
|
||||
c.Status(err.StatusCode)
|
||||
_, _ = fmt.Fprint(c.Writer, err.Error.Error())
|
||||
flusher.Flush()
|
||||
cliCancel()
|
||||
}
|
||||
return
|
||||
}
|
||||
// Send a keep-alive signal to the client.
|
||||
case <-time.After(500 * time.Millisecond):
|
||||
if hasFirstResponse {
|
||||
_, _ = c.Writer.Write([]byte("\n"))
|
||||
flusher.Flush()
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (h *GeminiCLIAPIHandlers) internalGenerateContent(c *gin.Context, rawJSON []byte) {
|
||||
c.Header("Content-Type", "application/json")
|
||||
|
||||
modelResult := gjson.GetBytes(rawJSON, "model")
|
||||
modelName := modelResult.String()
|
||||
cliCtx, cliCancel := context.WithCancel(context.Background())
|
||||
var cliClient *client.Client
|
||||
defer func() {
|
||||
if cliClient != nil {
|
||||
cliClient.RequestMutex.Unlock()
|
||||
}
|
||||
}()
|
||||
|
||||
for {
|
||||
var errorResponse *client.ErrorMessage
|
||||
cliClient, errorResponse = h.GetClient(modelName)
|
||||
if errorResponse != nil {
|
||||
c.Status(errorResponse.StatusCode)
|
||||
_, _ = fmt.Fprint(c.Writer, errorResponse.Error)
|
||||
cliCancel()
|
||||
return
|
||||
}
|
||||
|
||||
if glAPIKey := cliClient.GetGenerativeLanguageAPIKey(); glAPIKey != "" {
|
||||
log.Debugf("Request use generative language API Key: %s", glAPIKey)
|
||||
} else {
|
||||
log.Debugf("Request use account: %s, project id: %s", cliClient.GetEmail(), cliClient.GetProjectID())
|
||||
}
|
||||
|
||||
resp, err := cliClient.SendRawMessage(cliCtx, rawJSON, "")
|
||||
if err != nil {
|
||||
if err.StatusCode == 429 && h.Cfg.QuotaExceeded.SwitchProject {
|
||||
continue
|
||||
} else {
|
||||
c.Status(err.StatusCode)
|
||||
_, _ = c.Writer.Write([]byte(err.Error.Error()))
|
||||
cliCancel()
|
||||
}
|
||||
break
|
||||
} else {
|
||||
_, _ = c.Writer.Write(resp)
|
||||
cliCancel()
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
437
internal/api/handlers/gemini/gemini-handlers.go
Normal file
437
internal/api/handlers/gemini/gemini-handlers.go
Normal file
@@ -0,0 +1,437 @@
|
||||
// Package gemini provides HTTP handlers for Gemini API endpoints.
|
||||
// This package implements handlers for managing Gemini model operations including
|
||||
// model listing, content generation, streaming content generation, and token counting.
|
||||
// It serves as a proxy layer between clients and the Gemini backend service,
|
||||
// handling request translation, client management, and response processing.
|
||||
package gemini
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/luispater/CLIProxyAPI/internal/api/handlers"
|
||||
"github.com/luispater/CLIProxyAPI/internal/api/translator/gemini/cli"
|
||||
"github.com/luispater/CLIProxyAPI/internal/client"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"github.com/tidwall/gjson"
|
||||
"github.com/tidwall/sjson"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
// GeminiAPIHandlers contains the handlers for Gemini API endpoints.
|
||||
// It holds a pool of clients to interact with the backend service.
|
||||
type GeminiAPIHandlers struct {
|
||||
*handlers.APIHandlers
|
||||
}
|
||||
|
||||
// NewGeminiAPIHandlers creates a new Gemini API handlers instance.
|
||||
// It takes an APIHandlers instance as input and returns a GeminiAPIHandlers.
|
||||
func NewGeminiAPIHandlers(apiHandlers *handlers.APIHandlers) *GeminiAPIHandlers {
|
||||
return &GeminiAPIHandlers{
|
||||
APIHandlers: apiHandlers,
|
||||
}
|
||||
}
|
||||
|
||||
// GeminiModels handles the Gemini models listing endpoint.
|
||||
// It returns a JSON response containing available Gemini models and their specifications.
|
||||
func (h *GeminiAPIHandlers) GeminiModels(c *gin.Context) {
|
||||
c.Status(http.StatusOK)
|
||||
c.Header("Content-Type", "application/json; charset=UTF-8")
|
||||
_, _ = c.Writer.Write([]byte(`{"models":[{"name":"models/gemini-2.5-flash","version":"001","displayName":"Gemini `))
|
||||
_, _ = c.Writer.Write([]byte(`2.5 Flash","description":"Stable version of Gemini 2.5 Flash, our mid-size multimod`))
|
||||
_, _ = c.Writer.Write([]byte(`al model that supports up to 1 million tokens, released in June of 2025.","inputTok`))
|
||||
_, _ = c.Writer.Write([]byte(`enLimit":1048576,"outputTokenLimit":65536,"supportedGenerationMethods":["generateCo`))
|
||||
_, _ = c.Writer.Write([]byte(`ntent","countTokens","createCachedContent","batchGenerateContent"],"temperature":1,`))
|
||||
_, _ = c.Writer.Write([]byte(`"topP":0.95,"topK":64,"maxTemperature":2,"thinking":true},{"name":"models/gemini-2.`))
|
||||
_, _ = c.Writer.Write([]byte(`5-pro","version":"2.5","displayName":"Gemini 2.5 Pro","description":"Stable release`))
|
||||
_, _ = c.Writer.Write([]byte(` (June 17th, 2025) of Gemini 2.5 Pro","inputTokenLimit":1048576,"outputTokenLimit":`))
|
||||
_, _ = c.Writer.Write([]byte(`65536,"supportedGenerationMethods":["generateContent","countTokens","createCachedCo`))
|
||||
_, _ = c.Writer.Write([]byte(`ntent","batchGenerateContent"],"temperature":1,"topP":0.95,"topK":64,"maxTemperatur`))
|
||||
_, _ = c.Writer.Write([]byte(`e":2,"thinking":true}],"nextPageToken":""}`))
|
||||
}
|
||||
|
||||
// GeminiGetHandler handles GET requests for specific Gemini model information.
|
||||
// It returns detailed information about a specific Gemini model based on the action parameter.
|
||||
func (h *GeminiAPIHandlers) GeminiGetHandler(c *gin.Context) {
|
||||
var request struct {
|
||||
Action string `uri:"action" binding:"required"`
|
||||
}
|
||||
if err := c.ShouldBindUri(&request); err != nil {
|
||||
c.JSON(http.StatusBadRequest, handlers.ErrorResponse{
|
||||
Error: handlers.ErrorDetail{
|
||||
Message: fmt.Sprintf("Invalid request: %v", err),
|
||||
Type: "invalid_request_error",
|
||||
},
|
||||
})
|
||||
return
|
||||
}
|
||||
if request.Action == "gemini-2.5-pro" {
|
||||
c.Status(http.StatusOK)
|
||||
c.Header("Content-Type", "application/json; charset=UTF-8")
|
||||
_, _ = c.Writer.Write([]byte(`{"name":"models/gemini-2.5-pro","version":"2.5","displayName":"Gemini 2.5 Pro",`))
|
||||
_, _ = c.Writer.Write([]byte(`"description":"Stable release (June 17th, 2025) of Gemini 2.5 Pro","inputTokenL`))
|
||||
_, _ = c.Writer.Write([]byte(`imit":1048576,"outputTokenLimit":65536,"supportedGenerationMethods":["generateC`))
|
||||
_, _ = c.Writer.Write([]byte(`ontent","countTokens","createCachedContent","batchGenerateContent"],"temperatur`))
|
||||
_, _ = c.Writer.Write([]byte(`e":1,"topP":0.95,"topK":64,"maxTemperature":2,"thinking":true}`))
|
||||
} else if request.Action == "gemini-2.5-flash" {
|
||||
c.Status(http.StatusOK)
|
||||
c.Header("Content-Type", "application/json; charset=UTF-8")
|
||||
_, _ = c.Writer.Write([]byte(`{"name":"models/gemini-2.5-flash","version":"001","displayName":"Gemini 2.5 Fla`))
|
||||
_, _ = c.Writer.Write([]byte(`sh","description":"Stable version of Gemini 2.5 Flash, our mid-size multimodal `))
|
||||
_, _ = c.Writer.Write([]byte(`model that supports up to 1 million tokens, released in June of 2025.","inputTo`))
|
||||
_, _ = c.Writer.Write([]byte(`kenLimit":1048576,"outputTokenLimit":65536,"supportedGenerationMethods":["gener`))
|
||||
_, _ = c.Writer.Write([]byte(`ateContent","countTokens","createCachedContent","batchGenerateContent"],"temper`))
|
||||
_, _ = c.Writer.Write([]byte(`ature":1,"topP":0.95,"topK":64,"maxTemperature":2,"thinking":true}`))
|
||||
} else {
|
||||
c.Status(http.StatusNotFound)
|
||||
_, _ = c.Writer.Write([]byte(
|
||||
`{"error":{"message":"Not Found","code":404,"status":"NOT_FOUND"}}`,
|
||||
))
|
||||
}
|
||||
}
|
||||
|
||||
// GeminiHandler handles POST requests for Gemini API operations.
|
||||
// It routes requests to appropriate handlers based on the action parameter (model:method format).
|
||||
func (h *GeminiAPIHandlers) GeminiHandler(c *gin.Context) {
|
||||
var request struct {
|
||||
Action string `uri:"action" binding:"required"`
|
||||
}
|
||||
if err := c.ShouldBindUri(&request); err != nil {
|
||||
c.JSON(http.StatusBadRequest, handlers.ErrorResponse{
|
||||
Error: handlers.ErrorDetail{
|
||||
Message: fmt.Sprintf("Invalid request: %v", err),
|
||||
Type: "invalid_request_error",
|
||||
},
|
||||
})
|
||||
return
|
||||
}
|
||||
action := strings.Split(request.Action, ":")
|
||||
if len(action) != 2 {
|
||||
c.JSON(http.StatusNotFound, handlers.ErrorResponse{
|
||||
Error: handlers.ErrorDetail{
|
||||
Message: fmt.Sprintf("%s not found.", c.Request.URL.Path),
|
||||
Type: "invalid_request_error",
|
||||
},
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
modelName := action[0]
|
||||
method := action[1]
|
||||
rawJSON, _ := c.GetRawData()
|
||||
rawJSON, _ = sjson.SetBytes(rawJSON, "model", []byte(modelName))
|
||||
|
||||
if method == "generateContent" {
|
||||
h.geminiGenerateContent(c, rawJSON)
|
||||
} else if method == "streamGenerateContent" {
|
||||
h.geminiStreamGenerateContent(c, rawJSON)
|
||||
} else if method == "countTokens" {
|
||||
h.geminiCountTokens(c, rawJSON)
|
||||
}
|
||||
}
|
||||
|
||||
func (h *GeminiAPIHandlers) geminiStreamGenerateContent(c *gin.Context, rawJSON []byte) {
|
||||
alt := h.GetAlt(c)
|
||||
|
||||
if alt == "" {
|
||||
c.Header("Content-Type", "text/event-stream")
|
||||
c.Header("Cache-Control", "no-cache")
|
||||
c.Header("Connection", "keep-alive")
|
||||
c.Header("Access-Control-Allow-Origin", "*")
|
||||
}
|
||||
|
||||
// Get the http.Flusher interface to manually flush the response.
|
||||
flusher, ok := c.Writer.(http.Flusher)
|
||||
if !ok {
|
||||
c.JSON(http.StatusInternalServerError, handlers.ErrorResponse{
|
||||
Error: handlers.ErrorDetail{
|
||||
Message: "Streaming not supported",
|
||||
Type: "server_error",
|
||||
},
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
modelResult := gjson.GetBytes(rawJSON, "model")
|
||||
modelName := modelResult.String()
|
||||
|
||||
cliCtx, cliCancel := context.WithCancel(context.Background())
|
||||
var cliClient *client.Client
|
||||
defer func() {
|
||||
// Ensure the client's mutex is unlocked on function exit.
|
||||
if cliClient != nil {
|
||||
cliClient.RequestMutex.Unlock()
|
||||
}
|
||||
}()
|
||||
|
||||
outLoop:
|
||||
for {
|
||||
var errorResponse *client.ErrorMessage
|
||||
cliClient, errorResponse = h.GetClient(modelName)
|
||||
if errorResponse != nil {
|
||||
c.Status(errorResponse.StatusCode)
|
||||
_, _ = fmt.Fprint(c.Writer, errorResponse.Error)
|
||||
flusher.Flush()
|
||||
cliCancel()
|
||||
return
|
||||
}
|
||||
|
||||
template := ""
|
||||
parsed := gjson.Parse(string(rawJSON))
|
||||
contents := parsed.Get("request.contents")
|
||||
if contents.Exists() {
|
||||
template = string(rawJSON)
|
||||
} else {
|
||||
template = `{"project":"","request":{},"model":""}`
|
||||
template, _ = sjson.SetRaw(template, "request", string(rawJSON))
|
||||
template, _ = sjson.Set(template, "model", gjson.Get(template, "request.model").String())
|
||||
template, _ = sjson.Delete(template, "request.model")
|
||||
}
|
||||
|
||||
template, errFixCLIToolResponse := cli.FixCLIToolResponse(template)
|
||||
if errFixCLIToolResponse != nil {
|
||||
c.JSON(http.StatusInternalServerError, handlers.ErrorResponse{
|
||||
Error: handlers.ErrorDetail{
|
||||
Message: errFixCLIToolResponse.Error(),
|
||||
Type: "server_error",
|
||||
},
|
||||
})
|
||||
cliCancel()
|
||||
return
|
||||
}
|
||||
|
||||
systemInstructionResult := gjson.Get(template, "request.system_instruction")
|
||||
if systemInstructionResult.Exists() {
|
||||
template, _ = sjson.SetRaw(template, "request.systemInstruction", systemInstructionResult.Raw)
|
||||
template, _ = sjson.Delete(template, "request.system_instruction")
|
||||
}
|
||||
rawJSON = []byte(template)
|
||||
|
||||
if glAPIKey := cliClient.GetGenerativeLanguageAPIKey(); glAPIKey != "" {
|
||||
log.Debugf("Request use generative language API Key: %s", glAPIKey)
|
||||
} else {
|
||||
log.Debugf("Request use account: %s, project id: %s", cliClient.GetEmail(), cliClient.GetProjectID())
|
||||
}
|
||||
|
||||
// Send the message and receive response chunks and errors via channels.
|
||||
respChan, errChan := cliClient.SendRawMessageStream(cliCtx, rawJSON, alt)
|
||||
for {
|
||||
select {
|
||||
// Handle client disconnection.
|
||||
case <-c.Request.Context().Done():
|
||||
if c.Request.Context().Err().Error() == "context canceled" {
|
||||
log.Debugf("Client disconnected: %v", c.Request.Context().Err())
|
||||
cliCancel() // Cancel the backend request.
|
||||
return
|
||||
}
|
||||
// Process incoming response chunks.
|
||||
case chunk, okStream := <-respChan:
|
||||
if !okStream {
|
||||
cliCancel()
|
||||
return
|
||||
}
|
||||
if cliClient.GetGenerativeLanguageAPIKey() == "" {
|
||||
if alt == "" {
|
||||
responseResult := gjson.GetBytes(chunk, "response")
|
||||
if responseResult.Exists() {
|
||||
chunk = []byte(responseResult.Raw)
|
||||
}
|
||||
} else {
|
||||
chunkTemplate := "[]"
|
||||
responseResult := gjson.ParseBytes(chunk)
|
||||
if responseResult.IsArray() {
|
||||
responseResultItems := responseResult.Array()
|
||||
for i := 0; i < len(responseResultItems); i++ {
|
||||
responseResultItem := responseResultItems[i]
|
||||
if responseResultItem.Get("response").Exists() {
|
||||
chunkTemplate, _ = sjson.SetRaw(chunkTemplate, "-1", responseResultItem.Get("response").Raw)
|
||||
}
|
||||
}
|
||||
}
|
||||
chunk = []byte(chunkTemplate)
|
||||
}
|
||||
}
|
||||
if alt == "" {
|
||||
_, _ = c.Writer.Write([]byte("data: "))
|
||||
_, _ = c.Writer.Write(chunk)
|
||||
_, _ = c.Writer.Write([]byte("\n\n"))
|
||||
} else {
|
||||
_, _ = c.Writer.Write(chunk)
|
||||
}
|
||||
flusher.Flush()
|
||||
// Handle errors from the backend.
|
||||
case err, okError := <-errChan:
|
||||
if okError {
|
||||
if err.StatusCode == 429 && h.Cfg.QuotaExceeded.SwitchProject {
|
||||
log.Debugf("quota exceeded, switch client")
|
||||
continue outLoop
|
||||
} else {
|
||||
log.Debugf("error code :%d, error: %v", err.StatusCode, err.Error.Error())
|
||||
c.Status(err.StatusCode)
|
||||
_, _ = fmt.Fprint(c.Writer, err.Error.Error())
|
||||
flusher.Flush()
|
||||
cliCancel()
|
||||
}
|
||||
return
|
||||
}
|
||||
// Send a keep-alive signal to the client.
|
||||
case <-time.After(500 * time.Millisecond):
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (h *GeminiAPIHandlers) geminiCountTokens(c *gin.Context, rawJSON []byte) {
|
||||
c.Header("Content-Type", "application/json")
|
||||
|
||||
alt := h.GetAlt(c)
|
||||
// orgrawJSON := rawJSON
|
||||
modelResult := gjson.GetBytes(rawJSON, "model")
|
||||
modelName := modelResult.String()
|
||||
cliCtx, cliCancel := context.WithCancel(context.Background())
|
||||
var cliClient *client.Client
|
||||
defer func() {
|
||||
if cliClient != nil {
|
||||
cliClient.RequestMutex.Unlock()
|
||||
}
|
||||
}()
|
||||
|
||||
for {
|
||||
var errorResponse *client.ErrorMessage
|
||||
cliClient, errorResponse = h.GetClient(modelName, false)
|
||||
if errorResponse != nil {
|
||||
c.Status(errorResponse.StatusCode)
|
||||
_, _ = fmt.Fprint(c.Writer, errorResponse.Error)
|
||||
cliCancel()
|
||||
return
|
||||
}
|
||||
|
||||
if glAPIKey := cliClient.GetGenerativeLanguageAPIKey(); glAPIKey != "" {
|
||||
log.Debugf("Request use generative language API Key: %s", glAPIKey)
|
||||
} else {
|
||||
log.Debugf("Request use account: %s, project id: %s", cliClient.GetEmail(), cliClient.GetProjectID())
|
||||
|
||||
template := `{"request":{}}`
|
||||
if gjson.GetBytes(rawJSON, "generateContentRequest").Exists() {
|
||||
template, _ = sjson.SetRaw(template, "request", gjson.GetBytes(rawJSON, "generateContentRequest").Raw)
|
||||
template, _ = sjson.Delete(template, "generateContentRequest")
|
||||
} else if gjson.GetBytes(rawJSON, "contents").Exists() {
|
||||
template, _ = sjson.SetRaw(template, "request.contents", gjson.GetBytes(rawJSON, "contents").Raw)
|
||||
template, _ = sjson.Delete(template, "contents")
|
||||
}
|
||||
rawJSON = []byte(template)
|
||||
}
|
||||
|
||||
resp, err := cliClient.SendRawTokenCount(cliCtx, rawJSON, alt)
|
||||
if err != nil {
|
||||
if err.StatusCode == 429 && h.Cfg.QuotaExceeded.SwitchProject {
|
||||
continue
|
||||
} else {
|
||||
c.Status(err.StatusCode)
|
||||
_, _ = c.Writer.Write([]byte(err.Error.Error()))
|
||||
cliCancel()
|
||||
// log.Debugf(err.Error.Error())
|
||||
// log.Debugf(string(rawJSON))
|
||||
// log.Debugf(string(orgrawJSON))
|
||||
}
|
||||
break
|
||||
} else {
|
||||
if cliClient.GetGenerativeLanguageAPIKey() == "" {
|
||||
responseResult := gjson.GetBytes(resp, "response")
|
||||
if responseResult.Exists() {
|
||||
resp = []byte(responseResult.Raw)
|
||||
}
|
||||
}
|
||||
_, _ = c.Writer.Write(resp)
|
||||
cliCancel()
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (h *GeminiAPIHandlers) geminiGenerateContent(c *gin.Context, rawJSON []byte) {
|
||||
c.Header("Content-Type", "application/json")
|
||||
|
||||
alt := h.GetAlt(c)
|
||||
|
||||
modelResult := gjson.GetBytes(rawJSON, "model")
|
||||
modelName := modelResult.String()
|
||||
cliCtx, cliCancel := context.WithCancel(context.Background())
|
||||
var cliClient *client.Client
|
||||
defer func() {
|
||||
if cliClient != nil {
|
||||
cliClient.RequestMutex.Unlock()
|
||||
}
|
||||
}()
|
||||
|
||||
for {
|
||||
var errorResponse *client.ErrorMessage
|
||||
cliClient, errorResponse = h.GetClient(modelName)
|
||||
if errorResponse != nil {
|
||||
c.Status(errorResponse.StatusCode)
|
||||
_, _ = fmt.Fprint(c.Writer, errorResponse.Error)
|
||||
cliCancel()
|
||||
return
|
||||
}
|
||||
|
||||
template := ""
|
||||
parsed := gjson.Parse(string(rawJSON))
|
||||
contents := parsed.Get("request.contents")
|
||||
if contents.Exists() {
|
||||
template = string(rawJSON)
|
||||
} else {
|
||||
template = `{"project":"","request":{},"model":""}`
|
||||
template, _ = sjson.SetRaw(template, "request", string(rawJSON))
|
||||
template, _ = sjson.Set(template, "model", gjson.Get(template, "request.model").String())
|
||||
template, _ = sjson.Delete(template, "request.model")
|
||||
}
|
||||
|
||||
template, errFixCLIToolResponse := cli.FixCLIToolResponse(template)
|
||||
if errFixCLIToolResponse != nil {
|
||||
c.JSON(http.StatusInternalServerError, handlers.ErrorResponse{
|
||||
Error: handlers.ErrorDetail{
|
||||
Message: errFixCLIToolResponse.Error(),
|
||||
Type: "server_error",
|
||||
},
|
||||
})
|
||||
cliCancel()
|
||||
return
|
||||
}
|
||||
|
||||
systemInstructionResult := gjson.Get(template, "request.system_instruction")
|
||||
if systemInstructionResult.Exists() {
|
||||
template, _ = sjson.SetRaw(template, "request.systemInstruction", systemInstructionResult.Raw)
|
||||
template, _ = sjson.Delete(template, "request.system_instruction")
|
||||
}
|
||||
rawJSON = []byte(template)
|
||||
|
||||
if glAPIKey := cliClient.GetGenerativeLanguageAPIKey(); glAPIKey != "" {
|
||||
log.Debugf("Request use generative language API Key: %s", glAPIKey)
|
||||
} else {
|
||||
log.Debugf("Request use account: %s, project id: %s", cliClient.GetEmail(), cliClient.GetProjectID())
|
||||
}
|
||||
resp, err := cliClient.SendRawMessage(cliCtx, rawJSON, alt)
|
||||
if err != nil {
|
||||
if err.StatusCode == 429 && h.Cfg.QuotaExceeded.SwitchProject {
|
||||
continue
|
||||
} else {
|
||||
c.Status(err.StatusCode)
|
||||
_, _ = c.Writer.Write([]byte(err.Error.Error()))
|
||||
cliCancel()
|
||||
}
|
||||
break
|
||||
} else {
|
||||
if cliClient.GetGenerativeLanguageAPIKey() == "" {
|
||||
responseResult := gjson.GetBytes(resp, "response")
|
||||
if responseResult.Exists() {
|
||||
resp = []byte(responseResult.Raw)
|
||||
}
|
||||
}
|
||||
_, _ = c.Writer.Write(resp)
|
||||
cliCancel()
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
122
internal/api/handlers/handlers.go
Normal file
122
internal/api/handlers/handlers.go
Normal file
@@ -0,0 +1,122 @@
|
||||
// Package handlers provides core API handler functionality for the CLI Proxy API server.
|
||||
// It includes common types, client management, load balancing, and error handling
|
||||
// shared across all API endpoint handlers (OpenAI, Claude, Gemini).
|
||||
package handlers
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/luispater/CLIProxyAPI/internal/client"
|
||||
"github.com/luispater/CLIProxyAPI/internal/config"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"sync"
|
||||
)
|
||||
|
||||
// ErrorResponse represents a standard error response format for the API.
|
||||
// It contains a single ErrorDetail field.
|
||||
type ErrorResponse struct {
|
||||
Error ErrorDetail `json:"error"`
|
||||
}
|
||||
|
||||
// ErrorDetail provides specific information about an error that occurred.
|
||||
// It includes a human-readable message, an error type, and an optional error code.
|
||||
type ErrorDetail struct {
|
||||
// A human-readable message providing more details about the error.
|
||||
Message string `json:"message"`
|
||||
// The type of error that occurred (e.g., "invalid_request_error").
|
||||
Type string `json:"type"`
|
||||
// A short code identifying the error, if applicable.
|
||||
Code string `json:"code,omitempty"`
|
||||
}
|
||||
|
||||
// APIHandlers contains the handlers for API endpoints.
|
||||
// It holds a pool of clients to interact with the backend service.
|
||||
type APIHandlers struct {
|
||||
CliClients []*client.Client
|
||||
Cfg *config.Config
|
||||
Mutex *sync.Mutex
|
||||
LastUsedClientIndex int
|
||||
}
|
||||
|
||||
// NewAPIHandlers creates a new API handlers instance.
|
||||
// It takes a slice of clients and a debug flag as input.
|
||||
func NewAPIHandlers(cliClients []*client.Client, cfg *config.Config) *APIHandlers {
|
||||
return &APIHandlers{
|
||||
CliClients: cliClients,
|
||||
Cfg: cfg,
|
||||
Mutex: &sync.Mutex{},
|
||||
LastUsedClientIndex: 0,
|
||||
}
|
||||
}
|
||||
|
||||
// UpdateClients updates the handlers' client list and configuration
|
||||
func (h *APIHandlers) UpdateClients(clients []*client.Client, cfg *config.Config) {
|
||||
h.CliClients = clients
|
||||
h.Cfg = cfg
|
||||
}
|
||||
|
||||
// GetClient returns an available client from the pool using round-robin load balancing.
|
||||
// It checks for quota limits and tries to find an unlocked client for immediate use.
|
||||
// The modelName parameter is used to check quota status for specific models.
|
||||
func (h *APIHandlers) GetClient(modelName string, isGenerateContent ...bool) (*client.Client, *client.ErrorMessage) {
|
||||
if len(h.CliClients) == 0 {
|
||||
return nil, &client.ErrorMessage{StatusCode: 500, Error: fmt.Errorf("no clients available")}
|
||||
}
|
||||
|
||||
var cliClient *client.Client
|
||||
|
||||
// Lock the mutex to update the last used client index
|
||||
h.Mutex.Lock()
|
||||
startIndex := h.LastUsedClientIndex
|
||||
if (len(isGenerateContent) > 0 && isGenerateContent[0]) || len(isGenerateContent) == 0 {
|
||||
currentIndex := (startIndex + 1) % len(h.CliClients)
|
||||
h.LastUsedClientIndex = currentIndex
|
||||
}
|
||||
h.Mutex.Unlock()
|
||||
|
||||
// Reorder the client to start from the last used index
|
||||
reorderedClients := make([]*client.Client, 0)
|
||||
for i := 0; i < len(h.CliClients); i++ {
|
||||
cliClient = h.CliClients[(startIndex+1+i)%len(h.CliClients)]
|
||||
if cliClient.IsModelQuotaExceeded(modelName) {
|
||||
log.Debugf("Model %s is quota exceeded for account %s, project id: %s", modelName, cliClient.GetEmail(), cliClient.GetProjectID())
|
||||
cliClient = nil
|
||||
continue
|
||||
}
|
||||
reorderedClients = append(reorderedClients, cliClient)
|
||||
}
|
||||
|
||||
if len(reorderedClients) == 0 {
|
||||
return nil, &client.ErrorMessage{StatusCode: 429, Error: fmt.Errorf(`{"error":{"code":429,"message":"All the models of '%s' are quota exceeded","status":"RESOURCE_EXHAUSTED"}}`, modelName)}
|
||||
}
|
||||
|
||||
locked := false
|
||||
for i := 0; i < len(reorderedClients); i++ {
|
||||
cliClient = reorderedClients[i]
|
||||
if cliClient.RequestMutex.TryLock() {
|
||||
locked = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !locked {
|
||||
cliClient = h.CliClients[0]
|
||||
cliClient.RequestMutex.Lock()
|
||||
}
|
||||
|
||||
return cliClient, nil
|
||||
}
|
||||
|
||||
// GetAlt extracts the 'alt' parameter from the request query string.
|
||||
// It checks both 'alt' and '$alt' parameters and returns the appropriate value.
|
||||
func (h *APIHandlers) GetAlt(c *gin.Context) string {
|
||||
var alt string
|
||||
var hasAlt bool
|
||||
alt, hasAlt = c.GetQuery("alt")
|
||||
if !hasAlt {
|
||||
alt, _ = c.GetQuery("$alt")
|
||||
}
|
||||
if alt == "sse" {
|
||||
return ""
|
||||
}
|
||||
return alt
|
||||
}
|
||||
264
internal/api/handlers/openai/openai-handlers.go
Normal file
264
internal/api/handlers/openai/openai-handlers.go
Normal file
@@ -0,0 +1,264 @@
|
||||
// Package openai provides HTTP handlers for OpenAI API endpoints.
|
||||
// This package implements the OpenAI-compatible API interface, including model listing
|
||||
// and chat completion functionality. It supports both streaming and non-streaming responses,
|
||||
// and manages a pool of clients to interact with backend services.
|
||||
// The handlers translate OpenAI API requests to the appropriate backend format and
|
||||
// convert responses back to OpenAI-compatible format.
|
||||
package openai
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"github.com/luispater/CLIProxyAPI/internal/api/handlers"
|
||||
"github.com/luispater/CLIProxyAPI/internal/api/translator/openai"
|
||||
"github.com/luispater/CLIProxyAPI/internal/client"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"github.com/tidwall/gjson"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
// OpenAIAPIHandlers contains the handlers for OpenAI API endpoints.
|
||||
// It holds a pool of clients to interact with the backend service.
|
||||
type OpenAIAPIHandlers struct {
|
||||
*handlers.APIHandlers
|
||||
}
|
||||
|
||||
// NewOpenAIAPIHandlers creates a new OpenAI API handlers instance.
|
||||
// It takes an APIHandlers instance as input and returns an OpenAIAPIHandlers.
|
||||
func NewOpenAIAPIHandlers(apiHandlers *handlers.APIHandlers) *OpenAIAPIHandlers {
|
||||
return &OpenAIAPIHandlers{
|
||||
APIHandlers: apiHandlers,
|
||||
}
|
||||
}
|
||||
|
||||
// Models handles the /v1/models endpoint.
|
||||
// It returns a hardcoded list of available AI models.
|
||||
func (h *OpenAIAPIHandlers) Models(c *gin.Context) {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"data": []map[string]any{
|
||||
{
|
||||
"id": "gemini-2.5-pro",
|
||||
"object": "model",
|
||||
"version": "2.5",
|
||||
"name": "Gemini 2.5 Pro",
|
||||
"description": "Stable release (June 17th, 2025) of Gemini 2.5 Pro",
|
||||
"context_length": 1048576,
|
||||
"max_completion_tokens": 65536,
|
||||
"supported_parameters": []string{
|
||||
"tools",
|
||||
"temperature",
|
||||
"top_p",
|
||||
"top_k",
|
||||
},
|
||||
"temperature": 1,
|
||||
"topP": 0.95,
|
||||
"topK": 64,
|
||||
"maxTemperature": 2,
|
||||
"thinking": true,
|
||||
},
|
||||
{
|
||||
"id": "gemini-2.5-flash",
|
||||
"object": "model",
|
||||
"version": "001",
|
||||
"name": "Gemini 2.5 Flash",
|
||||
"description": "Stable version of Gemini 2.5 Flash, our mid-size multimodal model that supports up to 1 million tokens, released in June of 2025.",
|
||||
"context_length": 1048576,
|
||||
"max_completion_tokens": 65536,
|
||||
"supported_parameters": []string{
|
||||
"tools",
|
||||
"temperature",
|
||||
"top_p",
|
||||
"top_k",
|
||||
},
|
||||
"temperature": 1,
|
||||
"topP": 0.95,
|
||||
"topK": 64,
|
||||
"maxTemperature": 2,
|
||||
"thinking": true,
|
||||
},
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
// ChatCompletions handles the /v1/chat/completions endpoint.
|
||||
// It determines whether the request is for a streaming or non-streaming response
|
||||
// and calls the appropriate handler.
|
||||
func (h *OpenAIAPIHandlers) ChatCompletions(c *gin.Context) {
|
||||
rawJSON, err := c.GetRawData()
|
||||
// If data retrieval fails, return a 400 Bad Request error.
|
||||
if err != nil {
|
||||
c.JSON(http.StatusBadRequest, handlers.ErrorResponse{
|
||||
Error: handlers.ErrorDetail{
|
||||
Message: fmt.Sprintf("Invalid request: %v", err),
|
||||
Type: "invalid_request_error",
|
||||
},
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// Check if the client requested a streaming response.
|
||||
streamResult := gjson.GetBytes(rawJSON, "stream")
|
||||
if streamResult.Type == gjson.True {
|
||||
h.handleStreamingResponse(c, rawJSON)
|
||||
} else {
|
||||
h.handleNonStreamingResponse(c, rawJSON)
|
||||
}
|
||||
}
|
||||
|
||||
// handleNonStreamingResponse handles non-streaming chat completion responses.
|
||||
// It selects a client from the pool, sends the request, and aggregates the response
|
||||
// before sending it back to the client.
|
||||
func (h *OpenAIAPIHandlers) handleNonStreamingResponse(c *gin.Context, rawJSON []byte) {
|
||||
c.Header("Content-Type", "application/json")
|
||||
|
||||
modelName, systemInstruction, contents, tools := openai.PrepareRequest(rawJSON)
|
||||
cliCtx, cliCancel := context.WithCancel(context.Background())
|
||||
var cliClient *client.Client
|
||||
defer func() {
|
||||
if cliClient != nil {
|
||||
cliClient.RequestMutex.Unlock()
|
||||
}
|
||||
}()
|
||||
|
||||
for {
|
||||
var errorResponse *client.ErrorMessage
|
||||
cliClient, errorResponse = h.GetClient(modelName)
|
||||
if errorResponse != nil {
|
||||
c.Status(errorResponse.StatusCode)
|
||||
_, _ = fmt.Fprint(c.Writer, errorResponse.Error)
|
||||
cliCancel()
|
||||
return
|
||||
}
|
||||
|
||||
isGlAPIKey := false
|
||||
if glAPIKey := cliClient.GetGenerativeLanguageAPIKey(); glAPIKey != "" {
|
||||
log.Debugf("Request use generative language API Key: %s", glAPIKey)
|
||||
isGlAPIKey = true
|
||||
} else {
|
||||
log.Debugf("Request use account: %s, project id: %s", cliClient.GetEmail(), cliClient.GetProjectID())
|
||||
}
|
||||
|
||||
resp, err := cliClient.SendMessage(cliCtx, rawJSON, modelName, systemInstruction, contents, tools)
|
||||
if err != nil {
|
||||
if err.StatusCode == 429 && h.Cfg.QuotaExceeded.SwitchProject {
|
||||
continue
|
||||
} else {
|
||||
c.Status(err.StatusCode)
|
||||
_, _ = c.Writer.Write([]byte(err.Error.Error()))
|
||||
cliCancel()
|
||||
}
|
||||
break
|
||||
} else {
|
||||
openAIFormat := openai.ConvertCliToOpenAINonStream(resp, time.Now().Unix(), isGlAPIKey)
|
||||
if openAIFormat != "" {
|
||||
_, _ = c.Writer.Write([]byte(openAIFormat))
|
||||
}
|
||||
cliCancel()
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// handleStreamingResponse handles streaming responses
|
||||
func (h *OpenAIAPIHandlers) handleStreamingResponse(c *gin.Context, rawJSON []byte) {
|
||||
c.Header("Content-Type", "text/event-stream")
|
||||
c.Header("Cache-Control", "no-cache")
|
||||
c.Header("Connection", "keep-alive")
|
||||
c.Header("Access-Control-Allow-Origin", "*")
|
||||
|
||||
// Get the http.Flusher interface to manually flush the response.
|
||||
flusher, ok := c.Writer.(http.Flusher)
|
||||
if !ok {
|
||||
c.JSON(http.StatusInternalServerError, handlers.ErrorResponse{
|
||||
Error: handlers.ErrorDetail{
|
||||
Message: "Streaming not supported",
|
||||
Type: "server_error",
|
||||
},
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// Prepare the request for the backend client.
|
||||
modelName, systemInstruction, contents, tools := openai.PrepareRequest(rawJSON)
|
||||
cliCtx, cliCancel := context.WithCancel(context.Background())
|
||||
var cliClient *client.Client
|
||||
defer func() {
|
||||
// Ensure the client's mutex is unlocked on function exit.
|
||||
if cliClient != nil {
|
||||
cliClient.RequestMutex.Unlock()
|
||||
}
|
||||
}()
|
||||
|
||||
outLoop:
|
||||
for {
|
||||
var errorResponse *client.ErrorMessage
|
||||
cliClient, errorResponse = h.GetClient(modelName)
|
||||
if errorResponse != nil {
|
||||
c.Status(errorResponse.StatusCode)
|
||||
_, _ = fmt.Fprint(c.Writer, errorResponse.Error)
|
||||
flusher.Flush()
|
||||
cliCancel()
|
||||
return
|
||||
}
|
||||
|
||||
isGlAPIKey := false
|
||||
if glAPIKey := cliClient.GetGenerativeLanguageAPIKey(); glAPIKey != "" {
|
||||
log.Debugf("Request use generative language API Key: %s", glAPIKey)
|
||||
isGlAPIKey = true
|
||||
} else {
|
||||
log.Debugf("Request use account: %s, project id: %s", cliClient.GetEmail(), cliClient.GetProjectID())
|
||||
}
|
||||
// Send the message and receive response chunks and errors via channels.
|
||||
respChan, errChan := cliClient.SendMessageStream(cliCtx, rawJSON, modelName, systemInstruction, contents, tools)
|
||||
hasFirstResponse := false
|
||||
for {
|
||||
select {
|
||||
// Handle client disconnection.
|
||||
case <-c.Request.Context().Done():
|
||||
if c.Request.Context().Err().Error() == "context canceled" {
|
||||
log.Debugf("Client disconnected: %v", c.Request.Context().Err())
|
||||
cliCancel() // Cancel the backend request.
|
||||
return
|
||||
}
|
||||
// Process incoming response chunks.
|
||||
case chunk, okStream := <-respChan:
|
||||
if !okStream {
|
||||
// Stream is closed, send the final [DONE] message.
|
||||
_, _ = fmt.Fprintf(c.Writer, "data: [DONE]\n\n")
|
||||
flusher.Flush()
|
||||
cliCancel()
|
||||
return
|
||||
}
|
||||
// Convert the chunk to OpenAI format and send it to the client.
|
||||
hasFirstResponse = true
|
||||
openAIFormat := openai.ConvertCliToOpenAI(chunk, time.Now().Unix(), isGlAPIKey)
|
||||
if openAIFormat != "" {
|
||||
_, _ = fmt.Fprintf(c.Writer, "data: %s\n\n", openAIFormat)
|
||||
flusher.Flush()
|
||||
}
|
||||
// Handle errors from the backend.
|
||||
case err, okError := <-errChan:
|
||||
if okError {
|
||||
if err.StatusCode == 429 && h.Cfg.QuotaExceeded.SwitchProject {
|
||||
continue outLoop
|
||||
} else {
|
||||
c.Status(err.StatusCode)
|
||||
_, _ = fmt.Fprint(c.Writer, err.Error.Error())
|
||||
flusher.Flush()
|
||||
cliCancel()
|
||||
}
|
||||
return
|
||||
}
|
||||
// Send a keep-alive signal to the client.
|
||||
case <-time.After(500 * time.Millisecond):
|
||||
if hasFirstResponse {
|
||||
_, _ = c.Writer.Write([]byte(": CLI-PROXY-API PROCESSING\n\n"))
|
||||
flusher.Flush()
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1,18 +0,0 @@
|
||||
package api
|
||||
|
||||
// ErrorResponse represents a standard error response format for the API.
|
||||
// It contains a single ErrorDetail field.
|
||||
type ErrorResponse struct {
|
||||
Error ErrorDetail `json:"error"`
|
||||
}
|
||||
|
||||
// ErrorDetail provides specific information about an error that occurred.
|
||||
// It includes a human-readable message, an error type, and an optional error code.
|
||||
type ErrorDetail struct {
|
||||
// A human-readable message providing more details about the error.
|
||||
Message string `json:"message"`
|
||||
// The type of error that occurred (e.g., "invalid_request_error").
|
||||
Type string `json:"type"`
|
||||
// A short code identifying the error, if applicable.
|
||||
Code string `json:"code,omitempty"`
|
||||
}
|
||||
@@ -1,3 +1,7 @@
|
||||
// Package api provides the HTTP API server implementation for the CLI Proxy API.
|
||||
// It includes the main server struct, routing setup, middleware for CORS and authentication,
|
||||
// and integration with various AI API handlers (OpenAI, Claude, Gemini).
|
||||
// The server supports hot-reloading of clients and configuration.
|
||||
package api
|
||||
|
||||
import (
|
||||
@@ -5,7 +9,13 @@ import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/luispater/CLIProxyAPI/internal/api/handlers"
|
||||
"github.com/luispater/CLIProxyAPI/internal/api/handlers/claude"
|
||||
"github.com/luispater/CLIProxyAPI/internal/api/handlers/gemini"
|
||||
"github.com/luispater/CLIProxyAPI/internal/api/handlers/gemini/cli"
|
||||
"github.com/luispater/CLIProxyAPI/internal/api/handlers/openai"
|
||||
"github.com/luispater/CLIProxyAPI/internal/client"
|
||||
"github.com/luispater/CLIProxyAPI/internal/config"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"net/http"
|
||||
"strings"
|
||||
@@ -16,31 +26,18 @@ import (
|
||||
type Server struct {
|
||||
engine *gin.Engine
|
||||
server *http.Server
|
||||
handlers *APIHandlers
|
||||
cfg *ServerConfig
|
||||
}
|
||||
|
||||
// ServerConfig contains the configuration for the API server.
|
||||
type ServerConfig struct {
|
||||
// Port is the port number the server will listen on.
|
||||
Port string
|
||||
// Debug enables or disables debug mode for the server and Gin.
|
||||
Debug bool
|
||||
// ApiKeys is a list of valid API keys for authentication.
|
||||
ApiKeys []string
|
||||
handlers *handlers.APIHandlers
|
||||
cfg *config.Config
|
||||
}
|
||||
|
||||
// NewServer creates and initializes a new API server instance.
|
||||
// It sets up the Gin engine, middleware, routes, and handlers.
|
||||
func NewServer(config *ServerConfig, cliClients []*client.Client) *Server {
|
||||
func NewServer(cfg *config.Config, cliClients []*client.Client) *Server {
|
||||
// Set gin mode
|
||||
if !config.Debug {
|
||||
if !cfg.Debug {
|
||||
gin.SetMode(gin.ReleaseMode)
|
||||
}
|
||||
|
||||
// Create handlers
|
||||
handlers := NewAPIHandlers(cliClients, config.Debug)
|
||||
|
||||
// Create gin engine
|
||||
engine := gin.New()
|
||||
|
||||
@@ -52,8 +49,8 @@ func NewServer(config *ServerConfig, cliClients []*client.Client) *Server {
|
||||
// Create server instance
|
||||
s := &Server{
|
||||
engine: engine,
|
||||
handlers: handlers,
|
||||
cfg: config,
|
||||
handlers: handlers.NewAPIHandlers(cliClients, cfg),
|
||||
cfg: cfg,
|
||||
}
|
||||
|
||||
// Setup routes
|
||||
@@ -61,7 +58,7 @@ func NewServer(config *ServerConfig, cliClients []*client.Client) *Server {
|
||||
|
||||
// Create HTTP server
|
||||
s.server = &http.Server{
|
||||
Addr: ":" + config.Port,
|
||||
Addr: fmt.Sprintf(":%d", cfg.Port),
|
||||
Handler: engine,
|
||||
}
|
||||
|
||||
@@ -71,12 +68,27 @@ func NewServer(config *ServerConfig, cliClients []*client.Client) *Server {
|
||||
// setupRoutes configures the API routes for the server.
|
||||
// It defines the endpoints and associates them with their respective handlers.
|
||||
func (s *Server) setupRoutes() {
|
||||
openaiHandlers := openai.NewOpenAIAPIHandlers(s.handlers)
|
||||
geminiHandlers := gemini.NewGeminiAPIHandlers(s.handlers)
|
||||
geminiCLIHandlers := cli.NewGeminiCLIAPIHandlers(s.handlers)
|
||||
claudeCodeHandlers := claude.NewClaudeCodeAPIHandlers(s.handlers)
|
||||
|
||||
// OpenAI compatible API routes
|
||||
v1 := s.engine.Group("/v1")
|
||||
v1.Use(AuthMiddleware(s.cfg))
|
||||
{
|
||||
v1.GET("/models", s.handlers.Models)
|
||||
v1.POST("/chat/completions", s.handlers.ChatCompletions)
|
||||
v1.GET("/models", openaiHandlers.Models)
|
||||
v1.POST("/chat/completions", openaiHandlers.ChatCompletions)
|
||||
v1.POST("/messages", claudeCodeHandlers.ClaudeMessages)
|
||||
}
|
||||
|
||||
// Gemini compatible API routes
|
||||
v1beta := s.engine.Group("/v1beta")
|
||||
v1beta.Use(AuthMiddleware(s.cfg))
|
||||
{
|
||||
v1beta.GET("/models", geminiHandlers.GeminiModels)
|
||||
v1beta.POST("/models/:action", geminiHandlers.GeminiHandler)
|
||||
v1beta.GET("/models/:action", geminiHandlers.GeminiGetHandler)
|
||||
}
|
||||
|
||||
// Root endpoint
|
||||
@@ -90,6 +102,8 @@ func (s *Server) setupRoutes() {
|
||||
},
|
||||
})
|
||||
})
|
||||
s.engine.POST("/v1internal:method", geminiCLIHandlers.CLIHandler)
|
||||
|
||||
}
|
||||
|
||||
// Start begins listening for and serving HTTP requests.
|
||||
@@ -136,18 +150,31 @@ func corsMiddleware() gin.HandlerFunc {
|
||||
}
|
||||
}
|
||||
|
||||
// UpdateClients updates the server's client list and configuration
|
||||
func (s *Server) UpdateClients(clients []*client.Client, cfg *config.Config) {
|
||||
s.cfg = cfg
|
||||
s.handlers.UpdateClients(clients, cfg)
|
||||
log.Infof("server clients and configuration updated: %d clients", len(clients))
|
||||
}
|
||||
|
||||
// AuthMiddleware returns a Gin middleware handler that authenticates requests
|
||||
// using API keys. If no API keys are configured, it allows all requests.
|
||||
func AuthMiddleware(cfg *ServerConfig) gin.HandlerFunc {
|
||||
func AuthMiddleware(cfg *config.Config) gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
if len(cfg.ApiKeys) == 0 {
|
||||
if len(cfg.APIKeys) == 0 {
|
||||
c.Next()
|
||||
return
|
||||
}
|
||||
|
||||
// Get the Authorization header
|
||||
authHeader := c.GetHeader("Authorization")
|
||||
if authHeader == "" {
|
||||
authHeaderGoogle := c.GetHeader("X-Goog-Api-Key")
|
||||
authHeaderAnthropic := c.GetHeader("X-Api-Key")
|
||||
|
||||
// Get the API key from the query parameter
|
||||
apiKeyQuery, _ := c.GetQuery("key")
|
||||
|
||||
if authHeader == "" && authHeaderGoogle == "" && authHeaderAnthropic == "" && apiKeyQuery == "" {
|
||||
c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{
|
||||
"error": "Missing API key",
|
||||
})
|
||||
@@ -165,9 +192,9 @@ func AuthMiddleware(cfg *ServerConfig) gin.HandlerFunc {
|
||||
|
||||
// Find the API key in the in-memory list
|
||||
var foundKey string
|
||||
for i := range cfg.ApiKeys {
|
||||
if cfg.ApiKeys[i] == apiKey {
|
||||
foundKey = cfg.ApiKeys[i]
|
||||
for i := range cfg.APIKeys {
|
||||
if cfg.APIKeys[i] == apiKey || cfg.APIKeys[i] == authHeaderGoogle || cfg.APIKeys[i] == authHeaderAnthropic || cfg.APIKeys[i] == apiKeyQuery {
|
||||
foundKey = cfg.APIKeys[i]
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
169
internal/api/translator/claude/code/request.go
Normal file
169
internal/api/translator/claude/code/request.go
Normal file
@@ -0,0 +1,169 @@
|
||||
// Package code provides request translation functionality for Claude API.
|
||||
// It handles parsing and transforming Claude API requests into the internal client format,
|
||||
// extracting model information, system instructions, message contents, and tool declarations.
|
||||
// The package also performs JSON data cleaning and transformation to ensure compatibility
|
||||
// between Claude API format and the internal client's expected format.
|
||||
package code
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"github.com/luispater/CLIProxyAPI/internal/client"
|
||||
"github.com/tidwall/gjson"
|
||||
"github.com/tidwall/sjson"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// PrepareClaudeRequest parses and transforms a Claude API request into internal client format.
|
||||
// It extracts the model name, system instruction, message contents, and tool declarations
|
||||
// from the raw JSON request and returns them in the format expected by the internal client.
|
||||
func PrepareClaudeRequest(rawJSON []byte) (string, *client.Content, []client.Content, []client.ToolDeclaration) {
|
||||
var pathsToDelete []string
|
||||
root := gjson.ParseBytes(rawJSON)
|
||||
walk(root, "", "additionalProperties", &pathsToDelete)
|
||||
walk(root, "", "$schema", &pathsToDelete)
|
||||
|
||||
var err error
|
||||
for _, p := range pathsToDelete {
|
||||
rawJSON, err = sjson.DeleteBytes(rawJSON, p)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
}
|
||||
rawJSON = bytes.Replace(rawJSON, []byte(`"url":{"type":"string","format":"uri",`), []byte(`"url":{"type":"string",`), -1)
|
||||
|
||||
// log.Debug(string(rawJSON))
|
||||
modelName := "gemini-2.5-pro"
|
||||
modelResult := gjson.GetBytes(rawJSON, "model")
|
||||
if modelResult.Type == gjson.String {
|
||||
modelName = modelResult.String()
|
||||
}
|
||||
|
||||
contents := make([]client.Content, 0)
|
||||
|
||||
var systemInstruction *client.Content
|
||||
|
||||
systemResult := gjson.GetBytes(rawJSON, "system")
|
||||
if systemResult.IsArray() {
|
||||
systemResults := systemResult.Array()
|
||||
systemInstruction = &client.Content{Role: "user", Parts: []client.Part{}}
|
||||
for i := 0; i < len(systemResults); i++ {
|
||||
systemPromptResult := systemResults[i]
|
||||
systemTypePromptResult := systemPromptResult.Get("type")
|
||||
if systemTypePromptResult.Type == gjson.String && systemTypePromptResult.String() == "text" {
|
||||
systemPrompt := systemPromptResult.Get("text").String()
|
||||
systemPart := client.Part{Text: systemPrompt}
|
||||
systemInstruction.Parts = append(systemInstruction.Parts, systemPart)
|
||||
}
|
||||
}
|
||||
if len(systemInstruction.Parts) == 0 {
|
||||
systemInstruction = nil
|
||||
}
|
||||
}
|
||||
|
||||
messagesResult := gjson.GetBytes(rawJSON, "messages")
|
||||
if messagesResult.IsArray() {
|
||||
messageResults := messagesResult.Array()
|
||||
for i := 0; i < len(messageResults); i++ {
|
||||
messageResult := messageResults[i]
|
||||
roleResult := messageResult.Get("role")
|
||||
if roleResult.Type != gjson.String {
|
||||
continue
|
||||
}
|
||||
role := roleResult.String()
|
||||
if role == "assistant" {
|
||||
role = "model"
|
||||
}
|
||||
clientContent := client.Content{Role: role, Parts: []client.Part{}}
|
||||
|
||||
contentsResult := messageResult.Get("content")
|
||||
if contentsResult.IsArray() {
|
||||
contentResults := contentsResult.Array()
|
||||
for j := 0; j < len(contentResults); j++ {
|
||||
contentResult := contentResults[j]
|
||||
contentTypeResult := contentResult.Get("type")
|
||||
if contentTypeResult.Type == gjson.String && contentTypeResult.String() == "text" {
|
||||
prompt := contentResult.Get("text").String()
|
||||
clientContent.Parts = append(clientContent.Parts, client.Part{Text: prompt})
|
||||
} else if contentTypeResult.Type == gjson.String && contentTypeResult.String() == "tool_use" {
|
||||
functionName := contentResult.Get("name").String()
|
||||
functionArgs := contentResult.Get("input").String()
|
||||
var args map[string]any
|
||||
if err = json.Unmarshal([]byte(functionArgs), &args); err == nil {
|
||||
clientContent.Parts = append(clientContent.Parts, client.Part{
|
||||
FunctionCall: &client.FunctionCall{
|
||||
Name: functionName,
|
||||
Args: args,
|
||||
},
|
||||
})
|
||||
}
|
||||
} else if contentTypeResult.Type == gjson.String && contentTypeResult.String() == "tool_result" {
|
||||
toolCallID := contentResult.Get("tool_use_id").String()
|
||||
if toolCallID != "" {
|
||||
funcName := toolCallID
|
||||
toolCallIDs := strings.Split(toolCallID, "-")
|
||||
if len(toolCallIDs) > 1 {
|
||||
funcName = strings.Join(toolCallIDs[0:len(toolCallIDs)-1], "-")
|
||||
}
|
||||
responseData := contentResult.Get("content").String()
|
||||
functionResponse := client.FunctionResponse{Name: funcName, Response: map[string]interface{}{"result": responseData}}
|
||||
clientContent.Parts = append(clientContent.Parts, client.Part{FunctionResponse: &functionResponse})
|
||||
}
|
||||
}
|
||||
}
|
||||
contents = append(contents, clientContent)
|
||||
} else if contentsResult.Type == gjson.String {
|
||||
prompt := contentsResult.String()
|
||||
contents = append(contents, client.Content{Role: role, Parts: []client.Part{{Text: prompt}}})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
var tools []client.ToolDeclaration
|
||||
toolsResult := gjson.GetBytes(rawJSON, "tools")
|
||||
if toolsResult.IsArray() {
|
||||
tools = make([]client.ToolDeclaration, 1)
|
||||
tools[0].FunctionDeclarations = make([]any, 0)
|
||||
toolsResults := toolsResult.Array()
|
||||
for i := 0; i < len(toolsResults); i++ {
|
||||
toolResult := toolsResults[i]
|
||||
inputSchemaResult := toolResult.Get("input_schema")
|
||||
if inputSchemaResult.Exists() && inputSchemaResult.IsObject() {
|
||||
inputSchema := inputSchemaResult.Raw
|
||||
inputSchema, _ = sjson.Delete(inputSchema, "additionalProperties")
|
||||
inputSchema, _ = sjson.Delete(inputSchema, "$schema")
|
||||
|
||||
tool, _ := sjson.Delete(toolResult.Raw, "input_schema")
|
||||
tool, _ = sjson.SetRaw(tool, "parameters", inputSchema)
|
||||
var toolDeclaration any
|
||||
if err = json.Unmarshal([]byte(tool), &toolDeclaration); err == nil {
|
||||
tools[0].FunctionDeclarations = append(tools[0].FunctionDeclarations, toolDeclaration)
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
tools = make([]client.ToolDeclaration, 0)
|
||||
}
|
||||
|
||||
return modelName, systemInstruction, contents, tools
|
||||
}
|
||||
|
||||
func walk(value gjson.Result, path, field string, pathsToDelete *[]string) {
|
||||
switch value.Type {
|
||||
case gjson.JSON:
|
||||
value.ForEach(func(key, val gjson.Result) bool {
|
||||
var childPath string
|
||||
if path == "" {
|
||||
childPath = key.String()
|
||||
} else {
|
||||
childPath = path + "." + key.String()
|
||||
}
|
||||
if key.String() == field {
|
||||
*pathsToDelete = append(*pathsToDelete, childPath)
|
||||
}
|
||||
walk(val, childPath, field, pathsToDelete)
|
||||
return true
|
||||
})
|
||||
case gjson.String, gjson.Number, gjson.True, gjson.False, gjson.Null:
|
||||
}
|
||||
}
|
||||
206
internal/api/translator/claude/code/response.go
Normal file
206
internal/api/translator/claude/code/response.go
Normal file
@@ -0,0 +1,206 @@
|
||||
// Package code provides response translation functionality for Claude API.
|
||||
// This package handles the conversion of backend client responses into Claude-compatible
|
||||
// Server-Sent Events (SSE) format, implementing a sophisticated state machine that manages
|
||||
// different response types including text content, thinking processes, and function calls.
|
||||
// The translation ensures proper sequencing of SSE events and maintains state across
|
||||
// multiple response chunks to provide a seamless streaming experience.
|
||||
package code
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"github.com/tidwall/gjson"
|
||||
"github.com/tidwall/sjson"
|
||||
"time"
|
||||
)
|
||||
|
||||
// ConvertCliToClaude performs sophisticated streaming response format conversion.
|
||||
// This function implements a complex state machine that translates backend client responses
|
||||
// into Claude-compatible Server-Sent Events (SSE) format. It manages different response types
|
||||
// and handles state transitions between content blocks, thinking processes, and function calls.
|
||||
//
|
||||
// Response type states: 0=none, 1=content, 2=thinking, 3=function
|
||||
// The function maintains state across multiple calls to ensure proper SSE event sequencing.
|
||||
func ConvertCliToClaude(rawJSON []byte, isGlAPIKey, hasFirstResponse bool, responseType, responseIndex *int) string {
|
||||
// Normalize the response format for different API key types
|
||||
// Generative Language API keys have a different response structure
|
||||
if isGlAPIKey {
|
||||
rawJSON, _ = sjson.SetRawBytes(rawJSON, "response", rawJSON)
|
||||
}
|
||||
|
||||
// Track whether tools are being used in this response chunk
|
||||
usedTool := false
|
||||
output := ""
|
||||
|
||||
// Initialize the streaming session with a message_start event
|
||||
// This is only sent for the very first response chunk
|
||||
if !hasFirstResponse {
|
||||
output = "event: message_start\n"
|
||||
|
||||
// Create the initial message structure with default values
|
||||
// This follows the Claude API specification for streaming message initialization
|
||||
messageStartTemplate := `{"type": "message_start", "message": {"id": "msg_1nZdL29xx5MUA1yADyHTEsnR8uuvGzszyY", "type": "message", "role": "assistant", "content": [], "model": "claude-3-5-sonnet-20241022", "stop_reason": null, "stop_sequence": null, "usage": {"input_tokens": 0, "output_tokens": 0}}}`
|
||||
|
||||
// Override default values with actual response metadata if available
|
||||
if modelVersionResult := gjson.GetBytes(rawJSON, "response.modelVersion"); modelVersionResult.Exists() {
|
||||
messageStartTemplate, _ = sjson.Set(messageStartTemplate, "message.model", modelVersionResult.String())
|
||||
}
|
||||
if responseIDResult := gjson.GetBytes(rawJSON, "response.responseId"); responseIDResult.Exists() {
|
||||
messageStartTemplate, _ = sjson.Set(messageStartTemplate, "message.id", responseIDResult.String())
|
||||
}
|
||||
output = output + fmt.Sprintf("data: %s\n\n\n", messageStartTemplate)
|
||||
}
|
||||
|
||||
// Process the response parts array from the backend client
|
||||
// Each part can contain text content, thinking content, or function calls
|
||||
partsResult := gjson.GetBytes(rawJSON, "response.candidates.0.content.parts")
|
||||
if partsResult.IsArray() {
|
||||
partResults := partsResult.Array()
|
||||
for i := 0; i < len(partResults); i++ {
|
||||
partResult := partResults[i]
|
||||
|
||||
// Extract the different types of content from each part
|
||||
partTextResult := partResult.Get("text")
|
||||
functionCallResult := partResult.Get("functionCall")
|
||||
|
||||
// Handle text content (both regular content and thinking)
|
||||
if partTextResult.Exists() {
|
||||
// Process thinking content (internal reasoning)
|
||||
if partResult.Get("thought").Bool() {
|
||||
// Continue existing thinking block
|
||||
if *responseType == 2 {
|
||||
output = output + "event: content_block_delta\n"
|
||||
data, _ := sjson.Set(fmt.Sprintf(`{"type":"content_block_delta","index":%d,"delta":{"type":"thinking_delta","thinking":""}}`, *responseIndex), "delta.thinking", partTextResult.String())
|
||||
output = output + fmt.Sprintf("data: %s\n\n\n", data)
|
||||
} else {
|
||||
// Transition from another state to thinking
|
||||
// First, close any existing content block
|
||||
if *responseType != 0 {
|
||||
if *responseType == 2 {
|
||||
output = output + "event: content_block_delta\n"
|
||||
output = output + fmt.Sprintf(`data: {"type":"content_block_delta","index":%d,"delta":{"type":"signature_delta","signature":null}}`, *responseIndex)
|
||||
output = output + "\n\n\n"
|
||||
}
|
||||
output = output + "event: content_block_stop\n"
|
||||
output = output + fmt.Sprintf(`data: {"type":"content_block_stop","index":%d}`, *responseIndex)
|
||||
output = output + "\n\n\n"
|
||||
*responseIndex++
|
||||
}
|
||||
|
||||
// Start a new thinking content block
|
||||
output = output + "event: content_block_start\n"
|
||||
output = output + fmt.Sprintf(`data: {"type":"content_block_start","index":%d,"content_block":{"type":"thinking","thinking":""}}`, *responseIndex)
|
||||
output = output + "\n\n\n"
|
||||
output = output + "event: content_block_delta\n"
|
||||
data, _ := sjson.Set(fmt.Sprintf(`{"type":"content_block_delta","index":%d,"delta":{"type":"thinking_delta","thinking":""}}`, *responseIndex), "delta.thinking", partTextResult.String())
|
||||
output = output + fmt.Sprintf("data: %s\n\n\n", data)
|
||||
*responseType = 2 // Set state to thinking
|
||||
}
|
||||
} else {
|
||||
// Process regular text content (user-visible output)
|
||||
// Continue existing text block
|
||||
if *responseType == 1 {
|
||||
output = output + "event: content_block_delta\n"
|
||||
data, _ := sjson.Set(fmt.Sprintf(`{"type":"content_block_delta","index":%d,"delta":{"type":"text_delta","text":""}}`, *responseIndex), "delta.text", partTextResult.String())
|
||||
output = output + fmt.Sprintf("data: %s\n\n\n", data)
|
||||
} else {
|
||||
// Transition from another state to text content
|
||||
// First, close any existing content block
|
||||
if *responseType != 0 {
|
||||
if *responseType == 2 {
|
||||
output = output + "event: content_block_delta\n"
|
||||
output = output + fmt.Sprintf(`data: {"type":"content_block_delta","index":%d,"delta":{"type":"signature_delta","signature":null}}`, *responseIndex)
|
||||
output = output + "\n\n\n"
|
||||
}
|
||||
output = output + "event: content_block_stop\n"
|
||||
output = output + fmt.Sprintf(`data: {"type":"content_block_stop","index":%d}`, *responseIndex)
|
||||
output = output + "\n\n\n"
|
||||
*responseIndex++
|
||||
}
|
||||
|
||||
// Start a new text content block
|
||||
output = output + "event: content_block_start\n"
|
||||
output = output + fmt.Sprintf(`data: {"type":"content_block_start","index":%d,"content_block":{"type":"text","text":""}}`, *responseIndex)
|
||||
output = output + "\n\n\n"
|
||||
output = output + "event: content_block_delta\n"
|
||||
data, _ := sjson.Set(fmt.Sprintf(`{"type":"content_block_delta","index":%d,"delta":{"type":"text_delta","text":""}}`, *responseIndex), "delta.text", partTextResult.String())
|
||||
output = output + fmt.Sprintf("data: %s\n\n\n", data)
|
||||
*responseType = 1 // Set state to content
|
||||
}
|
||||
}
|
||||
} else if functionCallResult.Exists() {
|
||||
// Handle function/tool calls from the AI model
|
||||
// This processes tool usage requests and formats them for Claude API compatibility
|
||||
usedTool = true
|
||||
fcName := functionCallResult.Get("name").String()
|
||||
|
||||
// Handle state transitions when switching to function calls
|
||||
// Close any existing function call block first
|
||||
if *responseType == 3 {
|
||||
output = output + "event: content_block_stop\n"
|
||||
output = output + fmt.Sprintf(`data: {"type":"content_block_stop","index":%d}`, *responseIndex)
|
||||
output = output + "\n\n\n"
|
||||
*responseIndex++
|
||||
*responseType = 0
|
||||
}
|
||||
|
||||
// Special handling for thinking state transition
|
||||
if *responseType == 2 {
|
||||
output = output + "event: content_block_delta\n"
|
||||
output = output + fmt.Sprintf(`data: {"type":"content_block_delta","index":%d,"delta":{"type":"signature_delta","signature":null}}`, *responseIndex)
|
||||
output = output + "\n\n\n"
|
||||
}
|
||||
|
||||
// Close any other existing content block
|
||||
if *responseType != 0 {
|
||||
output = output + "event: content_block_stop\n"
|
||||
output = output + fmt.Sprintf(`data: {"type":"content_block_stop","index":%d}`, *responseIndex)
|
||||
output = output + "\n\n\n"
|
||||
*responseIndex++
|
||||
}
|
||||
|
||||
// Start a new tool use content block
|
||||
// This creates the structure for a function call in Claude format
|
||||
output = output + "event: content_block_start\n"
|
||||
|
||||
// Create the tool use block with unique ID and function details
|
||||
data := fmt.Sprintf(`{"type":"content_block_start","index":%d,"content_block":{"type":"tool_use","id":"","name":"","input":{}}}`, *responseIndex)
|
||||
data, _ = sjson.Set(data, "content_block.id", fmt.Sprintf("%s-%d", fcName, time.Now().UnixNano()))
|
||||
data, _ = sjson.Set(data, "content_block.name", fcName)
|
||||
output = output + fmt.Sprintf("data: %s\n\n\n", data)
|
||||
|
||||
if fcArgsResult := functionCallResult.Get("args"); fcArgsResult.Exists() {
|
||||
output = output + "event: content_block_delta\n"
|
||||
data, _ = sjson.Set(fmt.Sprintf(`{"type":"content_block_delta","index":%d,"delta":{"type":"input_json_delta","partial_json":""}}`, *responseIndex), "delta.partial_json", fcArgsResult.Raw)
|
||||
output = output + fmt.Sprintf("data: %s\n\n\n", data)
|
||||
}
|
||||
*responseType = 3
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
usageResult := gjson.GetBytes(rawJSON, "response.usageMetadata")
|
||||
if usageResult.Exists() && bytes.Contains(rawJSON, []byte(`"finishReason"`)) {
|
||||
if candidatesTokenCountResult := usageResult.Get("candidatesTokenCount"); candidatesTokenCountResult.Exists() {
|
||||
output = output + "event: content_block_stop\n"
|
||||
output = output + fmt.Sprintf(`data: {"type":"content_block_stop","index":%d}`, *responseIndex)
|
||||
output = output + "\n\n\n"
|
||||
|
||||
output = output + "event: message_delta\n"
|
||||
output = output + `data: `
|
||||
|
||||
template := `{"type":"message_delta","delta":{"stop_reason":"end_turn","stop_sequence":null},"usage":{"input_tokens":0,"output_tokens":0}}`
|
||||
if usedTool {
|
||||
template = `{"type":"message_delta","delta":{"stop_reason":"tool_use","stop_sequence":null},"usage":{"input_tokens":0,"output_tokens":0}}`
|
||||
}
|
||||
|
||||
thoughtsTokenCount := usageResult.Get("thoughtsTokenCount").Int()
|
||||
template, _ = sjson.Set(template, "usage.output_tokens", candidatesTokenCountResult.Int()+thoughtsTokenCount)
|
||||
template, _ = sjson.Set(template, "usage.input_tokens", usageResult.Get("promptTokenCount").Int())
|
||||
|
||||
output = output + template + "\n\n\n"
|
||||
}
|
||||
}
|
||||
|
||||
return output
|
||||
}
|
||||
185
internal/api/translator/gemini/cli/request.go
Normal file
185
internal/api/translator/gemini/cli/request.go
Normal file
@@ -0,0 +1,185 @@
|
||||
// Package cli provides request translation functionality for Gemini CLI API.
|
||||
// It handles the conversion and formatting of CLI tool responses, specifically
|
||||
// transforming between different JSON formats to ensure proper conversation flow
|
||||
// and API compatibility. The package focuses on intelligently grouping function
|
||||
// calls with their corresponding responses, converting from linear format to
|
||||
// grouped format where function calls and responses are properly associated.
|
||||
package cli
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"github.com/tidwall/gjson"
|
||||
"github.com/tidwall/sjson"
|
||||
)
|
||||
|
||||
// FunctionCallGroup represents a group of function calls and their responses
|
||||
type FunctionCallGroup struct {
|
||||
ModelContent map[string]interface{}
|
||||
FunctionCalls []gjson.Result
|
||||
ResponsesNeeded int
|
||||
}
|
||||
|
||||
// FixCLIToolResponse performs sophisticated tool response format conversion and grouping.
|
||||
// This function transforms the CLI tool response format by intelligently grouping function calls
|
||||
// with their corresponding responses, ensuring proper conversation flow and API compatibility.
|
||||
// It converts from a linear format (1.json) to a grouped format (2.json) where function calls
|
||||
// and their responses are properly associated and structured.
|
||||
func FixCLIToolResponse(input string) (string, error) {
|
||||
// Parse the input JSON to extract the conversation structure
|
||||
parsed := gjson.Parse(input)
|
||||
|
||||
// Extract the contents array which contains the conversation messages
|
||||
contents := parsed.Get("request.contents")
|
||||
if !contents.Exists() {
|
||||
// log.Debugf(input)
|
||||
return input, fmt.Errorf("contents not found in input")
|
||||
}
|
||||
|
||||
// Initialize data structures for processing and grouping
|
||||
var newContents []interface{} // Final processed contents array
|
||||
var pendingGroups []*FunctionCallGroup // Groups awaiting completion with responses
|
||||
var collectedResponses []gjson.Result // Standalone responses to be matched
|
||||
|
||||
// Process each content object in the conversation
|
||||
// This iterates through messages and groups function calls with their responses
|
||||
contents.ForEach(func(key, value gjson.Result) bool {
|
||||
role := value.Get("role").String()
|
||||
parts := value.Get("parts")
|
||||
|
||||
// Check if this content has function responses
|
||||
var responsePartsInThisContent []gjson.Result
|
||||
parts.ForEach(func(_, part gjson.Result) bool {
|
||||
if part.Get("functionResponse").Exists() {
|
||||
responsePartsInThisContent = append(responsePartsInThisContent, part)
|
||||
}
|
||||
return true
|
||||
})
|
||||
|
||||
// If this content has function responses, collect them
|
||||
if len(responsePartsInThisContent) > 0 {
|
||||
collectedResponses = append(collectedResponses, responsePartsInThisContent...)
|
||||
|
||||
// Check if any pending groups can be satisfied
|
||||
for i := len(pendingGroups) - 1; i >= 0; i-- {
|
||||
group := pendingGroups[i]
|
||||
if len(collectedResponses) >= group.ResponsesNeeded {
|
||||
// Take the needed responses for this group
|
||||
groupResponses := collectedResponses[:group.ResponsesNeeded]
|
||||
collectedResponses = collectedResponses[group.ResponsesNeeded:]
|
||||
|
||||
// Create merged function response content
|
||||
var responseParts []interface{}
|
||||
for _, response := range groupResponses {
|
||||
var responseMap map[string]interface{}
|
||||
errUnmarshal := json.Unmarshal([]byte(response.Raw), &responseMap)
|
||||
if errUnmarshal != nil {
|
||||
log.Warnf("failed to unmarshal function response: %v\n", errUnmarshal)
|
||||
continue
|
||||
}
|
||||
responseParts = append(responseParts, responseMap)
|
||||
}
|
||||
|
||||
if len(responseParts) > 0 {
|
||||
functionResponseContent := map[string]interface{}{
|
||||
"parts": responseParts,
|
||||
"role": "function",
|
||||
}
|
||||
newContents = append(newContents, functionResponseContent)
|
||||
}
|
||||
|
||||
// Remove this group as it's been satisfied
|
||||
pendingGroups = append(pendingGroups[:i], pendingGroups[i+1:]...)
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
return true // Skip adding this content, responses are merged
|
||||
}
|
||||
|
||||
// If this is a model with function calls, create a new group
|
||||
if role == "model" {
|
||||
var functionCallsInThisModel []gjson.Result
|
||||
parts.ForEach(func(_, part gjson.Result) bool {
|
||||
if part.Get("functionCall").Exists() {
|
||||
functionCallsInThisModel = append(functionCallsInThisModel, part)
|
||||
}
|
||||
return true
|
||||
})
|
||||
|
||||
if len(functionCallsInThisModel) > 0 {
|
||||
// Add the model content
|
||||
var contentMap map[string]interface{}
|
||||
errUnmarshal := json.Unmarshal([]byte(value.Raw), &contentMap)
|
||||
if errUnmarshal != nil {
|
||||
log.Warnf("failed to unmarshal model content: %v\n", errUnmarshal)
|
||||
return true
|
||||
}
|
||||
newContents = append(newContents, contentMap)
|
||||
|
||||
// Create a new group for tracking responses
|
||||
group := &FunctionCallGroup{
|
||||
ModelContent: contentMap,
|
||||
FunctionCalls: functionCallsInThisModel,
|
||||
ResponsesNeeded: len(functionCallsInThisModel),
|
||||
}
|
||||
pendingGroups = append(pendingGroups, group)
|
||||
} else {
|
||||
// Regular model content without function calls
|
||||
var contentMap map[string]interface{}
|
||||
errUnmarshal := json.Unmarshal([]byte(value.Raw), &contentMap)
|
||||
if errUnmarshal != nil {
|
||||
log.Warnf("failed to unmarshal content: %v\n", errUnmarshal)
|
||||
return true
|
||||
}
|
||||
newContents = append(newContents, contentMap)
|
||||
}
|
||||
} else {
|
||||
// Non-model content (user, etc.)
|
||||
var contentMap map[string]interface{}
|
||||
errUnmarshal := json.Unmarshal([]byte(value.Raw), &contentMap)
|
||||
if errUnmarshal != nil {
|
||||
log.Warnf("failed to unmarshal content: %v\n", errUnmarshal)
|
||||
return true
|
||||
}
|
||||
newContents = append(newContents, contentMap)
|
||||
}
|
||||
|
||||
return true
|
||||
})
|
||||
|
||||
// Handle any remaining pending groups with remaining responses
|
||||
for _, group := range pendingGroups {
|
||||
if len(collectedResponses) >= group.ResponsesNeeded {
|
||||
groupResponses := collectedResponses[:group.ResponsesNeeded]
|
||||
collectedResponses = collectedResponses[group.ResponsesNeeded:]
|
||||
|
||||
var responseParts []interface{}
|
||||
for _, response := range groupResponses {
|
||||
var responseMap map[string]interface{}
|
||||
errUnmarshal := json.Unmarshal([]byte(response.Raw), &responseMap)
|
||||
if errUnmarshal != nil {
|
||||
log.Warnf("failed to unmarshal function response: %v\n", errUnmarshal)
|
||||
continue
|
||||
}
|
||||
responseParts = append(responseParts, responseMap)
|
||||
}
|
||||
|
||||
if len(responseParts) > 0 {
|
||||
functionResponseContent := map[string]interface{}{
|
||||
"parts": responseParts,
|
||||
"role": "function",
|
||||
}
|
||||
newContents = append(newContents, functionResponseContent)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Update the original JSON with the new contents
|
||||
result := input
|
||||
newContentsJSON, _ := json.Marshal(newContents)
|
||||
result, _ = sjson.Set(result, "request.contents", json.RawMessage(newContentsJSON))
|
||||
|
||||
return result, nil
|
||||
}
|
||||
@@ -1,3 +1,6 @@
|
||||
// Package translator provides data translation and format conversion utilities
|
||||
// for the CLI Proxy API. It includes MIME type mappings and other translation
|
||||
// functions used across different API endpoints.
|
||||
package translator
|
||||
|
||||
// MimeTypes is a comprehensive map of file extensions to their corresponding MIME types.
|
||||
|
||||
@@ -1,7 +1,12 @@
|
||||
package translator
|
||||
// Package openai provides request translation functionality for OpenAI API.
|
||||
// It handles the conversion of OpenAI-compatible request formats to the internal
|
||||
// format expected by the backend client, including parsing messages, roles,
|
||||
// content types (text, image, file), and tool calls.
|
||||
package openai
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"github.com/luispater/CLIProxyAPI/internal/api/translator"
|
||||
"strings"
|
||||
|
||||
"github.com/luispater/CLIProxyAPI/internal/client"
|
||||
@@ -12,17 +17,60 @@ import (
|
||||
// PrepareRequest translates a raw JSON request from an OpenAI-compatible format
|
||||
// to the internal format expected by the backend client. It parses messages,
|
||||
// roles, content types (text, image, file), and tool calls.
|
||||
func PrepareRequest(rawJson []byte) (string, []client.Content, []client.ToolDeclaration) {
|
||||
func PrepareRequest(rawJSON []byte) (string, *client.Content, []client.Content, []client.ToolDeclaration) {
|
||||
// Extract the model name from the request, defaulting to "gemini-2.5-pro".
|
||||
modelName := "gemini-2.5-pro"
|
||||
modelResult := gjson.GetBytes(rawJson, "model")
|
||||
modelResult := gjson.GetBytes(rawJSON, "model")
|
||||
if modelResult.Type == gjson.String {
|
||||
modelName = modelResult.String()
|
||||
}
|
||||
|
||||
// Process the array of messages.
|
||||
// Initialize data structures for processing conversation messages
|
||||
// contents: stores the processed conversation history
|
||||
// systemInstruction: stores system-level instructions separate from conversation
|
||||
contents := make([]client.Content, 0)
|
||||
messagesResult := gjson.GetBytes(rawJson, "messages")
|
||||
var systemInstruction *client.Content
|
||||
messagesResult := gjson.GetBytes(rawJSON, "messages")
|
||||
|
||||
// Pre-process tool responses to create a lookup map
|
||||
// This first pass collects all tool responses so they can be matched with their corresponding calls
|
||||
toolItems := make(map[string]*client.FunctionResponse)
|
||||
if messagesResult.IsArray() {
|
||||
messagesResults := messagesResult.Array()
|
||||
for i := 0; i < len(messagesResults); i++ {
|
||||
messageResult := messagesResults[i]
|
||||
roleResult := messageResult.Get("role")
|
||||
if roleResult.Type != gjson.String {
|
||||
continue
|
||||
}
|
||||
contentResult := messageResult.Get("content")
|
||||
|
||||
// Extract tool responses for later matching with function calls
|
||||
if roleResult.String() == "tool" {
|
||||
toolCallID := messageResult.Get("tool_call_id").String()
|
||||
if toolCallID != "" {
|
||||
var responseData string
|
||||
// Handle both string and object-based tool response formats
|
||||
if contentResult.Type == gjson.String {
|
||||
responseData = contentResult.String()
|
||||
} else if contentResult.IsObject() && contentResult.Get("type").String() == "text" {
|
||||
responseData = contentResult.Get("text").String()
|
||||
}
|
||||
|
||||
// Clean up tool call ID by removing timestamp suffix
|
||||
// This normalizes IDs for consistent matching between calls and responses
|
||||
toolCallIDs := strings.Split(toolCallID, "-")
|
||||
strings.Join(toolCallIDs, "-")
|
||||
newToolCallID := strings.Join(toolCallIDs[:len(toolCallIDs)-1], "-")
|
||||
|
||||
// Create function response object with normalized ID and response data
|
||||
functionResponse := client.FunctionResponse{Name: newToolCallID, Response: map[string]interface{}{"result": responseData}}
|
||||
toolItems[toolCallID] = &functionResponse
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if messagesResult.IsArray() {
|
||||
messagesResults := messagesResult.Array()
|
||||
for i := 0; i < len(messagesResults); i++ {
|
||||
@@ -37,13 +85,11 @@ func PrepareRequest(rawJson []byte) (string, []client.Content, []client.ToolDecl
|
||||
// System messages are converted to a user message followed by a model's acknowledgment.
|
||||
case "system":
|
||||
if contentResult.Type == gjson.String {
|
||||
contents = append(contents, client.Content{Role: "user", Parts: []client.Part{{Text: contentResult.String()}}})
|
||||
contents = append(contents, client.Content{Role: "model", Parts: []client.Part{{Text: "Understood. I will follow these instructions and use my tools to assist you."}}})
|
||||
systemInstruction = &client.Content{Role: "user", Parts: []client.Part{{Text: contentResult.String()}}}
|
||||
} else if contentResult.IsObject() {
|
||||
// Handle object-based system messages.
|
||||
if contentResult.Get("type").String() == "text" {
|
||||
contents = append(contents, client.Content{Role: "user", Parts: []client.Part{{Text: contentResult.Get("text").String()}}})
|
||||
contents = append(contents, client.Content{Role: "model", Parts: []client.Part{{Text: "Understood. I will follow these instructions and use my tools to assist you."}}})
|
||||
systemInstruction = &client.Content{Role: "user", Parts: []client.Part{{Text: contentResult.Get("text").String()}}}
|
||||
}
|
||||
}
|
||||
// User messages can contain simple text or a multi-part body.
|
||||
@@ -80,7 +126,7 @@ func PrepareRequest(rawJson []byte) (string, []client.Content, []client.ToolDecl
|
||||
if split := strings.Split(filename, "."); len(split) > 1 {
|
||||
ext = split[len(split)-1]
|
||||
}
|
||||
if mimeType, ok := MimeTypes[ext]; ok {
|
||||
if mimeType, ok := translator.MimeTypes[ext]; ok {
|
||||
parts = append(parts, client.Part{InlineData: &client.InlineData{
|
||||
MimeType: mimeType,
|
||||
Data: fileData,
|
||||
@@ -92,53 +138,70 @@ func PrepareRequest(rawJson []byte) (string, []client.Content, []client.ToolDecl
|
||||
}
|
||||
contents = append(contents, client.Content{Role: "user", Parts: parts})
|
||||
}
|
||||
// Assistant messages can contain text or tool calls.
|
||||
// Assistant messages can contain text responses or tool calls
|
||||
// In the internal format, assistant messages are converted to "model" role
|
||||
case "assistant":
|
||||
if contentResult.Type == gjson.String {
|
||||
// Simple text response from the assistant
|
||||
contents = append(contents, client.Content{Role: "model", Parts: []client.Part{{Text: contentResult.String()}}})
|
||||
} else if !contentResult.Exists() || contentResult.Type == gjson.Null {
|
||||
// Handle tool calls made by the assistant.
|
||||
// Handle complex tool calls made by the assistant
|
||||
// This processes function calls and matches them with their responses
|
||||
functionIDs := make([]string, 0)
|
||||
toolCallsResult := messageResult.Get("tool_calls")
|
||||
if toolCallsResult.IsArray() {
|
||||
parts := make([]client.Part, 0)
|
||||
tcsResult := toolCallsResult.Array()
|
||||
|
||||
// Process each tool call in the assistant's message
|
||||
for j := 0; j < len(tcsResult); j++ {
|
||||
tcResult := tcsResult[j]
|
||||
|
||||
// Extract function call details
|
||||
functionID := tcResult.Get("id").String()
|
||||
functionIDs = append(functionIDs, functionID)
|
||||
|
||||
functionName := tcResult.Get("function.name").String()
|
||||
functionArgs := tcResult.Get("function.arguments").String()
|
||||
|
||||
// Parse function arguments from JSON string to map
|
||||
var args map[string]any
|
||||
if err := json.Unmarshal([]byte(functionArgs), &args); err == nil {
|
||||
contents = append(contents, client.Content{
|
||||
Role: "model", Parts: []client.Part{{
|
||||
FunctionCall: &client.FunctionCall{
|
||||
Name: functionName,
|
||||
Args: args,
|
||||
},
|
||||
}},
|
||||
parts = append(parts, client.Part{
|
||||
FunctionCall: &client.FunctionCall{
|
||||
Name: functionName,
|
||||
Args: args,
|
||||
},
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// Add the model's function calls to the conversation
|
||||
if len(parts) > 0 {
|
||||
contents = append(contents, client.Content{
|
||||
Role: "model", Parts: parts,
|
||||
})
|
||||
|
||||
// Create a separate tool response message with the collected responses
|
||||
// This matches function calls with their corresponding responses
|
||||
toolParts := make([]client.Part, 0)
|
||||
for _, functionID := range functionIDs {
|
||||
if functionResponse, ok := toolItems[functionID]; ok {
|
||||
toolParts = append(toolParts, client.Part{FunctionResponse: functionResponse})
|
||||
}
|
||||
}
|
||||
// Add the tool responses as a separate message in the conversation
|
||||
contents = append(contents, client.Content{Role: "tool", Parts: toolParts})
|
||||
}
|
||||
}
|
||||
}
|
||||
// Tool messages contain the output of a tool call.
|
||||
case "tool":
|
||||
toolCallID := messageResult.Get("tool_call_id").String()
|
||||
if toolCallID != "" {
|
||||
var responseData string
|
||||
if contentResult.Type == gjson.String {
|
||||
responseData = contentResult.String()
|
||||
} else if contentResult.IsObject() && contentResult.Get("type").String() == "text" {
|
||||
responseData = contentResult.Get("text").String()
|
||||
}
|
||||
functionResponse := client.FunctionResponse{Name: toolCallID, Response: map[string]interface{}{"result": responseData}}
|
||||
contents = append(contents, client.Content{Role: "tool", Parts: []client.Part{{FunctionResponse: &functionResponse}}})
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Translate the tool declarations from the request.
|
||||
var tools []client.ToolDeclaration
|
||||
toolsResult := gjson.GetBytes(rawJson, "tools")
|
||||
toolsResult := gjson.GetBytes(rawJSON, "tools")
|
||||
if toolsResult.IsArray() {
|
||||
tools = make([]client.ToolDeclaration, 1)
|
||||
tools[0].FunctionDeclarations = make([]any, 0)
|
||||
@@ -159,5 +222,5 @@ func PrepareRequest(rawJson []byte) (string, []client.Content, []client.ToolDecl
|
||||
tools = make([]client.ToolDeclaration, 0)
|
||||
}
|
||||
|
||||
return modelName, contents, tools
|
||||
return modelName, systemInstruction, contents, tools
|
||||
}
|
||||
@@ -1,6 +1,14 @@
|
||||
package translator
|
||||
// Package openai provides response translation functionality for converting between
|
||||
// different API response formats and OpenAI-compatible formats. It handles both
|
||||
// streaming and non-streaming responses, transforming backend client responses
|
||||
// into OpenAI Server-Sent Events (SSE) format and standard JSON response formats.
|
||||
// The package supports content translation, function calls, usage metadata,
|
||||
// and various response attributes while maintaining compatibility with OpenAI API
|
||||
// specifications.
|
||||
package openai
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/tidwall/gjson"
|
||||
@@ -10,38 +18,43 @@ import (
|
||||
// ConvertCliToOpenAI translates a single chunk of a streaming response from the
|
||||
// backend client format to the OpenAI Server-Sent Events (SSE) format.
|
||||
// It returns an empty string if the chunk contains no useful data.
|
||||
func ConvertCliToOpenAI(rawJson []byte) string {
|
||||
func ConvertCliToOpenAI(rawJSON []byte, unixTimestamp int64, isGlAPIKey bool) string {
|
||||
if isGlAPIKey {
|
||||
rawJSON, _ = sjson.SetRawBytes(rawJSON, "response", rawJSON)
|
||||
}
|
||||
|
||||
// Initialize the OpenAI SSE template.
|
||||
template := `{"id":"","object":"chat.completion.chunk","created":12345,"model":"model","choices":[{"index":0,"delta":{"role":null,"content":null,"reasoning_content":null,"tool_calls":null},"finish_reason":null,"native_finish_reason":null}]}`
|
||||
|
||||
// Extract and set the model version.
|
||||
if modelVersionResult := gjson.GetBytes(rawJson, "response.modelVersion"); modelVersionResult.Exists() {
|
||||
if modelVersionResult := gjson.GetBytes(rawJSON, "response.modelVersion"); modelVersionResult.Exists() {
|
||||
template, _ = sjson.Set(template, "model", modelVersionResult.String())
|
||||
}
|
||||
|
||||
// Extract and set the creation timestamp.
|
||||
if createTimeResult := gjson.GetBytes(rawJson, "response.createTime"); createTimeResult.Exists() {
|
||||
if createTimeResult := gjson.GetBytes(rawJSON, "response.createTime"); createTimeResult.Exists() {
|
||||
t, err := time.Parse(time.RFC3339Nano, createTimeResult.String())
|
||||
unixTimestamp := time.Now().Unix()
|
||||
if err == nil {
|
||||
unixTimestamp = t.Unix()
|
||||
}
|
||||
template, _ = sjson.Set(template, "created", unixTimestamp)
|
||||
} else {
|
||||
template, _ = sjson.Set(template, "created", unixTimestamp)
|
||||
}
|
||||
|
||||
// Extract and set the response ID.
|
||||
if responseIdResult := gjson.GetBytes(rawJson, "response.responseId"); responseIdResult.Exists() {
|
||||
template, _ = sjson.Set(template, "id", responseIdResult.String())
|
||||
if responseIDResult := gjson.GetBytes(rawJSON, "response.responseId"); responseIDResult.Exists() {
|
||||
template, _ = sjson.Set(template, "id", responseIDResult.String())
|
||||
}
|
||||
|
||||
// Extract and set the finish reason.
|
||||
if finishReasonResult := gjson.GetBytes(rawJson, "response.candidates.0.finishReason"); finishReasonResult.Exists() {
|
||||
if finishReasonResult := gjson.GetBytes(rawJSON, "response.candidates.0.finishReason"); finishReasonResult.Exists() {
|
||||
template, _ = sjson.Set(template, "choices.0.finish_reason", finishReasonResult.String())
|
||||
template, _ = sjson.Set(template, "choices.0.native_finish_reason", finishReasonResult.String())
|
||||
}
|
||||
|
||||
// Extract and set usage metadata (token counts).
|
||||
if usageResult := gjson.GetBytes(rawJson, "response.usageMetadata"); usageResult.Exists() {
|
||||
if usageResult := gjson.GetBytes(rawJSON, "response.usageMetadata"); usageResult.Exists() {
|
||||
if candidatesTokenCountResult := usageResult.Get("candidatesTokenCount"); candidatesTokenCountResult.Exists() {
|
||||
template, _ = sjson.Set(template, "usage.completion_tokens", candidatesTokenCountResult.Int())
|
||||
}
|
||||
@@ -57,32 +70,40 @@ func ConvertCliToOpenAI(rawJson []byte) string {
|
||||
}
|
||||
|
||||
// Process the main content part of the response.
|
||||
partResult := gjson.GetBytes(rawJson, "response.candidates.0.content.parts.0")
|
||||
partTextResult := partResult.Get("text")
|
||||
functionCallResult := partResult.Get("functionCall")
|
||||
partsResult := gjson.GetBytes(rawJSON, "response.candidates.0.content.parts")
|
||||
if partsResult.IsArray() {
|
||||
partResults := partsResult.Array()
|
||||
for i := 0; i < len(partResults); i++ {
|
||||
partResult := partResults[i]
|
||||
partTextResult := partResult.Get("text")
|
||||
functionCallResult := partResult.Get("functionCall")
|
||||
|
||||
if partTextResult.Exists() {
|
||||
// Handle text content, distinguishing between regular content and reasoning/thoughts.
|
||||
if partResult.Get("thought").Bool() {
|
||||
template, _ = sjson.Set(template, "choices.0.delta.reasoning_content", partTextResult.String())
|
||||
} else {
|
||||
template, _ = sjson.Set(template, "choices.0.delta.content", partTextResult.String())
|
||||
if partTextResult.Exists() {
|
||||
// Handle text content, distinguishing between regular content and reasoning/thoughts.
|
||||
if partResult.Get("thought").Bool() {
|
||||
template, _ = sjson.Set(template, "choices.0.delta.reasoning_content", partTextResult.String())
|
||||
} else {
|
||||
template, _ = sjson.Set(template, "choices.0.delta.content", partTextResult.String())
|
||||
}
|
||||
template, _ = sjson.Set(template, "choices.0.delta.role", "assistant")
|
||||
} else if functionCallResult.Exists() {
|
||||
// Handle function call content.
|
||||
toolCallsResult := gjson.Get(template, "choices.0.delta.tool_calls")
|
||||
if !toolCallsResult.Exists() || !toolCallsResult.IsArray() {
|
||||
template, _ = sjson.SetRaw(template, "choices.0.delta.tool_calls", `[]`)
|
||||
}
|
||||
|
||||
functionCallTemplate := `{"id": "","type": "function","function": {"name": "","arguments": ""}}`
|
||||
fcName := functionCallResult.Get("name").String()
|
||||
functionCallTemplate, _ = sjson.Set(functionCallTemplate, "id", fmt.Sprintf("%s-%d", fcName, time.Now().UnixNano()))
|
||||
functionCallTemplate, _ = sjson.Set(functionCallTemplate, "function.name", fcName)
|
||||
if fcArgsResult := functionCallResult.Get("args"); fcArgsResult.Exists() {
|
||||
functionCallTemplate, _ = sjson.Set(functionCallTemplate, "function.arguments", fcArgsResult.Raw)
|
||||
}
|
||||
template, _ = sjson.Set(template, "choices.0.delta.role", "assistant")
|
||||
template, _ = sjson.SetRaw(template, "choices.0.message.tool_calls.-1", functionCallTemplate)
|
||||
}
|
||||
}
|
||||
template, _ = sjson.Set(template, "choices.0.delta.role", "assistant")
|
||||
} else if functionCallResult.Exists() {
|
||||
// Handle function call content.
|
||||
functionCallTemplate := `[{"id": "","type": "function","function": {"name": "","arguments": ""}}]`
|
||||
fcName := functionCallResult.Get("name").String()
|
||||
functionCallTemplate, _ = sjson.Set(functionCallTemplate, "0.id", fcName)
|
||||
functionCallTemplate, _ = sjson.Set(functionCallTemplate, "0.function.name", fcName)
|
||||
if fcArgsResult := functionCallResult.Get("args"); fcArgsResult.Exists() {
|
||||
functionCallTemplate, _ = sjson.Set(functionCallTemplate, "0.function.arguments", fcArgsResult.Raw)
|
||||
}
|
||||
template, _ = sjson.Set(template, "choices.0.delta.role", "assistant")
|
||||
template, _ = sjson.SetRaw(template, "choices.0.delta.tool_calls", functionCallTemplate)
|
||||
} else {
|
||||
// If no usable content is found, return an empty string.
|
||||
return ""
|
||||
}
|
||||
|
||||
return template
|
||||
@@ -90,29 +111,35 @@ func ConvertCliToOpenAI(rawJson []byte) string {
|
||||
|
||||
// ConvertCliToOpenAINonStream aggregates response from the backend client
|
||||
// convert a single, non-streaming OpenAI-compatible JSON response.
|
||||
func ConvertCliToOpenAINonStream(rawJson []byte) string {
|
||||
func ConvertCliToOpenAINonStream(rawJSON []byte, unixTimestamp int64, isGlAPIKey bool) string {
|
||||
if isGlAPIKey {
|
||||
rawJSON, _ = sjson.SetRawBytes(rawJSON, "response", rawJSON)
|
||||
}
|
||||
template := `{"id":"","object":"chat.completion","created":123456,"model":"model","choices":[{"index":0,"message":{"role":"assistant","content":null,"reasoning_content":null,"tool_calls":null},"finish_reason":null,"native_finish_reason":null}]}`
|
||||
if modelVersionResult := gjson.GetBytes(rawJson, "response.modelVersion"); modelVersionResult.Exists() {
|
||||
if modelVersionResult := gjson.GetBytes(rawJSON, "response.modelVersion"); modelVersionResult.Exists() {
|
||||
template, _ = sjson.Set(template, "model", modelVersionResult.String())
|
||||
}
|
||||
if createTimeResult := gjson.GetBytes(rawJson, "response.createTime"); createTimeResult.Exists() {
|
||||
|
||||
if createTimeResult := gjson.GetBytes(rawJSON, "response.createTime"); createTimeResult.Exists() {
|
||||
t, err := time.Parse(time.RFC3339Nano, createTimeResult.String())
|
||||
unixTimestamp := time.Now().Unix()
|
||||
if err == nil {
|
||||
unixTimestamp = t.Unix()
|
||||
}
|
||||
template, _ = sjson.Set(template, "created", unixTimestamp)
|
||||
}
|
||||
if responseIdResult := gjson.GetBytes(rawJson, "response.responseId"); responseIdResult.Exists() {
|
||||
template, _ = sjson.Set(template, "id", responseIdResult.String())
|
||||
} else {
|
||||
template, _ = sjson.Set(template, "created", unixTimestamp)
|
||||
}
|
||||
|
||||
if finishReasonResult := gjson.GetBytes(rawJson, "response.candidates.0.finishReason"); finishReasonResult.Exists() {
|
||||
if responseIDResult := gjson.GetBytes(rawJSON, "response.responseId"); responseIDResult.Exists() {
|
||||
template, _ = sjson.Set(template, "id", responseIDResult.String())
|
||||
}
|
||||
|
||||
if finishReasonResult := gjson.GetBytes(rawJSON, "response.candidates.0.finishReason"); finishReasonResult.Exists() {
|
||||
template, _ = sjson.Set(template, "choices.0.finish_reason", finishReasonResult.String())
|
||||
template, _ = sjson.Set(template, "choices.0.native_finish_reason", finishReasonResult.String())
|
||||
}
|
||||
|
||||
if usageResult := gjson.GetBytes(rawJson, "response.usageMetadata"); usageResult.Exists() {
|
||||
if usageResult := gjson.GetBytes(rawJSON, "response.usageMetadata"); usageResult.Exists() {
|
||||
if candidatesTokenCountResult := usageResult.Get("candidatesTokenCount"); candidatesTokenCountResult.Exists() {
|
||||
template, _ = sjson.Set(template, "usage.completion_tokens", candidatesTokenCountResult.Int())
|
||||
}
|
||||
@@ -128,7 +155,7 @@ func ConvertCliToOpenAINonStream(rawJson []byte) string {
|
||||
}
|
||||
|
||||
// Process the main content part of the response.
|
||||
partsResult := gjson.GetBytes(rawJson, "response.candidates.0.content.parts")
|
||||
partsResult := gjson.GetBytes(rawJSON, "response.candidates.0.content.parts")
|
||||
if partsResult.IsArray() {
|
||||
partsResults := partsResult.Array()
|
||||
for i := 0; i < len(partsResults); i++ {
|
||||
@@ -152,7 +179,7 @@ func ConvertCliToOpenAINonStream(rawJson []byte) string {
|
||||
}
|
||||
functionCallItemTemplate := `{"id": "","type": "function","function": {"name": "","arguments": ""}}`
|
||||
fcName := functionCallResult.Get("name").String()
|
||||
functionCallItemTemplate, _ = sjson.Set(functionCallItemTemplate, "id", fcName)
|
||||
functionCallItemTemplate, _ = sjson.Set(functionCallItemTemplate, "id", fmt.Sprintf("%s-%d", fcName, time.Now().UnixNano()))
|
||||
functionCallItemTemplate, _ = sjson.Set(functionCallItemTemplate, "function.name", fcName)
|
||||
if fcArgsResult := functionCallResult.Get("args"); fcArgsResult.Exists() {
|
||||
functionCallItemTemplate, _ = sjson.Set(functionCallItemTemplate, "function.arguments", fcArgsResult.Raw)
|
||||
@@ -1,3 +1,6 @@
|
||||
// Package auth provides OAuth2 authentication functionality for Google Cloud APIs.
|
||||
// It handles the complete OAuth2 flow including token storage, web-based authentication,
|
||||
// proxy support, and automatic token refresh. The package supports both SOCKS5 and HTTP/HTTPS proxies.
|
||||
package auth
|
||||
|
||||
import (
|
||||
@@ -39,7 +42,7 @@ var (
|
||||
// initiating a new web-based OAuth flow if necessary, and refreshing tokens.
|
||||
func GetAuthenticatedClient(ctx context.Context, ts *TokenStorage, cfg *config.Config) (*http.Client, error) {
|
||||
// Configure proxy settings for the HTTP client if a proxy URL is provided.
|
||||
proxyURL, err := url.Parse(cfg.ProxyUrl)
|
||||
proxyURL, err := url.Parse(cfg.ProxyURL)
|
||||
if err == nil {
|
||||
var transport *http.Transport
|
||||
if proxyURL.Scheme == "socks5" {
|
||||
@@ -168,11 +171,12 @@ func getTokenFromWeb(ctx context.Context, config *oauth2.Config) (*oauth2.Token,
|
||||
codeChan := make(chan string)
|
||||
errChan := make(chan error)
|
||||
|
||||
// Create a new HTTP server.
|
||||
server := &http.Server{Addr: "localhost:8085"}
|
||||
// Create a new HTTP server with its own multiplexer.
|
||||
mux := http.NewServeMux()
|
||||
server := &http.Server{Addr: ":8085", Handler: mux}
|
||||
config.RedirectURL = "http://localhost:8085/oauth2callback"
|
||||
|
||||
http.HandleFunc("/oauth2callback", func(w http.ResponseWriter, r *http.Request) {
|
||||
mux.HandleFunc("/oauth2callback", func(w http.ResponseWriter, r *http.Request) {
|
||||
if err := r.URL.Query().Get("error"); err != "" {
|
||||
_, _ = fmt.Fprintf(w, "Authentication failed: %s", err)
|
||||
errChan <- fmt.Errorf("authentication failed via callback: %s", err)
|
||||
|
||||
@@ -1,3 +1,7 @@
|
||||
// Package client provides HTTP client functionality for interacting with Google Cloud AI APIs.
|
||||
// It handles OAuth2 authentication, token management, request/response processing,
|
||||
// streaming communication, quota management, and automatic model fallback.
|
||||
// The package supports both direct API key authentication and OAuth2 flows.
|
||||
package client
|
||||
|
||||
import (
|
||||
@@ -27,51 +31,84 @@ const (
|
||||
codeAssistEndpoint = "https://cloudcode-pa.googleapis.com"
|
||||
apiVersion = "v1internal"
|
||||
pluginVersion = "0.1.9"
|
||||
|
||||
glEndPoint = "https://generativelanguage.googleapis.com"
|
||||
glAPIVersion = "v1beta"
|
||||
)
|
||||
|
||||
var (
|
||||
previewModels = map[string][]string{
|
||||
"gemini-2.5-pro": {"gemini-2.5-pro-preview-05-06", "gemini-2.5-pro-preview-06-05"},
|
||||
"gemini-2.5-flash": {"gemini-2.5-flash-preview-04-17", "gemini-2.5-flash-preview-05-20"},
|
||||
}
|
||||
)
|
||||
|
||||
// Client is the main client for interacting with the CLI API.
|
||||
type Client struct {
|
||||
httpClient *http.Client
|
||||
RequestMutex sync.Mutex
|
||||
tokenStorage *auth.TokenStorage
|
||||
cfg *config.Config
|
||||
httpClient *http.Client
|
||||
RequestMutex sync.Mutex
|
||||
tokenStorage *auth.TokenStorage
|
||||
cfg *config.Config
|
||||
modelQuotaExceeded map[string]*time.Time
|
||||
glAPIKey string
|
||||
}
|
||||
|
||||
// NewClient creates a new CLI API client.
|
||||
func NewClient(httpClient *http.Client, ts *auth.TokenStorage, cfg *config.Config) *Client {
|
||||
func NewClient(httpClient *http.Client, ts *auth.TokenStorage, cfg *config.Config, glAPIKey ...string) *Client {
|
||||
var glKey string
|
||||
if len(glAPIKey) > 0 {
|
||||
glKey = glAPIKey[0]
|
||||
}
|
||||
return &Client{
|
||||
httpClient: httpClient,
|
||||
tokenStorage: ts,
|
||||
cfg: cfg,
|
||||
httpClient: httpClient,
|
||||
tokenStorage: ts,
|
||||
cfg: cfg,
|
||||
modelQuotaExceeded: make(map[string]*time.Time),
|
||||
glAPIKey: glKey,
|
||||
}
|
||||
}
|
||||
|
||||
// SetProjectID updates the project ID for the client's token storage.
|
||||
func (c *Client) SetProjectID(projectID string) {
|
||||
c.tokenStorage.ProjectID = projectID
|
||||
}
|
||||
|
||||
// SetIsAuto configures whether the client should operate in automatic mode.
|
||||
func (c *Client) SetIsAuto(auto bool) {
|
||||
c.tokenStorage.Auto = auto
|
||||
}
|
||||
|
||||
// SetIsChecked sets the checked status for the client's token storage.
|
||||
func (c *Client) SetIsChecked(checked bool) {
|
||||
c.tokenStorage.Checked = checked
|
||||
}
|
||||
|
||||
// IsChecked returns whether the client's token storage has been checked.
|
||||
func (c *Client) IsChecked() bool {
|
||||
return c.tokenStorage.Checked
|
||||
}
|
||||
|
||||
// IsAuto returns whether the client is operating in automatic mode.
|
||||
func (c *Client) IsAuto() bool {
|
||||
return c.tokenStorage.Auto
|
||||
}
|
||||
|
||||
// GetEmail returns the email address associated with the client's token storage.
|
||||
func (c *Client) GetEmail() string {
|
||||
return c.tokenStorage.Email
|
||||
}
|
||||
|
||||
// GetProjectID returns the Google Cloud project ID from the client's token storage.
|
||||
func (c *Client) GetProjectID() string {
|
||||
return c.tokenStorage.ProjectID
|
||||
if c.tokenStorage != nil {
|
||||
return c.tokenStorage.ProjectID
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// GetGenerativeLanguageAPIKey returns the generative language API key if configured.
|
||||
func (c *Client) GetGenerativeLanguageAPIKey() string {
|
||||
return c.glAPIKey
|
||||
}
|
||||
|
||||
// SetupUser performs the initial user onboarding and setup.
|
||||
@@ -187,6 +224,7 @@ func (c *Client) makeAPIRequest(ctx context.Context, endpoint, method string, bo
|
||||
metadataStr := getClientMetadataString()
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("User-Agent", getUserAgent())
|
||||
req.Header.Set("X-Goog-Api-Client", "gl-node/22.17.0")
|
||||
req.Header.Set("Client-Metadata", metadataStr)
|
||||
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", token.AccessToken))
|
||||
|
||||
@@ -214,99 +252,8 @@ func (c *Client) makeAPIRequest(ctx context.Context, endpoint, method string, bo
|
||||
return nil
|
||||
}
|
||||
|
||||
// SendMessageStream handles a single conversational turn, including tool calls.
|
||||
func (c *Client) SendMessageStream(ctx context.Context, rawJson []byte, model string, contents []Content, tools []ToolDeclaration) (<-chan []byte, <-chan *ErrorMessage) {
|
||||
dataTag := []byte("data: ")
|
||||
errChan := make(chan *ErrorMessage)
|
||||
dataChan := make(chan []byte)
|
||||
go func() {
|
||||
defer close(errChan)
|
||||
defer close(dataChan)
|
||||
|
||||
request := GenerateContentRequest{
|
||||
Contents: contents,
|
||||
GenerationConfig: GenerationConfig{
|
||||
ThinkingConfig: GenerationConfigThinkingConfig{
|
||||
IncludeThoughts: true,
|
||||
},
|
||||
},
|
||||
}
|
||||
request.Tools = tools
|
||||
|
||||
requestBody := map[string]interface{}{
|
||||
"project": c.tokenStorage.ProjectID, // Assuming ProjectID is available
|
||||
"request": request,
|
||||
"model": model,
|
||||
}
|
||||
|
||||
byteRequestBody, _ := json.Marshal(requestBody)
|
||||
|
||||
// log.Debug(string(byteRequestBody))
|
||||
|
||||
reasoningEffortResult := gjson.GetBytes(rawJson, "reasoning_effort")
|
||||
if reasoningEffortResult.String() == "none" {
|
||||
byteRequestBody, _ = sjson.DeleteBytes(byteRequestBody, "request.generationConfig.thinkingConfig.include_thoughts")
|
||||
byteRequestBody, _ = sjson.SetBytes(byteRequestBody, "request.generationConfig.thinkingConfig.thinkingBudget", 0)
|
||||
} else if reasoningEffortResult.String() == "auto" {
|
||||
byteRequestBody, _ = sjson.SetBytes(byteRequestBody, "request.generationConfig.thinkingConfig.thinkingBudget", -1)
|
||||
} else if reasoningEffortResult.String() == "low" {
|
||||
byteRequestBody, _ = sjson.SetBytes(byteRequestBody, "request.generationConfig.thinkingConfig.thinkingBudget", 1024)
|
||||
} else if reasoningEffortResult.String() == "medium" {
|
||||
byteRequestBody, _ = sjson.SetBytes(byteRequestBody, "request.generationConfig.thinkingConfig.thinkingBudget", 8192)
|
||||
} else if reasoningEffortResult.String() == "high" {
|
||||
byteRequestBody, _ = sjson.SetBytes(byteRequestBody, "request.generationConfig.thinkingConfig.thinkingBudget", 24576)
|
||||
} else {
|
||||
byteRequestBody, _ = sjson.SetBytes(byteRequestBody, "request.generationConfig.thinkingConfig.thinkingBudget", -1)
|
||||
}
|
||||
|
||||
temperatureResult := gjson.GetBytes(rawJson, "temperature")
|
||||
if temperatureResult.Exists() && temperatureResult.Type == gjson.Number {
|
||||
byteRequestBody, _ = sjson.SetBytes(byteRequestBody, "request.generationConfig.temperature", temperatureResult.Num)
|
||||
}
|
||||
|
||||
topPResult := gjson.GetBytes(rawJson, "top_p")
|
||||
if topPResult.Exists() && topPResult.Type == gjson.Number {
|
||||
byteRequestBody, _ = sjson.SetBytes(byteRequestBody, "request.generationConfig.topP", topPResult.Num)
|
||||
}
|
||||
|
||||
topKResult := gjson.GetBytes(rawJson, "top_k")
|
||||
if topKResult.Exists() && topKResult.Type == gjson.Number {
|
||||
byteRequestBody, _ = sjson.SetBytes(byteRequestBody, "request.generationConfig.topK", topKResult.Num)
|
||||
}
|
||||
|
||||
// log.Debug(string(byteRequestBody))
|
||||
|
||||
stream, err := c.APIRequest(ctx, "streamGenerateContent", byteRequestBody, true)
|
||||
if err != nil {
|
||||
// log.Println(err)
|
||||
errChan <- err
|
||||
return
|
||||
}
|
||||
|
||||
scanner := bufio.NewScanner(stream)
|
||||
for scanner.Scan() {
|
||||
line := scanner.Bytes()
|
||||
// log.Printf("Received stream chunk: %s", line)
|
||||
if bytes.HasPrefix(line, dataTag) {
|
||||
dataChan <- line[6:]
|
||||
}
|
||||
}
|
||||
|
||||
if errScanner := scanner.Err(); errScanner != nil {
|
||||
// log.Println(err)
|
||||
errChan <- &ErrorMessage{500, errScanner}
|
||||
_ = stream.Close()
|
||||
return
|
||||
}
|
||||
|
||||
_ = stream.Close()
|
||||
}()
|
||||
|
||||
return dataChan, errChan
|
||||
}
|
||||
|
||||
// APIRequest handles making requests to the CLI API endpoints.
|
||||
func (c *Client) APIRequest(ctx context.Context, endpoint string, body interface{}, stream bool) (io.ReadCloser, *ErrorMessage) {
|
||||
func (c *Client) APIRequest(ctx context.Context, endpoint string, body interface{}, alt string, stream bool) (io.ReadCloser, *ErrorMessage) {
|
||||
var jsonBody []byte
|
||||
var err error
|
||||
if byteBody, ok := body.([]byte); ok {
|
||||
@@ -317,35 +264,70 @@ func (c *Client) APIRequest(ctx context.Context, endpoint string, body interface
|
||||
return nil, &ErrorMessage{500, fmt.Errorf("failed to marshal request body: %w", err)}
|
||||
}
|
||||
}
|
||||
// log.Debug(string(jsonBody))
|
||||
reqBody := bytes.NewBuffer(jsonBody)
|
||||
|
||||
// Add alt=sse for streaming
|
||||
url := fmt.Sprintf("%s/%s:%s", codeAssistEndpoint, apiVersion, endpoint)
|
||||
if stream {
|
||||
url = url + "?alt=sse"
|
||||
var url string
|
||||
if c.glAPIKey == "" {
|
||||
// Add alt=sse for streaming
|
||||
url = fmt.Sprintf("%s/%s:%s", codeAssistEndpoint, apiVersion, endpoint)
|
||||
if alt == "" && stream {
|
||||
url = url + "?alt=sse"
|
||||
} else {
|
||||
if alt != "" {
|
||||
url = url + fmt.Sprintf("?$alt=%s", alt)
|
||||
}
|
||||
}
|
||||
} else {
|
||||
if endpoint == "countTokens" {
|
||||
modelResult := gjson.GetBytes(jsonBody, "model")
|
||||
url = fmt.Sprintf("%s/%s/models/%s:%s", glEndPoint, glAPIVersion, modelResult.String(), endpoint)
|
||||
} else {
|
||||
modelResult := gjson.GetBytes(jsonBody, "model")
|
||||
url = fmt.Sprintf("%s/%s/models/%s:%s", glEndPoint, glAPIVersion, modelResult.String(), endpoint)
|
||||
if alt == "" && stream {
|
||||
url = url + "?alt=sse"
|
||||
} else {
|
||||
if alt != "" {
|
||||
url = url + fmt.Sprintf("?$alt=%s", alt)
|
||||
}
|
||||
}
|
||||
jsonBody = []byte(gjson.GetBytes(jsonBody, "request").Raw)
|
||||
systemInstructionResult := gjson.GetBytes(jsonBody, "systemInstruction")
|
||||
if systemInstructionResult.Exists() {
|
||||
jsonBody, _ = sjson.SetRawBytes(jsonBody, "system_instruction", []byte(systemInstructionResult.Raw))
|
||||
jsonBody, _ = sjson.DeleteBytes(jsonBody, "systemInstruction")
|
||||
jsonBody, _ = sjson.DeleteBytes(jsonBody, "session_id")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// log.Debug(string(jsonBody))
|
||||
// log.Debug(url)
|
||||
reqBody := bytes.NewBuffer(jsonBody)
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, "POST", url, reqBody)
|
||||
if err != nil {
|
||||
return nil, &ErrorMessage{500, fmt.Errorf("failed to create request: %w", err)}
|
||||
}
|
||||
|
||||
token, err := c.httpClient.Transport.(*oauth2.Transport).Source.Token()
|
||||
if err != nil {
|
||||
return nil, &ErrorMessage{500, fmt.Errorf("failed to get token: %w", err)}
|
||||
return nil, &ErrorMessage{500, fmt.Errorf("failed to create request: %v", err)}
|
||||
}
|
||||
|
||||
// Set headers
|
||||
metadataStr := getClientMetadataString()
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("User-Agent", getUserAgent())
|
||||
req.Header.Set("Client-Metadata", metadataStr)
|
||||
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", token.AccessToken))
|
||||
if c.glAPIKey == "" {
|
||||
token, errToken := c.httpClient.Transport.(*oauth2.Transport).Source.Token()
|
||||
if errToken != nil {
|
||||
return nil, &ErrorMessage{500, fmt.Errorf("failed to get token: %v", errToken)}
|
||||
}
|
||||
req.Header.Set("User-Agent", getUserAgent())
|
||||
req.Header.Set("X-Goog-Api-Client", "gl-node/22.17.0")
|
||||
req.Header.Set("Client-Metadata", metadataStr)
|
||||
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", token.AccessToken))
|
||||
} else {
|
||||
req.Header.Set("x-goog-api-key", c.glAPIKey)
|
||||
}
|
||||
|
||||
resp, err := c.httpClient.Do(req)
|
||||
if err != nil {
|
||||
return nil, &ErrorMessage{500, fmt.Errorf("failed to execute request: %w", err)}
|
||||
return nil, &ErrorMessage{500, fmt.Errorf("failed to execute request: %v", err)}
|
||||
}
|
||||
|
||||
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
|
||||
@@ -355,15 +337,15 @@ func (c *Client) APIRequest(ctx context.Context, endpoint string, body interface
|
||||
}
|
||||
}()
|
||||
bodyBytes, _ := io.ReadAll(resp.Body)
|
||||
|
||||
// log.Debug(string(jsonBody))
|
||||
return nil, &ErrorMessage{resp.StatusCode, fmt.Errorf(string(bodyBytes))}
|
||||
}
|
||||
|
||||
return resp.Body, nil
|
||||
}
|
||||
|
||||
// SendMessageStream handles a single conversational turn, including tool calls.
|
||||
func (c *Client) SendMessage(ctx context.Context, rawJson []byte, model string, contents []Content, tools []ToolDeclaration) ([]byte, *ErrorMessage) {
|
||||
// SendMessage handles a single conversational turn, including tool calls.
|
||||
func (c *Client) SendMessage(ctx context.Context, rawJSON []byte, model string, systemInstruction *Content, contents []Content, tools []ToolDeclaration) ([]byte, *ErrorMessage) {
|
||||
request := GenerateContentRequest{
|
||||
Contents: contents,
|
||||
GenerationConfig: GenerationConfig{
|
||||
@@ -372,10 +354,13 @@ func (c *Client) SendMessage(ctx context.Context, rawJson []byte, model string,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
request.SystemInstruction = systemInstruction
|
||||
|
||||
request.Tools = tools
|
||||
|
||||
requestBody := map[string]interface{}{
|
||||
"project": c.tokenStorage.ProjectID, // Assuming ProjectID is available
|
||||
"project": c.GetProjectID(), // Assuming ProjectID is available
|
||||
"request": request,
|
||||
"model": model,
|
||||
}
|
||||
@@ -384,7 +369,7 @@ func (c *Client) SendMessage(ctx context.Context, rawJson []byte, model string,
|
||||
|
||||
// log.Debug(string(byteRequestBody))
|
||||
|
||||
reasoningEffortResult := gjson.GetBytes(rawJson, "reasoning_effort")
|
||||
reasoningEffortResult := gjson.GetBytes(rawJSON, "reasoning_effort")
|
||||
if reasoningEffortResult.String() == "none" {
|
||||
byteRequestBody, _ = sjson.DeleteBytes(byteRequestBody, "request.generationConfig.thinkingConfig.include_thoughts")
|
||||
byteRequestBody, _ = sjson.SetBytes(byteRequestBody, "request.generationConfig.thinkingConfig.thinkingBudget", 0)
|
||||
@@ -400,32 +385,443 @@ func (c *Client) SendMessage(ctx context.Context, rawJson []byte, model string,
|
||||
byteRequestBody, _ = sjson.SetBytes(byteRequestBody, "request.generationConfig.thinkingConfig.thinkingBudget", -1)
|
||||
}
|
||||
|
||||
temperatureResult := gjson.GetBytes(rawJson, "temperature")
|
||||
temperatureResult := gjson.GetBytes(rawJSON, "temperature")
|
||||
if temperatureResult.Exists() && temperatureResult.Type == gjson.Number {
|
||||
byteRequestBody, _ = sjson.SetBytes(byteRequestBody, "request.generationConfig.temperature", temperatureResult.Num)
|
||||
}
|
||||
|
||||
topPResult := gjson.GetBytes(rawJson, "top_p")
|
||||
topPResult := gjson.GetBytes(rawJSON, "top_p")
|
||||
if topPResult.Exists() && topPResult.Type == gjson.Number {
|
||||
byteRequestBody, _ = sjson.SetBytes(byteRequestBody, "request.generationConfig.topP", topPResult.Num)
|
||||
}
|
||||
|
||||
topKResult := gjson.GetBytes(rawJson, "top_k")
|
||||
topKResult := gjson.GetBytes(rawJSON, "top_k")
|
||||
if topKResult.Exists() && topKResult.Type == gjson.Number {
|
||||
byteRequestBody, _ = sjson.SetBytes(byteRequestBody, "request.generationConfig.topK", topKResult.Num)
|
||||
}
|
||||
|
||||
modelName := model
|
||||
// log.Debug(string(byteRequestBody))
|
||||
for {
|
||||
if c.isModelQuotaExceeded(modelName) {
|
||||
if c.cfg.QuotaExceeded.SwitchPreviewModel && c.glAPIKey == "" {
|
||||
modelName = c.getPreviewModel(model)
|
||||
if modelName != "" {
|
||||
log.Debugf("Model %s is quota exceeded. Switch to preview model %s", model, modelName)
|
||||
byteRequestBody, _ = sjson.SetBytes(byteRequestBody, "model", modelName)
|
||||
continue
|
||||
}
|
||||
}
|
||||
return nil, &ErrorMessage{
|
||||
StatusCode: 429,
|
||||
Error: fmt.Errorf(`{"error":{"code":429,"message":"All the models of '%s' are quota exceeded","status":"RESOURCE_EXHAUSTED"}}`, model),
|
||||
}
|
||||
}
|
||||
|
||||
respBody, err := c.APIRequest(ctx, "generateContent", byteRequestBody, false)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
respBody, err := c.APIRequest(ctx, "generateContent", byteRequestBody, "", false)
|
||||
if err != nil {
|
||||
if err.StatusCode == 429 {
|
||||
now := time.Now()
|
||||
c.modelQuotaExceeded[modelName] = &now
|
||||
if c.cfg.QuotaExceeded.SwitchPreviewModel && c.glAPIKey == "" {
|
||||
continue
|
||||
}
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
delete(c.modelQuotaExceeded, modelName)
|
||||
bodyBytes, errReadAll := io.ReadAll(respBody)
|
||||
if errReadAll != nil {
|
||||
return nil, &ErrorMessage{StatusCode: 500, Error: errReadAll}
|
||||
}
|
||||
return bodyBytes, nil
|
||||
}
|
||||
bodyBytes, errReadAll := io.ReadAll(respBody)
|
||||
if errReadAll != nil {
|
||||
return nil, &ErrorMessage{StatusCode: 500, Error: errReadAll}
|
||||
}
|
||||
|
||||
// SendMessageStream handles streaming conversational turns with comprehensive parameter management.
|
||||
// This function implements a sophisticated streaming system that supports tool calls, reasoning modes,
|
||||
// quota management, and automatic model fallback. It returns two channels for asynchronous communication:
|
||||
// one for streaming response data and another for error handling.
|
||||
func (c *Client) SendMessageStream(ctx context.Context, rawJSON []byte, model string, systemInstruction *Content, contents []Content, tools []ToolDeclaration, includeThoughts ...bool) (<-chan []byte, <-chan *ErrorMessage) {
|
||||
// Define the data prefix used in Server-Sent Events streaming format
|
||||
dataTag := []byte("data: ")
|
||||
|
||||
// Create channels for asynchronous communication
|
||||
// errChan: delivers error messages during streaming
|
||||
// dataChan: delivers response data chunks
|
||||
errChan := make(chan *ErrorMessage)
|
||||
dataChan := make(chan []byte)
|
||||
|
||||
// Launch a goroutine to handle the streaming process asynchronously
|
||||
// This allows the function to return immediately while processing continues in the background
|
||||
go func() {
|
||||
// Ensure channels are properly closed when the goroutine exits
|
||||
defer close(errChan)
|
||||
defer close(dataChan)
|
||||
|
||||
// Configure thinking/reasoning capabilities
|
||||
// Default to including thoughts unless explicitly disabled
|
||||
includeThoughtsFlag := true
|
||||
if len(includeThoughts) > 0 {
|
||||
includeThoughtsFlag = includeThoughts[0]
|
||||
}
|
||||
|
||||
// Build the base request structure for the Gemini API
|
||||
// This includes conversation contents and generation configuration
|
||||
request := GenerateContentRequest{
|
||||
Contents: contents,
|
||||
GenerationConfig: GenerationConfig{
|
||||
ThinkingConfig: GenerationConfigThinkingConfig{
|
||||
IncludeThoughts: includeThoughtsFlag,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
// Add system instructions if provided
|
||||
// System instructions guide the AI's behavior and response style
|
||||
request.SystemInstruction = systemInstruction
|
||||
|
||||
// Add available tools for function calling capabilities
|
||||
// Tools allow the AI to perform actions beyond text generation
|
||||
request.Tools = tools
|
||||
|
||||
// Construct the complete request body with project context
|
||||
// The project ID is essential for proper API routing and billing
|
||||
requestBody := map[string]interface{}{
|
||||
"project": c.GetProjectID(), // Project ID for API routing and quota management
|
||||
"request": request,
|
||||
"model": model,
|
||||
}
|
||||
|
||||
// Serialize the request body to JSON for API transmission
|
||||
byteRequestBody, _ := json.Marshal(requestBody)
|
||||
|
||||
// Parse and configure reasoning effort levels from the original request
|
||||
// This maps Claude-style reasoning effort parameters to Gemini's thinking budget system
|
||||
reasoningEffortResult := gjson.GetBytes(rawJSON, "reasoning_effort")
|
||||
if reasoningEffortResult.String() == "none" {
|
||||
// Disable thinking entirely for fastest responses
|
||||
byteRequestBody, _ = sjson.DeleteBytes(byteRequestBody, "request.generationConfig.thinkingConfig.include_thoughts")
|
||||
byteRequestBody, _ = sjson.SetBytes(byteRequestBody, "request.generationConfig.thinkingConfig.thinkingBudget", 0)
|
||||
} else if reasoningEffortResult.String() == "auto" {
|
||||
// Let the model decide the appropriate thinking budget automatically
|
||||
byteRequestBody, _ = sjson.SetBytes(byteRequestBody, "request.generationConfig.thinkingConfig.thinkingBudget", -1)
|
||||
} else if reasoningEffortResult.String() == "low" {
|
||||
// Minimal thinking for simple tasks (1KB thinking budget)
|
||||
byteRequestBody, _ = sjson.SetBytes(byteRequestBody, "request.generationConfig.thinkingConfig.thinkingBudget", 1024)
|
||||
} else if reasoningEffortResult.String() == "medium" {
|
||||
// Moderate thinking for complex tasks (8KB thinking budget)
|
||||
byteRequestBody, _ = sjson.SetBytes(byteRequestBody, "request.generationConfig.thinkingConfig.thinkingBudget", 8192)
|
||||
} else if reasoningEffortResult.String() == "high" {
|
||||
// Maximum thinking for very complex tasks (24KB thinking budget)
|
||||
byteRequestBody, _ = sjson.SetBytes(byteRequestBody, "request.generationConfig.thinkingConfig.thinkingBudget", 24576)
|
||||
} else {
|
||||
// Default to automatic thinking budget if no specific level is provided
|
||||
byteRequestBody, _ = sjson.SetBytes(byteRequestBody, "request.generationConfig.thinkingConfig.thinkingBudget", -1)
|
||||
}
|
||||
|
||||
// Configure temperature parameter for response randomness control
|
||||
// Temperature affects the creativity vs consistency trade-off in responses
|
||||
temperatureResult := gjson.GetBytes(rawJSON, "temperature")
|
||||
if temperatureResult.Exists() && temperatureResult.Type == gjson.Number {
|
||||
byteRequestBody, _ = sjson.SetBytes(byteRequestBody, "request.generationConfig.temperature", temperatureResult.Num)
|
||||
}
|
||||
|
||||
// Configure top-p parameter for nucleus sampling
|
||||
// Controls the cumulative probability threshold for token selection
|
||||
topPResult := gjson.GetBytes(rawJSON, "top_p")
|
||||
if topPResult.Exists() && topPResult.Type == gjson.Number {
|
||||
byteRequestBody, _ = sjson.SetBytes(byteRequestBody, "request.generationConfig.topP", topPResult.Num)
|
||||
}
|
||||
|
||||
// Configure top-k parameter for limiting token candidates
|
||||
// Restricts the model to consider only the top K most likely tokens
|
||||
topKResult := gjson.GetBytes(rawJSON, "top_k")
|
||||
if topKResult.Exists() && topKResult.Type == gjson.Number {
|
||||
byteRequestBody, _ = sjson.SetBytes(byteRequestBody, "request.generationConfig.topK", topKResult.Num)
|
||||
}
|
||||
|
||||
// Initialize model name for quota management and potential fallback
|
||||
modelName := model
|
||||
var stream io.ReadCloser
|
||||
|
||||
// Quota management and model fallback loop
|
||||
// This loop handles quota exceeded scenarios and automatic model switching
|
||||
for {
|
||||
// Check if the current model has exceeded its quota
|
||||
if c.isModelQuotaExceeded(modelName) {
|
||||
// Attempt to switch to a preview model if configured and using account auth
|
||||
if c.cfg.QuotaExceeded.SwitchPreviewModel && c.glAPIKey == "" {
|
||||
modelName = c.getPreviewModel(model)
|
||||
if modelName != "" {
|
||||
log.Debugf("Model %s is quota exceeded. Switch to preview model %s", model, modelName)
|
||||
// Update the request body with the new model name
|
||||
byteRequestBody, _ = sjson.SetBytes(byteRequestBody, "model", modelName)
|
||||
continue // Retry with the preview model
|
||||
}
|
||||
}
|
||||
// If no fallback is available, return a quota exceeded error
|
||||
errChan <- &ErrorMessage{
|
||||
StatusCode: 429,
|
||||
Error: fmt.Errorf(`{"error":{"code":429,"message":"All the models of '%s' are quota exceeded","status":"RESOURCE_EXHAUSTED"}}`, model),
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// Attempt to establish a streaming connection with the API
|
||||
var err *ErrorMessage
|
||||
stream, err = c.APIRequest(ctx, "streamGenerateContent", byteRequestBody, "", true)
|
||||
if err != nil {
|
||||
// Handle quota exceeded errors by marking the model and potentially retrying
|
||||
if err.StatusCode == 429 {
|
||||
now := time.Now()
|
||||
c.modelQuotaExceeded[modelName] = &now // Mark model as quota exceeded
|
||||
// If preview model switching is enabled, retry the loop
|
||||
if c.cfg.QuotaExceeded.SwitchPreviewModel && c.glAPIKey == "" {
|
||||
continue
|
||||
}
|
||||
}
|
||||
// Forward other errors to the error channel
|
||||
errChan <- err
|
||||
return
|
||||
}
|
||||
// Clear any previous quota exceeded status for this model
|
||||
delete(c.modelQuotaExceeded, modelName)
|
||||
break // Successfully established connection, exit the retry loop
|
||||
}
|
||||
|
||||
// Process the streaming response using a scanner
|
||||
// This handles the Server-Sent Events format from the API
|
||||
scanner := bufio.NewScanner(stream)
|
||||
for scanner.Scan() {
|
||||
line := scanner.Bytes()
|
||||
// Filter and forward only data lines (those prefixed with "data: ")
|
||||
// This extracts the actual JSON content from the SSE format
|
||||
if bytes.HasPrefix(line, dataTag) {
|
||||
dataChan <- line[6:] // Remove "data: " prefix and send the JSON content
|
||||
}
|
||||
}
|
||||
|
||||
// Handle any scanning errors that occurred during stream processing
|
||||
if errScanner := scanner.Err(); errScanner != nil {
|
||||
// Send a 500 Internal Server Error for scanning failures
|
||||
errChan <- &ErrorMessage{500, errScanner}
|
||||
_ = stream.Close()
|
||||
return
|
||||
}
|
||||
|
||||
// Ensure the stream is properly closed to prevent resource leaks
|
||||
_ = stream.Close()
|
||||
}()
|
||||
|
||||
// Return the channels immediately for asynchronous communication
|
||||
// The caller can read from these channels while the goroutine processes the request
|
||||
return dataChan, errChan
|
||||
}
|
||||
|
||||
// SendRawTokenCount handles a token count.
|
||||
func (c *Client) SendRawTokenCount(ctx context.Context, rawJSON []byte, alt string) ([]byte, *ErrorMessage) {
|
||||
modelResult := gjson.GetBytes(rawJSON, "model")
|
||||
model := modelResult.String()
|
||||
modelName := model
|
||||
for {
|
||||
if c.isModelQuotaExceeded(modelName) {
|
||||
if c.cfg.QuotaExceeded.SwitchPreviewModel && c.glAPIKey == "" {
|
||||
modelName = c.getPreviewModel(model)
|
||||
if modelName != "" {
|
||||
log.Debugf("Model %s is quota exceeded. Switch to preview model %s", model, modelName)
|
||||
rawJSON, _ = sjson.SetBytes(rawJSON, "model", modelName)
|
||||
continue
|
||||
}
|
||||
}
|
||||
return nil, &ErrorMessage{
|
||||
StatusCode: 429,
|
||||
Error: fmt.Errorf(`{"error":{"code":429,"message":"All the models of '%s' are quota exceeded","status":"RESOURCE_EXHAUSTED"}}`, model),
|
||||
}
|
||||
}
|
||||
|
||||
respBody, err := c.APIRequest(ctx, "countTokens", rawJSON, alt, false)
|
||||
if err != nil {
|
||||
if err.StatusCode == 429 {
|
||||
now := time.Now()
|
||||
c.modelQuotaExceeded[modelName] = &now
|
||||
if c.cfg.QuotaExceeded.SwitchPreviewModel && c.glAPIKey == "" {
|
||||
continue
|
||||
}
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
delete(c.modelQuotaExceeded, modelName)
|
||||
bodyBytes, errReadAll := io.ReadAll(respBody)
|
||||
if errReadAll != nil {
|
||||
return nil, &ErrorMessage{StatusCode: 500, Error: errReadAll}
|
||||
}
|
||||
return bodyBytes, nil
|
||||
}
|
||||
return bodyBytes, nil
|
||||
}
|
||||
|
||||
// SendRawMessage handles a single conversational turn, including tool calls.
|
||||
func (c *Client) SendRawMessage(ctx context.Context, rawJSON []byte, alt string) ([]byte, *ErrorMessage) {
|
||||
if c.glAPIKey == "" {
|
||||
rawJSON, _ = sjson.SetBytes(rawJSON, "project", c.GetProjectID())
|
||||
}
|
||||
|
||||
modelResult := gjson.GetBytes(rawJSON, "model")
|
||||
model := modelResult.String()
|
||||
modelName := model
|
||||
for {
|
||||
if c.isModelQuotaExceeded(modelName) {
|
||||
if c.cfg.QuotaExceeded.SwitchPreviewModel && c.glAPIKey == "" {
|
||||
modelName = c.getPreviewModel(model)
|
||||
if modelName != "" {
|
||||
log.Debugf("Model %s is quota exceeded. Switch to preview model %s", model, modelName)
|
||||
rawJSON, _ = sjson.SetBytes(rawJSON, "model", modelName)
|
||||
continue
|
||||
}
|
||||
}
|
||||
return nil, &ErrorMessage{
|
||||
StatusCode: 429,
|
||||
Error: fmt.Errorf(`{"error":{"code":429,"message":"All the models of '%s' are quota exceeded","status":"RESOURCE_EXHAUSTED"}}`, model),
|
||||
}
|
||||
}
|
||||
|
||||
respBody, err := c.APIRequest(ctx, "generateContent", rawJSON, alt, false)
|
||||
if err != nil {
|
||||
if err.StatusCode == 429 {
|
||||
now := time.Now()
|
||||
c.modelQuotaExceeded[modelName] = &now
|
||||
if c.cfg.QuotaExceeded.SwitchPreviewModel && c.glAPIKey == "" {
|
||||
continue
|
||||
}
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
delete(c.modelQuotaExceeded, modelName)
|
||||
bodyBytes, errReadAll := io.ReadAll(respBody)
|
||||
if errReadAll != nil {
|
||||
return nil, &ErrorMessage{StatusCode: 500, Error: errReadAll}
|
||||
}
|
||||
return bodyBytes, nil
|
||||
}
|
||||
}
|
||||
|
||||
// SendRawMessageStream handles a single conversational turn, including tool calls.
|
||||
func (c *Client) SendRawMessageStream(ctx context.Context, rawJSON []byte, alt string) (<-chan []byte, <-chan *ErrorMessage) {
|
||||
dataTag := []byte("data: ")
|
||||
errChan := make(chan *ErrorMessage)
|
||||
dataChan := make(chan []byte)
|
||||
go func() {
|
||||
defer close(errChan)
|
||||
defer close(dataChan)
|
||||
|
||||
if c.glAPIKey == "" {
|
||||
rawJSON, _ = sjson.SetBytes(rawJSON, "project", c.GetProjectID())
|
||||
}
|
||||
|
||||
modelResult := gjson.GetBytes(rawJSON, "model")
|
||||
model := modelResult.String()
|
||||
modelName := model
|
||||
var stream io.ReadCloser
|
||||
for {
|
||||
if c.isModelQuotaExceeded(modelName) {
|
||||
if c.cfg.QuotaExceeded.SwitchPreviewModel && c.glAPIKey == "" {
|
||||
modelName = c.getPreviewModel(model)
|
||||
if modelName != "" {
|
||||
log.Debugf("Model %s is quota exceeded. Switch to preview model %s", model, modelName)
|
||||
rawJSON, _ = sjson.SetBytes(rawJSON, "model", modelName)
|
||||
continue
|
||||
}
|
||||
}
|
||||
errChan <- &ErrorMessage{
|
||||
StatusCode: 429,
|
||||
Error: fmt.Errorf(`{"error":{"code":429,"message":"All the models of '%s' are quota exceeded","status":"RESOURCE_EXHAUSTED"}}`, model),
|
||||
}
|
||||
return
|
||||
}
|
||||
var err *ErrorMessage
|
||||
stream, err = c.APIRequest(ctx, "streamGenerateContent", rawJSON, alt, true)
|
||||
if err != nil {
|
||||
if err.StatusCode == 429 {
|
||||
now := time.Now()
|
||||
c.modelQuotaExceeded[modelName] = &now
|
||||
if c.cfg.QuotaExceeded.SwitchPreviewModel && c.glAPIKey == "" {
|
||||
continue
|
||||
}
|
||||
}
|
||||
errChan <- err
|
||||
return
|
||||
}
|
||||
delete(c.modelQuotaExceeded, modelName)
|
||||
break
|
||||
}
|
||||
|
||||
if alt == "" {
|
||||
scanner := bufio.NewScanner(stream)
|
||||
for scanner.Scan() {
|
||||
line := scanner.Bytes()
|
||||
if bytes.HasPrefix(line, dataTag) {
|
||||
dataChan <- line[6:]
|
||||
}
|
||||
}
|
||||
|
||||
if errScanner := scanner.Err(); errScanner != nil {
|
||||
errChan <- &ErrorMessage{500, errScanner}
|
||||
_ = stream.Close()
|
||||
return
|
||||
}
|
||||
|
||||
} else {
|
||||
data, err := io.ReadAll(stream)
|
||||
if err != nil {
|
||||
errChan <- &ErrorMessage{500, err}
|
||||
_ = stream.Close()
|
||||
return
|
||||
}
|
||||
dataChan <- data
|
||||
}
|
||||
_ = stream.Close()
|
||||
|
||||
}()
|
||||
|
||||
return dataChan, errChan
|
||||
}
|
||||
|
||||
// isModelQuotaExceeded checks if the specified model has exceeded its quota
|
||||
// within the last 30 minutes.
|
||||
func (c *Client) isModelQuotaExceeded(model string) bool {
|
||||
if lastExceededTime, hasKey := c.modelQuotaExceeded[model]; hasKey {
|
||||
duration := time.Now().Sub(*lastExceededTime)
|
||||
if duration > 30*time.Minute {
|
||||
return false
|
||||
}
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// getPreviewModel returns an available preview model for the given base model,
|
||||
// or an empty string if no preview models are available or all are quota exceeded.
|
||||
func (c *Client) getPreviewModel(model string) string {
|
||||
if models, hasKey := previewModels[model]; hasKey {
|
||||
for i := 0; i < len(models); i++ {
|
||||
if !c.isModelQuotaExceeded(models[i]) {
|
||||
return models[i]
|
||||
}
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// IsModelQuotaExceeded returns true if the specified model has exceeded its quota
|
||||
// and no fallback options are available.
|
||||
func (c *Client) IsModelQuotaExceeded(model string) bool {
|
||||
if c.isModelQuotaExceeded(model) {
|
||||
if c.cfg.QuotaExceeded.SwitchPreviewModel {
|
||||
return c.getPreviewModel(model) == ""
|
||||
}
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// CheckCloudAPIIsEnabled sends a simple test request to the API to verify
|
||||
@@ -442,23 +838,24 @@ func (c *Client) CheckCloudAPIIsEnabled() (bool, error) {
|
||||
// A simple request to test the API endpoint.
|
||||
requestBody := fmt.Sprintf(`{"project":"%s","request":{"contents":[{"role":"user","parts":[{"text":"Be concise. What is the capital of France?"}]}],"generationConfig":{"thinkingConfig":{"include_thoughts":false,"thinkingBudget":0}}},"model":"gemini-2.5-flash"}`, c.tokenStorage.ProjectID)
|
||||
|
||||
stream, err := c.APIRequest(ctx, "streamGenerateContent", []byte(requestBody), true)
|
||||
stream, err := c.APIRequest(ctx, "streamGenerateContent", []byte(requestBody), "", true)
|
||||
if err != nil {
|
||||
// If a 403 Forbidden error occurs, it likely means the API is not enabled.
|
||||
if err.StatusCode == 403 {
|
||||
errJson := err.Error.Error()
|
||||
errJSON := err.Error.Error()
|
||||
// Check for a specific error code and extract the activation URL.
|
||||
if gjson.Get(errJson, "error.code").Int() == 403 {
|
||||
activationUrl := gjson.Get(errJson, "error.details.0.metadata.activationUrl").String()
|
||||
if activationUrl != "" {
|
||||
if gjson.Get(errJSON, "0.error.code").Int() == 403 {
|
||||
activationURL := gjson.Get(errJSON, "0.error.details.0.metadata.activationUrl").String()
|
||||
if activationURL != "" {
|
||||
log.Warnf(
|
||||
"\n\nPlease activate your account with this url:\n\n%s\n And execute this command again:\n%s --login --project_id %s",
|
||||
activationUrl,
|
||||
"\n\nPlease activate your account with this url:\n\n%s\n\n And execute this command again:\n%s --login --project_id %s",
|
||||
activationURL,
|
||||
os.Args[0],
|
||||
c.tokenStorage.ProjectID,
|
||||
)
|
||||
}
|
||||
}
|
||||
log.Warnf("\n\nPlease copy this message and create an issue.\n\n%s\n\n", errJSON)
|
||||
return false, nil
|
||||
}
|
||||
return false, err.Error
|
||||
@@ -536,10 +933,10 @@ func (c *Client) SaveTokenToFile() error {
|
||||
// such as IDE type, platform, and plugin version.
|
||||
func getClientMetadata() map[string]string {
|
||||
return map[string]string{
|
||||
"ideType": "IDE_UNSPECIFIED",
|
||||
"platform": getPlatform(),
|
||||
"pluginType": "GEMINI",
|
||||
"pluginVersion": pluginVersion,
|
||||
"ideType": "IDE_UNSPECIFIED",
|
||||
"platform": "PLATFORM_UNSPECIFIED",
|
||||
"pluginType": "GEMINI",
|
||||
// "pluginVersion": pluginVersion,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -556,7 +953,8 @@ func getClientMetadataString() string {
|
||||
|
||||
// getUserAgent constructs the User-Agent string for HTTP requests.
|
||||
func getUserAgent() string {
|
||||
return fmt.Sprintf("GeminiCLI/%s (%s; %s)", pluginVersion, runtime.GOOS, runtime.GOARCH)
|
||||
// return fmt.Sprintf("GeminiCLI/%s (%s; %s)", pluginVersion, runtime.GOOS, runtime.GOARCH)
|
||||
return "google-api-nodejs-client/9.15.1"
|
||||
}
|
||||
|
||||
// getPlatform determines the operating system and architecture and formats
|
||||
|
||||
@@ -64,9 +64,10 @@ type FunctionResponse struct {
|
||||
|
||||
// GenerateContentRequest is the top-level request structure for the streamGenerateContent endpoint.
|
||||
type GenerateContentRequest struct {
|
||||
Contents []Content `json:"contents"`
|
||||
Tools []ToolDeclaration `json:"tools,omitempty"`
|
||||
GenerationConfig `json:"generationConfig"`
|
||||
SystemInstruction *Content `json:"systemInstruction,omitempty"`
|
||||
Contents []Content `json:"contents"`
|
||||
Tools []ToolDeclaration `json:"tools,omitempty"`
|
||||
GenerationConfig `json:"generationConfig"`
|
||||
}
|
||||
|
||||
// GenerationConfig defines parameters that control the model's generation behavior.
|
||||
|
||||
@@ -1,3 +1,6 @@
|
||||
// Package cmd provides command-line interface functionality for the CLI Proxy API.
|
||||
// It implements the main application commands including login/authentication
|
||||
// and server startup, handling the complete user onboarding and service lifecycle.
|
||||
package cmd
|
||||
|
||||
import (
|
||||
@@ -73,6 +76,7 @@ func DoLogin(cfg *config.Config, projectID string) {
|
||||
// If the check fails (returns false), the CheckCloudAPIIsEnabled function
|
||||
// will have already printed instructions, so we can just exit.
|
||||
if !isChecked {
|
||||
log.Fatal("Failed to check if Cloud AI API is enabled. If you encounter an error message, please create an issue.")
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,15 +1,22 @@
|
||||
// Package cmd provides the main service execution functionality for the CLIProxyAPI.
|
||||
// It contains the core logic for starting and managing the API proxy service,
|
||||
// including authentication client management, server initialization, and graceful shutdown handling.
|
||||
// The package handles loading authentication tokens, creating client pools, starting the API server,
|
||||
// and monitoring configuration changes through file watchers.
|
||||
package cmd
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"github.com/luispater/CLIProxyAPI/internal/api"
|
||||
"github.com/luispater/CLIProxyAPI/internal/auth"
|
||||
"github.com/luispater/CLIProxyAPI/internal/client"
|
||||
"github.com/luispater/CLIProxyAPI/internal/config"
|
||||
"github.com/luispater/CLIProxyAPI/internal/util"
|
||||
"github.com/luispater/CLIProxyAPI/internal/watcher"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"io/fs"
|
||||
"net/http"
|
||||
"os"
|
||||
"os/signal"
|
||||
"path/filepath"
|
||||
@@ -21,14 +28,7 @@ import (
|
||||
// StartService initializes and starts the main API proxy service.
|
||||
// It loads all available authentication tokens, creates a pool of clients,
|
||||
// starts the API server, and handles graceful shutdown signals.
|
||||
func StartService(cfg *config.Config) {
|
||||
// Configure the API server based on the main application config.
|
||||
apiConfig := &api.ServerConfig{
|
||||
Port: fmt.Sprintf("%d", cfg.Port),
|
||||
Debug: cfg.Debug,
|
||||
ApiKeys: cfg.ApiKeys,
|
||||
}
|
||||
|
||||
func StartService(cfg *config.Config, configPath string) {
|
||||
// Create a pool of API clients, one for each token file found.
|
||||
cliClients := make([]*client.Client, 0)
|
||||
err := filepath.Walk(cfg.AuthDir, func(path string, info fs.FileInfo, err error) error {
|
||||
@@ -72,13 +72,62 @@ func StartService(cfg *config.Config) {
|
||||
log.Fatalf("Error walking auth directory: %v", err)
|
||||
}
|
||||
|
||||
// Create and start the API server with the pool of clients.
|
||||
apiServer := api.NewServer(apiConfig, cliClients)
|
||||
log.Infof("Starting API server on port %s", apiConfig.Port)
|
||||
if err = apiServer.Start(); err != nil {
|
||||
log.Fatalf("API server failed to start: %v", err)
|
||||
if len(cfg.GlAPIKey) > 0 {
|
||||
for i := 0; i < len(cfg.GlAPIKey); i++ {
|
||||
httpClient, errSetProxy := util.SetProxy(cfg, &http.Client{})
|
||||
if errSetProxy != nil {
|
||||
log.Fatalf("set proxy failed: %v", errSetProxy)
|
||||
}
|
||||
|
||||
log.Debug("Initializing with Generative Language API key...")
|
||||
cliClient := client.NewClient(httpClient, nil, cfg, cfg.GlAPIKey[i])
|
||||
cliClients = append(cliClients, cliClient)
|
||||
}
|
||||
}
|
||||
|
||||
// Create and start the API server with the pool of clients.
|
||||
apiServer := api.NewServer(cfg, cliClients)
|
||||
log.Infof("Starting API server on port %d", cfg.Port)
|
||||
|
||||
// Start the API server in a goroutine so it doesn't block the main thread
|
||||
go func() {
|
||||
if err = apiServer.Start(); err != nil {
|
||||
log.Fatalf("API server failed to start: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
// Give the server a moment to start up
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
log.Info("API server started successfully")
|
||||
|
||||
// Setup file watcher for config and auth directory changes
|
||||
fileWatcher, errNewWatcher := watcher.NewWatcher(configPath, cfg.AuthDir, func(newClients []*client.Client, newCfg *config.Config) {
|
||||
// Update the API server with new clients and configuration
|
||||
apiServer.UpdateClients(newClients, newCfg)
|
||||
})
|
||||
if errNewWatcher != nil {
|
||||
log.Fatalf("failed to create file watcher: %v", errNewWatcher)
|
||||
}
|
||||
|
||||
// Set initial state for the watcher
|
||||
fileWatcher.SetConfig(cfg)
|
||||
fileWatcher.SetClients(cliClients)
|
||||
|
||||
// Start the file watcher
|
||||
watcherCtx, watcherCancel := context.WithCancel(context.Background())
|
||||
if errStartWatcher := fileWatcher.Start(watcherCtx); errStartWatcher != nil {
|
||||
log.Fatalf("failed to start file watcher: %v", errStartWatcher)
|
||||
}
|
||||
log.Info("file watcher started for config and auth directory changes")
|
||||
|
||||
defer func() {
|
||||
watcherCancel()
|
||||
errStopWatcher := fileWatcher.Stop()
|
||||
if errStopWatcher != nil {
|
||||
log.Errorf("error stopping file watcher: %v", errStopWatcher)
|
||||
}
|
||||
}()
|
||||
|
||||
// Set up a channel to listen for OS signals for graceful shutdown.
|
||||
sigChan := make(chan os.Signal, 1)
|
||||
signal.Notify(sigChan, syscall.SIGINT, syscall.SIGTERM)
|
||||
|
||||
@@ -1,3 +1,7 @@
|
||||
// Package config provides configuration management for the CLI Proxy API server.
|
||||
// It handles loading and parsing YAML configuration files, and provides structured
|
||||
// access to application settings including server port, authentication directory,
|
||||
// debug settings, proxy configuration, and API keys.
|
||||
package config
|
||||
|
||||
import (
|
||||
@@ -11,13 +15,26 @@ type Config struct {
|
||||
// Port is the network port on which the API server will listen.
|
||||
Port int `yaml:"port"`
|
||||
// AuthDir is the directory where authentication token files are stored.
|
||||
AuthDir string `yaml:"auth_dir"`
|
||||
AuthDir string `yaml:"auth-dir"`
|
||||
// Debug enables or disables debug-level logging and other debug features.
|
||||
Debug bool `yaml:"debug"`
|
||||
// ProxyUrl is the URL of an optional proxy server to use for outbound requests.
|
||||
ProxyUrl string `yaml:"proxy-url"`
|
||||
// ApiKeys is a list of keys for authenticating clients to this proxy server.
|
||||
ApiKeys []string `yaml:"api_keys"`
|
||||
// ProxyURL is the URL of an optional proxy server to use for outbound requests.
|
||||
ProxyURL string `yaml:"proxy-url"`
|
||||
// APIKeys is a list of keys for authenticating clients to this proxy server.
|
||||
APIKeys []string `yaml:"api-keys"`
|
||||
// QuotaExceeded defines the behavior when a quota is exceeded.
|
||||
QuotaExceeded QuotaExceeded `yaml:"quota-exceeded"`
|
||||
// GlAPIKey is the API key for the generative language API.
|
||||
GlAPIKey []string `yaml:"generative-language-api-key"`
|
||||
}
|
||||
|
||||
// QuotaExceeded defines the behavior when API quota limits are exceeded.
|
||||
// It provides configuration options for automatic failover mechanisms.
|
||||
type QuotaExceeded struct {
|
||||
// SwitchProject indicates whether to automatically switch to another project when a quota is exceeded.
|
||||
SwitchProject bool `yaml:"switch-project"`
|
||||
// SwitchPreviewModel indicates whether to automatically switch to a preview model when a quota is exceeded.
|
||||
SwitchPreviewModel bool `yaml:"switch-preview-model"`
|
||||
}
|
||||
|
||||
// LoadConfig reads a YAML configuration file from the given path,
|
||||
|
||||
43
internal/util/proxy.go
Normal file
43
internal/util/proxy.go
Normal file
@@ -0,0 +1,43 @@
|
||||
// Package util provides utility functions for the CLI Proxy API server.
|
||||
// It includes helper functions for proxy configuration, HTTP client setup,
|
||||
// and other common operations used across the application.
|
||||
package util
|
||||
|
||||
import (
|
||||
"context"
|
||||
"github.com/luispater/CLIProxyAPI/internal/config"
|
||||
"golang.org/x/net/proxy"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/url"
|
||||
)
|
||||
|
||||
// SetProxy configures the provided HTTP client with proxy settings from the configuration.
|
||||
// It supports SOCKS5, HTTP, and HTTPS proxies. The function modifies the client's transport
|
||||
// to route requests through the configured proxy server.
|
||||
func SetProxy(cfg *config.Config, httpClient *http.Client) (*http.Client, error) {
|
||||
var transport *http.Transport
|
||||
proxyURL, errParse := url.Parse(cfg.ProxyURL)
|
||||
if errParse == nil {
|
||||
if proxyURL.Scheme == "socks5" {
|
||||
username := proxyURL.User.Username()
|
||||
password, _ := proxyURL.User.Password()
|
||||
proxyAuth := &proxy.Auth{User: username, Password: password}
|
||||
dialer, errSOCKS5 := proxy.SOCKS5("tcp", proxyURL.Host, proxyAuth, proxy.Direct)
|
||||
if errSOCKS5 != nil {
|
||||
return nil, errSOCKS5
|
||||
}
|
||||
transport = &http.Transport{
|
||||
DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
|
||||
return dialer.Dial(network, addr)
|
||||
},
|
||||
}
|
||||
} else if proxyURL.Scheme == "http" || proxyURL.Scheme == "https" {
|
||||
transport = &http.Transport{Proxy: http.ProxyURL(proxyURL)}
|
||||
}
|
||||
}
|
||||
if transport != nil {
|
||||
httpClient.Transport = transport
|
||||
}
|
||||
return httpClient, nil
|
||||
}
|
||||
286
internal/watcher/watcher.go
Normal file
286
internal/watcher/watcher.go
Normal file
@@ -0,0 +1,286 @@
|
||||
// Package watcher provides file system monitoring functionality for the CLI Proxy API.
|
||||
// It watches configuration files and authentication directories for changes,
|
||||
// automatically reloading clients and configuration when files are modified.
|
||||
// The package handles cross-platform file system events and supports hot-reloading.
|
||||
package watcher
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"github.com/fsnotify/fsnotify"
|
||||
"github.com/luispater/CLIProxyAPI/internal/auth"
|
||||
"github.com/luispater/CLIProxyAPI/internal/client"
|
||||
"github.com/luispater/CLIProxyAPI/internal/config"
|
||||
"github.com/luispater/CLIProxyAPI/internal/util"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"io/fs"
|
||||
"net/http"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
// Watcher manages file watching for configuration and authentication files
|
||||
type Watcher struct {
|
||||
configPath string
|
||||
authDir string
|
||||
config *config.Config
|
||||
clients []*client.Client
|
||||
clientsMutex sync.RWMutex
|
||||
reloadCallback func([]*client.Client, *config.Config)
|
||||
watcher *fsnotify.Watcher
|
||||
}
|
||||
|
||||
// NewWatcher creates a new file watcher instance
|
||||
func NewWatcher(configPath, authDir string, reloadCallback func([]*client.Client, *config.Config)) (*Watcher, error) {
|
||||
watcher, errNewWatcher := fsnotify.NewWatcher()
|
||||
if errNewWatcher != nil {
|
||||
return nil, errNewWatcher
|
||||
}
|
||||
|
||||
return &Watcher{
|
||||
configPath: configPath,
|
||||
authDir: authDir,
|
||||
reloadCallback: reloadCallback,
|
||||
watcher: watcher,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Start begins watching the configuration file and authentication directory
|
||||
func (w *Watcher) Start(ctx context.Context) error {
|
||||
// Watch the config file
|
||||
if errAddConfig := w.watcher.Add(w.configPath); errAddConfig != nil {
|
||||
log.Errorf("failed to watch config file %s: %v", w.configPath, errAddConfig)
|
||||
return errAddConfig
|
||||
}
|
||||
log.Debugf("watching config file: %s", w.configPath)
|
||||
|
||||
// Watch the auth directory
|
||||
if errAddAuthDir := w.watcher.Add(w.authDir); errAddAuthDir != nil {
|
||||
log.Errorf("failed to watch auth directory %s: %v", w.authDir, errAddAuthDir)
|
||||
return errAddAuthDir
|
||||
}
|
||||
log.Debugf("watching auth directory: %s", w.authDir)
|
||||
|
||||
// Start the event processing goroutine
|
||||
go w.processEvents(ctx)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Stop stops the file watcher
|
||||
func (w *Watcher) Stop() error {
|
||||
return w.watcher.Close()
|
||||
}
|
||||
|
||||
// SetConfig updates the current configuration
|
||||
func (w *Watcher) SetConfig(cfg *config.Config) {
|
||||
w.clientsMutex.Lock()
|
||||
defer w.clientsMutex.Unlock()
|
||||
w.config = cfg
|
||||
}
|
||||
|
||||
// SetClients updates the current client list
|
||||
func (w *Watcher) SetClients(clients []*client.Client) {
|
||||
w.clientsMutex.Lock()
|
||||
defer w.clientsMutex.Unlock()
|
||||
w.clients = clients
|
||||
}
|
||||
|
||||
// processEvents handles file system events
|
||||
func (w *Watcher) processEvents(ctx context.Context) {
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case event, ok := <-w.watcher.Events:
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
w.handleEvent(event)
|
||||
case errWatch, ok := <-w.watcher.Errors:
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
log.Errorf("file watcher error: %v", errWatch)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// handleEvent processes individual file system events
|
||||
func (w *Watcher) handleEvent(event fsnotify.Event) {
|
||||
now := time.Now()
|
||||
|
||||
log.Debugf("file system event detected: %s %s", event.Op.String(), event.Name)
|
||||
|
||||
// Handle config file changes
|
||||
if event.Name == w.configPath && (event.Op&fsnotify.Write == fsnotify.Write || event.Op&fsnotify.Create == fsnotify.Create) {
|
||||
log.Infof("config file changed, reloading: %s", w.configPath)
|
||||
log.Debugf("config file change details - operation: %s, timestamp: %s", event.Op.String(), now.Format("2006-01-02 15:04:05.000"))
|
||||
w.reloadConfig()
|
||||
return
|
||||
}
|
||||
|
||||
// Handle auth directory changes (only for .json files)
|
||||
// Simplified: reload on any change to .json files in auth directory
|
||||
if strings.HasPrefix(event.Name, w.authDir) && strings.HasSuffix(event.Name, ".json") {
|
||||
log.Infof("auth file changed (%s): %s, reloading clients", event.Op.String(), filepath.Base(event.Name))
|
||||
log.Debugf("auth file change details - operation: %s, file: %s, timestamp: %s",
|
||||
event.Op.String(), filepath.Base(event.Name), now.Format("2006-01-02 15:04:05.000"))
|
||||
w.reloadClients()
|
||||
}
|
||||
}
|
||||
|
||||
// reloadConfig reloads the configuration and triggers a full reload
|
||||
func (w *Watcher) reloadConfig() {
|
||||
log.Debugf("starting config reload from: %s", w.configPath)
|
||||
|
||||
newConfig, errLoadConfig := config.LoadConfig(w.configPath)
|
||||
if errLoadConfig != nil {
|
||||
log.Errorf("failed to reload config: %v", errLoadConfig)
|
||||
return
|
||||
}
|
||||
|
||||
w.clientsMutex.Lock()
|
||||
oldConfig := w.config
|
||||
w.config = newConfig
|
||||
w.clientsMutex.Unlock()
|
||||
|
||||
// Log configuration changes in debug mode
|
||||
if oldConfig != nil {
|
||||
log.Debugf("config changes detected:")
|
||||
if oldConfig.Port != newConfig.Port {
|
||||
log.Debugf(" port: %d -> %d", oldConfig.Port, newConfig.Port)
|
||||
}
|
||||
if oldConfig.AuthDir != newConfig.AuthDir {
|
||||
log.Debugf(" auth-dir: %s -> %s", oldConfig.AuthDir, newConfig.AuthDir)
|
||||
}
|
||||
if oldConfig.Debug != newConfig.Debug {
|
||||
log.Debugf(" debug: %t -> %t", oldConfig.Debug, newConfig.Debug)
|
||||
}
|
||||
if oldConfig.ProxyURL != newConfig.ProxyURL {
|
||||
log.Debugf(" proxy-url: %s -> %s", oldConfig.ProxyURL, newConfig.ProxyURL)
|
||||
}
|
||||
if len(oldConfig.APIKeys) != len(newConfig.APIKeys) {
|
||||
log.Debugf(" api-keys count: %d -> %d", len(oldConfig.APIKeys), len(newConfig.APIKeys))
|
||||
}
|
||||
if len(oldConfig.GlAPIKey) != len(newConfig.GlAPIKey) {
|
||||
log.Debugf(" generative-language-api-key count: %d -> %d", len(oldConfig.GlAPIKey), len(newConfig.GlAPIKey))
|
||||
}
|
||||
}
|
||||
|
||||
log.Infof("config successfully reloaded, triggering client reload")
|
||||
// Reload clients with new config
|
||||
w.reloadClients()
|
||||
}
|
||||
|
||||
// reloadClients reloads all authentication clients
|
||||
func (w *Watcher) reloadClients() {
|
||||
log.Debugf("starting client reload process")
|
||||
|
||||
w.clientsMutex.RLock()
|
||||
cfg := w.config
|
||||
oldClientCount := len(w.clients)
|
||||
w.clientsMutex.RUnlock()
|
||||
|
||||
if cfg == nil {
|
||||
log.Error("config is nil, cannot reload clients")
|
||||
return
|
||||
}
|
||||
|
||||
log.Debugf("scanning auth directory: %s", cfg.AuthDir)
|
||||
|
||||
// Create new client list
|
||||
newClients := make([]*client.Client, 0)
|
||||
authFileCount := 0
|
||||
successfulAuthCount := 0
|
||||
|
||||
// Load clients from auth directory
|
||||
errWalk := filepath.Walk(cfg.AuthDir, func(path string, info fs.FileInfo, err error) error {
|
||||
if err != nil {
|
||||
log.Debugf("error accessing path %s: %v", path, err)
|
||||
return err
|
||||
}
|
||||
|
||||
// Process only JSON files in the auth directory
|
||||
if !info.IsDir() && strings.HasSuffix(info.Name(), ".json") {
|
||||
authFileCount++
|
||||
log.Debugf("processing auth file %d: %s", authFileCount, filepath.Base(path))
|
||||
|
||||
f, errOpen := os.Open(path)
|
||||
if errOpen != nil {
|
||||
log.Errorf("failed to open token file %s: %v", path, errOpen)
|
||||
return nil // Continue processing other files
|
||||
}
|
||||
defer func() {
|
||||
errClose := f.Close()
|
||||
if errClose != nil {
|
||||
log.Errorf("failed to close token file %s: %v", path, errClose)
|
||||
}
|
||||
}()
|
||||
|
||||
// Decode the token storage file
|
||||
var ts auth.TokenStorage
|
||||
if errDecode := json.NewDecoder(f).Decode(&ts); errDecode == nil {
|
||||
// For each valid token, create an authenticated client
|
||||
clientCtx := context.Background()
|
||||
log.Debugf(" initializing authentication for token from %s...", filepath.Base(path))
|
||||
httpClient, errGetClient := auth.GetAuthenticatedClient(clientCtx, &ts, cfg)
|
||||
if errGetClient != nil {
|
||||
log.Errorf(" failed to get authenticated client for token %s: %v", path, errGetClient)
|
||||
return nil // Continue processing other files
|
||||
}
|
||||
log.Debugf(" authentication successful for token from %s", filepath.Base(path))
|
||||
|
||||
// Add the new client to the pool
|
||||
cliClient := client.NewClient(httpClient, &ts, cfg)
|
||||
newClients = append(newClients, cliClient)
|
||||
successfulAuthCount++
|
||||
} else {
|
||||
log.Errorf(" failed to decode token file %s: %v", path, errDecode)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
})
|
||||
if errWalk != nil {
|
||||
log.Errorf("error walking auth directory: %v", errWalk)
|
||||
return
|
||||
}
|
||||
|
||||
log.Debugf("auth directory scan complete - found %d .json files, %d successful authentications", authFileCount, successfulAuthCount)
|
||||
|
||||
// Add clients for Generative Language API keys if configured
|
||||
glAPIKeyCount := 0
|
||||
if len(cfg.GlAPIKey) > 0 {
|
||||
log.Debugf("processing %d Generative Language API keys", len(cfg.GlAPIKey))
|
||||
for i := 0; i < len(cfg.GlAPIKey); i++ {
|
||||
httpClient, errSetProxy := util.SetProxy(cfg, &http.Client{})
|
||||
if errSetProxy != nil {
|
||||
log.Errorf("set proxy failed for GL API key %d: %v", i+1, errSetProxy)
|
||||
continue
|
||||
}
|
||||
|
||||
log.Debugf(" initializing with Generative Language API key %d...", i+1)
|
||||
cliClient := client.NewClient(httpClient, nil, cfg, cfg.GlAPIKey[i])
|
||||
newClients = append(newClients, cliClient)
|
||||
glAPIKeyCount++
|
||||
}
|
||||
log.Debugf("successfully initialized %d Generative Language API key clients", glAPIKeyCount)
|
||||
}
|
||||
|
||||
// Update the client list
|
||||
w.clientsMutex.Lock()
|
||||
w.clients = newClients
|
||||
w.clientsMutex.Unlock()
|
||||
|
||||
log.Infof("client reload complete - old: %d clients, new: %d clients (%d auth files + %d GL API keys)",
|
||||
oldClientCount, len(newClients), successfulAuthCount, glAPIKeyCount)
|
||||
|
||||
// Trigger the callback to update the server
|
||||
if w.reloadCallback != nil {
|
||||
log.Debugf("triggering server update callback")
|
||||
w.reloadCallback(newClients, cfg)
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user