mirror of
https://github.com/router-for-me/CLIProxyAPI.git
synced 2026-02-02 04:20:50 +08:00
Add openai codex support
This commit is contained in:
101
README.md
101
README.md
@@ -2,25 +2,31 @@
|
|||||||
|
|
||||||
English | [中文](README_CN.md)
|
English | [中文](README_CN.md)
|
||||||
|
|
||||||
A proxy server that provides an OpenAI/Gemini/Claude compatible API interface for CLI. This allows you to use CLI models with tools and libraries designed for the OpenAI/Gemini/Claude API.
|
A proxy server that provides OpenAI/Gemini/Claude compatible API interfaces for CLI.
|
||||||
|
|
||||||
|
It now also supports OpenAI Codex (GPT models) via OAuth.
|
||||||
|
|
||||||
|
so you can use local or multi‑account CLI access with OpenAI‑compatible clients and SDKs.
|
||||||
|
|
||||||
## Features
|
## Features
|
||||||
|
|
||||||
- OpenAI/Gemini/Claude compatible API endpoints for CLI models
|
- OpenAI/Gemini/Claude compatible API endpoints for CLI models
|
||||||
- Support for both streaming and non-streaming responses
|
- OpenAI Codex support (GPT models) via OAuth login
|
||||||
|
- Streaming and non-streaming responses
|
||||||
- Function calling/tools support
|
- Function calling/tools support
|
||||||
- Multimodal input support (text and images)
|
- Multimodal input support (text and images)
|
||||||
- Multiple account support with load balancing
|
- Multiple accounts with round‑robin load balancing (Gemini and OpenAI)
|
||||||
- Simple CLI authentication flow
|
- Simple CLI authentication flows (Gemini and OpenAI)
|
||||||
- Support for Generative Language API Key
|
- Generative Language API Key support
|
||||||
- Support Gemini CLI with multiple account load balancing
|
- Gemini CLI multi‑account load balancing
|
||||||
|
|
||||||
## Installation
|
## Installation
|
||||||
|
|
||||||
### Prerequisites
|
### Prerequisites
|
||||||
|
|
||||||
- Go 1.24 or higher
|
- Go 1.24 or higher
|
||||||
- A Google account with access to CLI models
|
- A Google account with access to Gemini CLI models (optional)
|
||||||
|
- An OpenAI account for Codex/GPT access (optional)
|
||||||
|
|
||||||
### Building from Source
|
### Building from Source
|
||||||
|
|
||||||
@@ -39,17 +45,23 @@ A proxy server that provides an OpenAI/Gemini/Claude compatible API interface fo
|
|||||||
|
|
||||||
### Authentication
|
### Authentication
|
||||||
|
|
||||||
Before using the API, you need to authenticate with your Google account:
|
You can authenticate for Gemini and/or OpenAI. Both can coexist in the same `auth-dir` and will be load balanced.
|
||||||
|
|
||||||
```bash
|
- Gemini (Google):
|
||||||
./cli-proxy-api --login
|
```bash
|
||||||
```
|
./cli-proxy-api --login
|
||||||
|
```
|
||||||
|
If you are an old gemini code user, you may need to specify a project ID:
|
||||||
|
```bash
|
||||||
|
./cli-proxy-api --login --project_id <your_project_id>
|
||||||
|
```
|
||||||
|
The local OAuth callback uses port `8085`.
|
||||||
|
|
||||||
If you are an old gemini code user, you may need to specify a project ID:
|
- OpenAI (Codex/GPT via OAuth):
|
||||||
|
```bash
|
||||||
```bash
|
./cli-proxy-api --codex-login
|
||||||
./cli-proxy-api --login --project_id <your_project_id>
|
```
|
||||||
```
|
Options: add `--no-browser` to print the login URL instead of opening a browser. The local OAuth callback uses port `1455`.
|
||||||
|
|
||||||
### Starting the Server
|
### Starting the Server
|
||||||
|
|
||||||
@@ -90,6 +102,15 @@ Request body example:
|
|||||||
}
|
}
|
||||||
```
|
```
|
||||||
|
|
||||||
|
Notes:
|
||||||
|
- Use a `gemini-*` model for Gemini (e.g., `gemini-2.5-pro`) or a `gpt-*` model for OpenAI (e.g., `gpt-5`). The proxy will route to the correct provider automatically.
|
||||||
|
|
||||||
|
#### Claude Messages (SSE-compatible)
|
||||||
|
|
||||||
|
```
|
||||||
|
POST http://localhost:8317/v1/messages
|
||||||
|
```
|
||||||
|
|
||||||
### Using with OpenAI Libraries
|
### Using with OpenAI Libraries
|
||||||
|
|
||||||
You can use this proxy with any OpenAI-compatible library by setting the base URL to your local server:
|
You can use this proxy with any OpenAI-compatible library by setting the base URL to your local server:
|
||||||
@@ -104,14 +125,19 @@ client = OpenAI(
|
|||||||
base_url="http://localhost:8317/v1"
|
base_url="http://localhost:8317/v1"
|
||||||
)
|
)
|
||||||
|
|
||||||
response = client.chat.completions.create(
|
# Gemini example
|
||||||
|
gemini = client.chat.completions.create(
|
||||||
model="gemini-2.5-pro",
|
model="gemini-2.5-pro",
|
||||||
messages=[
|
messages=[{"role": "user", "content": "Hello, how are you?"}]
|
||||||
{"role": "user", "content": "Hello, how are you?"}
|
|
||||||
]
|
|
||||||
)
|
)
|
||||||
|
|
||||||
print(response.choices[0].message.content)
|
# Codex/GPT example
|
||||||
|
gpt = client.chat.completions.create(
|
||||||
|
model="gpt-5",
|
||||||
|
messages=[{"role": "user", "content": "Summarize this project in one sentence."}]
|
||||||
|
)
|
||||||
|
print(gemini.choices[0].message.content)
|
||||||
|
print(gpt.choices[0].message.content)
|
||||||
```
|
```
|
||||||
|
|
||||||
#### JavaScript/TypeScript
|
#### JavaScript/TypeScript
|
||||||
@@ -124,28 +150,35 @@ const openai = new OpenAI({
|
|||||||
baseURL: 'http://localhost:8317/v1',
|
baseURL: 'http://localhost:8317/v1',
|
||||||
});
|
});
|
||||||
|
|
||||||
const response = await openai.chat.completions.create({
|
// Gemini
|
||||||
|
const gemini = await openai.chat.completions.create({
|
||||||
model: 'gemini-2.5-pro',
|
model: 'gemini-2.5-pro',
|
||||||
messages: [
|
messages: [{ role: 'user', content: 'Hello, how are you?' }],
|
||||||
{ role: 'user', content: 'Hello, how are you?' }
|
|
||||||
],
|
|
||||||
});
|
});
|
||||||
|
|
||||||
console.log(response.choices[0].message.content);
|
// Codex/GPT
|
||||||
|
const gpt = await openai.chat.completions.create({
|
||||||
|
model: 'gpt-5',
|
||||||
|
messages: [{ role: 'user', content: 'Summarize this project in one sentence.' }],
|
||||||
|
});
|
||||||
|
|
||||||
|
console.log(gemini.choices[0].message.content);
|
||||||
|
console.log(gpt.choices[0].message.content);
|
||||||
```
|
```
|
||||||
|
|
||||||
## Supported Models
|
## Supported Models
|
||||||
|
|
||||||
- gemini-2.5-pro
|
- gemini-2.5-pro
|
||||||
- gemini-2.5-flash
|
- gemini-2.5-flash
|
||||||
- And it automates switching to various preview versions
|
- gpt-5
|
||||||
|
- Gemini models auto‑switch to preview variants when needed
|
||||||
|
|
||||||
## Configuration
|
## Configuration
|
||||||
|
|
||||||
The server uses a YAML configuration file (`config.yaml`) located in the project root directory by default. You can specify a different configuration file path using the `--config` flag:
|
The server uses a YAML configuration file (`config.yaml`) located in the project root directory by default. You can specify a different configuration file path using the `--config` flag:
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
./cli-proxy --config /path/to/your/config.yaml
|
./cli-proxy-api --config /path/to/your/config.yaml
|
||||||
```
|
```
|
||||||
|
|
||||||
### Configuration Options
|
### Configuration Options
|
||||||
@@ -211,6 +244,10 @@ Authorization: Bearer your-api-key-1
|
|||||||
|
|
||||||
The `generative-language-api-key` parameter allows you to define a list of API keys that can be used to authenticate requests to the official Generative Language API.
|
The `generative-language-api-key` parameter allows you to define a list of API keys that can be used to authenticate requests to the official Generative Language API.
|
||||||
|
|
||||||
|
## Hot Reloading
|
||||||
|
|
||||||
|
The server watches the config file and the `auth-dir` for changes and reloads clients and settings automatically. You can add or remove Gemini/OpenAI token JSON files while the server is running; no restart is required.
|
||||||
|
|
||||||
## Gemini CLI with multiple account load balancing
|
## Gemini CLI with multiple account load balancing
|
||||||
|
|
||||||
Start CLI Proxy API server, and then set the `CODE_ASSIST_ENDPOINT` environment variable to the URL of the CLI Proxy API server.
|
Start CLI Proxy API server, and then set the `CODE_ASSIST_ENDPOINT` environment variable to the URL of the CLI Proxy API server.
|
||||||
@@ -227,12 +264,18 @@ The server will relay the `loadCodeAssist`, `onboardUser`, and `countTokens` req
|
|||||||
|
|
||||||
## Run with Docker
|
## Run with Docker
|
||||||
|
|
||||||
Run the following command to login:
|
Run the following command to login (Gemini OAuth on port 8085):
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
docker run --rm -p 8085:8085 -v /path/to/your/config.yaml:/CLIProxyAPI/config.yaml -v /path/to/your/auth-dir:/root/.cli-proxy-api eceasy/cli-proxy-api:latest /CLIProxyAPI/CLIProxyAPI --login
|
docker run --rm -p 8085:8085 -v /path/to/your/config.yaml:/CLIProxyAPI/config.yaml -v /path/to/your/auth-dir:/root/.cli-proxy-api eceasy/cli-proxy-api:latest /CLIProxyAPI/CLIProxyAPI --login
|
||||||
```
|
```
|
||||||
|
|
||||||
|
Run the following command to login (OpenAI OAuth on port 1455):
|
||||||
|
|
||||||
|
```bash
|
||||||
|
docker run --rm -p 1455:1455 -v /path/to/your/config.yaml:/CLIProxyAPI/config.yaml -v /path/to/your/auth-dir:/root/.cli-proxy-api eceasy/cli-proxy-api:latest /CLIProxyAPI/CLIProxyAPI --codex-login
|
||||||
|
```
|
||||||
|
|
||||||
Run the following command to start the server:
|
Run the following command to start the server:
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
|
|||||||
102
README_CN.md
102
README_CN.md
@@ -2,16 +2,21 @@
|
|||||||
|
|
||||||
[English](README.md) | 中文
|
[English](README.md) | 中文
|
||||||
|
|
||||||
一个为 CLI 提供 OpenAI/Gemini/Claude 兼容 API 接口的代理服务器。这让您可以摆脱终端界面的束缚,将 Gemini 的强大能力以 API 的形式轻松接入到任何您喜爱的客户端或应用中。
|
一个为 CLI 提供 OpenAI/Gemini/Claude 兼容 API 接口的代理服务器。
|
||||||
|
|
||||||
|
现已支持通过 OAuth 登录接入 OpenAI Codex(GPT 系列)。
|
||||||
|
|
||||||
|
可与本地或多账户方式配合,使用任何 OpenAI 兼容的客户端与 SDK。
|
||||||
|
|
||||||
## 功能特性
|
## 功能特性
|
||||||
|
|
||||||
- 为 CLI 模型提供 OpenAI/Gemini/Claude 兼容的 API 端点
|
- 为 CLI 模型提供 OpenAI/Gemini/Claude 兼容的 API 端点
|
||||||
- 支持流式和非流式响应
|
- 新增 OpenAI Codex(GPT 系列)支持(OAuth 登录)
|
||||||
|
- 支持流式与非流式响应
|
||||||
- 函数调用/工具支持
|
- 函数调用/工具支持
|
||||||
- 多模态输入支持(文本和图像)
|
- 多模态输入(文本、图片)
|
||||||
- 多账户支持与负载均衡
|
- 多账户支持与轮询负载均衡(Gemini 与 OpenAI)
|
||||||
- 简单的 CLI 身份验证流程
|
- 简单的 CLI 身份验证流程(Gemini 与 OpenAI)
|
||||||
- 支持 Gemini AIStudio API 密钥
|
- 支持 Gemini AIStudio API 密钥
|
||||||
- 支持 Gemini CLI 多账户轮询
|
- 支持 Gemini CLI 多账户轮询
|
||||||
|
|
||||||
@@ -20,7 +25,8 @@
|
|||||||
### 前置要求
|
### 前置要求
|
||||||
|
|
||||||
- Go 1.24 或更高版本
|
- Go 1.24 或更高版本
|
||||||
- 有权访问 CLI 模型的 Google 账户
|
- 有权访问 Gemini CLI 模型的 Google 账户(可选)
|
||||||
|
- 有权访问 OpenAI Codex/GPT 的 OpenAI 账户(可选)
|
||||||
|
|
||||||
### 从源码构建
|
### 从源码构建
|
||||||
|
|
||||||
@@ -39,17 +45,23 @@
|
|||||||
|
|
||||||
### 身份验证
|
### 身份验证
|
||||||
|
|
||||||
在使用 API 之前,您需要使用 Google 账户进行身份验证:
|
您可以分别为 Gemini 和 OpenAI 进行身份验证,二者可同时存在于同一个 `auth-dir` 中并参与负载均衡。
|
||||||
|
|
||||||
```bash
|
- Gemini(Google):
|
||||||
./cli-proxy-api --login
|
```bash
|
||||||
```
|
./cli-proxy-api --login
|
||||||
|
```
|
||||||
|
如果您是旧版 gemini code 用户,可能需要指定项目 ID:
|
||||||
|
```bash
|
||||||
|
./cli-proxy-api --login --project_id <your_project_id>
|
||||||
|
```
|
||||||
|
本地 OAuth 回调端口为 `8085`。
|
||||||
|
|
||||||
如果您是旧版 gemini code 用户,可能需要指定项目 ID:
|
- OpenAI(Codex/GPT,OAuth):
|
||||||
|
```bash
|
||||||
```bash
|
./cli-proxy-api --codex-login
|
||||||
./cli-proxy-api --login --project_id <your_project_id>
|
```
|
||||||
```
|
选项:加上 `--no-browser` 可打印登录地址而不自动打开浏览器。本地 OAuth 回调端口为 `1455`。
|
||||||
|
|
||||||
### 启动服务器
|
### 启动服务器
|
||||||
|
|
||||||
@@ -90,6 +102,15 @@ POST http://localhost:8317/v1/chat/completions
|
|||||||
}
|
}
|
||||||
```
|
```
|
||||||
|
|
||||||
|
说明:
|
||||||
|
- 使用 `gemini-*` 模型(如 `gemini-2.5-pro`)走 Gemini,使用 `gpt-*` 模型(如 `gpt-5`)走 OpenAI,服务会自动路由到对应提供商。
|
||||||
|
|
||||||
|
#### Claude 消息(SSE 兼容)
|
||||||
|
|
||||||
|
```
|
||||||
|
POST http://localhost:8317/v1/messages
|
||||||
|
```
|
||||||
|
|
||||||
### 与 OpenAI 库一起使用
|
### 与 OpenAI 库一起使用
|
||||||
|
|
||||||
您可以通过将基础 URL 设置为本地服务器来将此代理与任何 OpenAI 兼容的库一起使用:
|
您可以通过将基础 URL 设置为本地服务器来将此代理与任何 OpenAI 兼容的库一起使用:
|
||||||
@@ -104,14 +125,20 @@ client = OpenAI(
|
|||||||
base_url="http://localhost:8317/v1"
|
base_url="http://localhost:8317/v1"
|
||||||
)
|
)
|
||||||
|
|
||||||
response = client.chat.completions.create(
|
# Gemini 示例
|
||||||
|
gemini = client.chat.completions.create(
|
||||||
model="gemini-2.5-pro",
|
model="gemini-2.5-pro",
|
||||||
messages=[
|
messages=[{"role": "user", "content": "你好,你好吗?"}]
|
||||||
{"role": "user", "content": "你好,你好吗?"}
|
|
||||||
]
|
|
||||||
)
|
)
|
||||||
|
|
||||||
print(response.choices[0].message.content)
|
# Codex/GPT 示例
|
||||||
|
gpt = client.chat.completions.create(
|
||||||
|
model="gpt-5",
|
||||||
|
messages=[{"role": "user", "content": "用一句话总结这个项目"}]
|
||||||
|
)
|
||||||
|
|
||||||
|
print(gemini.choices[0].message.content)
|
||||||
|
print(gpt.choices[0].message.content)
|
||||||
```
|
```
|
||||||
|
|
||||||
#### JavaScript/TypeScript
|
#### JavaScript/TypeScript
|
||||||
@@ -124,28 +151,35 @@ const openai = new OpenAI({
|
|||||||
baseURL: 'http://localhost:8317/v1',
|
baseURL: 'http://localhost:8317/v1',
|
||||||
});
|
});
|
||||||
|
|
||||||
const response = await openai.chat.completions.create({
|
// Gemini
|
||||||
|
const gemini = await openai.chat.completions.create({
|
||||||
model: 'gemini-2.5-pro',
|
model: 'gemini-2.5-pro',
|
||||||
messages: [
|
messages: [{ role: 'user', content: '你好,你好吗?' }],
|
||||||
{ role: 'user', content: '你好,你好吗?' }
|
|
||||||
],
|
|
||||||
});
|
});
|
||||||
|
|
||||||
console.log(response.choices[0].message.content);
|
// Codex/GPT
|
||||||
|
const gpt = await openai.chat.completions.create({
|
||||||
|
model: 'gpt-5',
|
||||||
|
messages: [{ role: 'user', content: '用一句话总结这个项目' }],
|
||||||
|
});
|
||||||
|
|
||||||
|
console.log(gemini.choices[0].message.content);
|
||||||
|
console.log(gpt.choices[0].message.content);
|
||||||
```
|
```
|
||||||
|
|
||||||
## 支持的模型
|
## 支持的模型
|
||||||
|
|
||||||
- gemini-2.5-pro
|
- gemini-2.5-pro
|
||||||
- gemini-2.5-flash
|
- gemini-2.5-flash
|
||||||
- 并且自动切换到之前的预览版本
|
- gpt-5
|
||||||
|
- Gemini 模型在需要时自动切换到对应的 preview 版本
|
||||||
|
|
||||||
## 配置
|
## 配置
|
||||||
|
|
||||||
服务器默认使用位于项目根目录的 YAML 配置文件(`config.yaml`)。您可以使用 `--config` 标志指定不同的配置文件路径:
|
服务器默认使用位于项目根目录的 YAML 配置文件(`config.yaml`)。您可以使用 `--config` 标志指定不同的配置文件路径:
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
./cli-proxy --config /path/to/your/config.yaml
|
./cli-proxy-api --config /path/to/your/config.yaml
|
||||||
```
|
```
|
||||||
|
|
||||||
### 配置选项
|
### 配置选项
|
||||||
@@ -211,6 +245,10 @@ Authorization: Bearer your-api-key-1
|
|||||||
|
|
||||||
`generative-language-api-key` 参数允许您定义可用于验证对官方 AIStudio Gemini API 请求的 API 密钥列表。
|
`generative-language-api-key` 参数允许您定义可用于验证对官方 AIStudio Gemini API 请求的 API 密钥列表。
|
||||||
|
|
||||||
|
## 热更新
|
||||||
|
|
||||||
|
服务会监听配置文件与 `auth-dir` 目录的变化并自动重新加载客户端与配置。您可以在运行中新增/移除 Gemini/OpenAI 的令牌 JSON 文件,无需重启服务。
|
||||||
|
|
||||||
## Gemini CLI 多账户负载均衡
|
## Gemini CLI 多账户负载均衡
|
||||||
|
|
||||||
启动 CLI 代理 API 服务器,然后将 `CODE_ASSIST_ENDPOINT` 环境变量设置为 CLI 代理 API 服务器的 URL。
|
启动 CLI 代理 API 服务器,然后将 `CODE_ASSIST_ENDPOINT` 环境变量设置为 CLI 代理 API 服务器的 URL。
|
||||||
@@ -227,12 +265,18 @@ export CODE_ASSIST_ENDPOINT="http://127.0.0.1:8317"
|
|||||||
|
|
||||||
## 使用 Docker 运行
|
## 使用 Docker 运行
|
||||||
|
|
||||||
运行以下命令进行登录:
|
运行以下命令进行登录(Gemini OAuth,端口 8085):
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
docker run --rm -p 8085:8085 -v /path/to/your/config.yaml:/CLIProxyAPI/config.yaml -v /path/to/your/auth-dir:/root/.cli-proxy-api eceasy/cli-proxy-api:latest /CLIProxyAPI/CLIProxyAPI --login
|
docker run --rm -p 8085:8085 -v /path/to/your/config.yaml:/CLIProxyAPI/config.yaml -v /path/to/your/auth-dir:/root/.cli-proxy-api eceasy/cli-proxy-api:latest /CLIProxyAPI/CLIProxyAPI --login
|
||||||
```
|
```
|
||||||
|
|
||||||
|
运行以下命令进行登录(OpenAI OAuth,端口 1455):
|
||||||
|
|
||||||
|
```bash
|
||||||
|
docker run --rm -p 1455:1455 -v /path/to/your/config.yaml:/CLIProxyAPI/config.yaml -v /path/to/your/auth-dir:/root/.cli-proxy-api eceasy/cli-proxy-api:latest /CLIProxyAPI/CLIProxyAPI --codex-login
|
||||||
|
```
|
||||||
|
|
||||||
运行以下命令启动服务器:
|
运行以下命令启动服务器:
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
@@ -251,4 +295,4 @@ docker run --rm -p 8317:8317 -v /path/to/your/config.yaml:/CLIProxyAPI/config.ya
|
|||||||
|
|
||||||
## 许可证
|
## 许可证
|
||||||
|
|
||||||
此项目根据 MIT 许可证授权 - 有关详细信息,请参阅 [LICENSE](LICENSE) 文件。
|
此项目根据 MIT 许可证授权 - 有关详细信息,请参阅 [LICENSE](LICENSE) 文件。
|
||||||
|
|||||||
@@ -7,19 +7,23 @@ import (
|
|||||||
"bytes"
|
"bytes"
|
||||||
"flag"
|
"flag"
|
||||||
"fmt"
|
"fmt"
|
||||||
"github.com/luispater/CLIProxyAPI/internal/cmd"
|
|
||||||
"github.com/luispater/CLIProxyAPI/internal/config"
|
|
||||||
log "github.com/sirupsen/logrus"
|
|
||||||
"os"
|
"os"
|
||||||
"path"
|
"path"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
|
"github.com/luispater/CLIProxyAPI/internal/cmd"
|
||||||
|
"github.com/luispater/CLIProxyAPI/internal/config"
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
)
|
)
|
||||||
|
|
||||||
// LogFormatter defines a custom log format for logrus.
|
// LogFormatter defines a custom log format for logrus.
|
||||||
|
// This formatter adds timestamp, log level, and source location information
|
||||||
|
// to each log entry for better debugging and monitoring.
|
||||||
type LogFormatter struct {
|
type LogFormatter struct {
|
||||||
}
|
}
|
||||||
|
|
||||||
// Format renders a single log entry.
|
// Format renders a single log entry with custom formatting.
|
||||||
|
// It includes timestamp, log level, source file and line number, and the log message.
|
||||||
func (m *LogFormatter) Format(entry *log.Entry) ([]byte, error) {
|
func (m *LogFormatter) Format(entry *log.Entry) ([]byte, error) {
|
||||||
var b *bytes.Buffer
|
var b *bytes.Buffer
|
||||||
if entry.Buffer != nil {
|
if entry.Buffer != nil {
|
||||||
@@ -38,6 +42,8 @@ func (m *LogFormatter) Format(entry *log.Entry) ([]byte, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// init initializes the logger configuration.
|
// init initializes the logger configuration.
|
||||||
|
// It sets up the custom log formatter, enables caller reporting,
|
||||||
|
// and configures the log output destination.
|
||||||
func init() {
|
func init() {
|
||||||
// Set logger output to standard output.
|
// Set logger output to standard output.
|
||||||
log.SetOutput(os.Stdout)
|
log.SetOutput(os.Stdout)
|
||||||
@@ -48,14 +54,20 @@ func init() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// main is the entry point of the application.
|
// main is the entry point of the application.
|
||||||
|
// It parses command-line flags, loads configuration, and starts the appropriate
|
||||||
|
// service based on the provided flags (login, codex-login, or server mode).
|
||||||
func main() {
|
func main() {
|
||||||
var login bool
|
var login bool
|
||||||
|
var codexLogin bool
|
||||||
|
var noBrowser bool
|
||||||
var projectID string
|
var projectID string
|
||||||
var configPath string
|
var configPath string
|
||||||
|
|
||||||
// Define command-line flags.
|
// Define command-line flags for different operation modes.
|
||||||
flag.BoolVar(&login, "login", false, "Login Google Account")
|
flag.BoolVar(&login, "login", false, "Login Google Account")
|
||||||
flag.StringVar(&projectID, "project_id", "", "Project ID")
|
flag.BoolVar(&codexLogin, "codex-login", false, "Login to Codex using OAuth")
|
||||||
|
flag.BoolVar(&noBrowser, "no-browser", false, "Don't open browser automatically for OAuth")
|
||||||
|
flag.StringVar(&projectID, "project_id", "", "Project ID (Gemini only, not required)")
|
||||||
flag.StringVar(&configPath, "config", "", "Configure File Path")
|
flag.StringVar(&configPath, "config", "", "Configure File Path")
|
||||||
|
|
||||||
// Parse the command-line flags.
|
// Parse the command-line flags.
|
||||||
@@ -104,10 +116,19 @@ func main() {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Either perform login or start the service based on the 'login' flag.
|
// Handle different command modes based on the provided flags.
|
||||||
|
options := &cmd.LoginOptions{
|
||||||
|
NoBrowser: noBrowser,
|
||||||
|
}
|
||||||
|
|
||||||
if login {
|
if login {
|
||||||
cmd.DoLogin(cfg, projectID)
|
// Handle Google/Gemini login
|
||||||
|
cmd.DoLogin(cfg, projectID, options)
|
||||||
|
} else if codexLogin {
|
||||||
|
// Handle Codex login
|
||||||
|
cmd.DoCodexLogin(cfg, options)
|
||||||
} else {
|
} else {
|
||||||
|
// Start the main proxy service
|
||||||
cmd.StartService(cfg, configFilePath)
|
cmd.StartService(cfg, configFilePath)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,15 +1,22 @@
|
|||||||
|
# Server configuration
|
||||||
port: 8317
|
port: 8317
|
||||||
auth-dir: "~/.cli-proxy-api"
|
auth-dir: "~/.cli-proxy-api"
|
||||||
debug: true
|
debug: true
|
||||||
proxy-url: ""
|
proxy-url: ""
|
||||||
|
|
||||||
|
# Quota exceeded behavior
|
||||||
quota-exceeded:
|
quota-exceeded:
|
||||||
switch-project: true
|
switch-project: true
|
||||||
switch-preview-model: true
|
switch-preview-model: true
|
||||||
|
|
||||||
|
# API keys for client authentication
|
||||||
api-keys:
|
api-keys:
|
||||||
- "12345"
|
- "12345"
|
||||||
- "23456"
|
- "23456"
|
||||||
|
|
||||||
|
# Generative language API keys
|
||||||
generative-language-api-key:
|
generative-language-api-key:
|
||||||
- "AIzaSy...01"
|
- "AIzaSy...01"
|
||||||
- "AIzaSy...02"
|
- "AIzaSy...02"
|
||||||
- "AIzaSy...03"
|
- "AIzaSy...03"
|
||||||
- "AIzaSy...04"
|
- "AIzaSy...04"
|
||||||
3
go.mod
3
go.mod
@@ -3,7 +3,9 @@ module github.com/luispater/CLIProxyAPI
|
|||||||
go 1.24
|
go 1.24
|
||||||
|
|
||||||
require (
|
require (
|
||||||
|
github.com/fsnotify/fsnotify v1.9.0
|
||||||
github.com/gin-gonic/gin v1.10.1
|
github.com/gin-gonic/gin v1.10.1
|
||||||
|
github.com/google/uuid v1.6.0
|
||||||
github.com/sirupsen/logrus v1.9.3
|
github.com/sirupsen/logrus v1.9.3
|
||||||
github.com/skratchdot/open-golang v0.0.0-20200116055534-eef842397966
|
github.com/skratchdot/open-golang v0.0.0-20200116055534-eef842397966
|
||||||
github.com/tidwall/gjson v1.18.0
|
github.com/tidwall/gjson v1.18.0
|
||||||
@@ -19,7 +21,6 @@ require (
|
|||||||
github.com/bytedance/sonic/loader v0.1.1 // indirect
|
github.com/bytedance/sonic/loader v0.1.1 // indirect
|
||||||
github.com/cloudwego/base64x v0.1.4 // indirect
|
github.com/cloudwego/base64x v0.1.4 // indirect
|
||||||
github.com/cloudwego/iasm v0.2.0 // indirect
|
github.com/cloudwego/iasm v0.2.0 // indirect
|
||||||
github.com/fsnotify/fsnotify v1.9.0 // indirect
|
|
||||||
github.com/gabriel-vasile/mimetype v1.4.3 // indirect
|
github.com/gabriel-vasile/mimetype v1.4.3 // indirect
|
||||||
github.com/gin-contrib/sse v0.1.0 // indirect
|
github.com/gin-contrib/sse v0.1.0 // indirect
|
||||||
github.com/go-playground/locales v0.14.1 // indirect
|
github.com/go-playground/locales v0.14.1 // indirect
|
||||||
|
|||||||
2
go.sum
2
go.sum
@@ -32,6 +32,8 @@ github.com/goccy/go-json v0.10.2/go.mod h1:6MelG93GURQebXPDq3khkgXZkazVtN9CRI+MG
|
|||||||
github.com/google/go-cmp v0.5.5 h1:Khx7svrCpmxxtHBq5j2mp/xVjsi8hQMfNLvJFAlrGgU=
|
github.com/google/go-cmp v0.5.5 h1:Khx7svrCpmxxtHBq5j2mp/xVjsi8hQMfNLvJFAlrGgU=
|
||||||
github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
|
github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
|
||||||
github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg=
|
github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg=
|
||||||
|
github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
|
||||||
|
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
|
||||||
github.com/json-iterator/go v1.1.12 h1:PV8peI4a0ysnczrg+LtxykD8LfKY9ML6u2jnxaEnrnM=
|
github.com/json-iterator/go v1.1.12 h1:PV8peI4a0ysnczrg+LtxykD8LfKY9ML6u2jnxaEnrnM=
|
||||||
github.com/json-iterator/go v1.1.12/go.mod h1:e30LSqwooZae/UwlEbR2852Gd8hjQvJoHmT4TnhNGBo=
|
github.com/json-iterator/go v1.1.12/go.mod h1:e30LSqwooZae/UwlEbR2852Gd8hjQvJoHmT4TnhNGBo=
|
||||||
github.com/klauspost/cpuid/v2 v2.0.9/go.mod h1:FInQzS24/EEf25PyTYn52gqo7WaD8xa0213Md/qVLRg=
|
github.com/klauspost/cpuid/v2 v2.0.9/go.mod h1:FInQzS24/EEf25PyTYn52gqo7WaD8xa0213Md/qVLRg=
|
||||||
|
|||||||
@@ -1,208 +0,0 @@
|
|||||||
// Package claude provides HTTP handlers for Claude API code-related functionality.
|
|
||||||
// This package implements Claude-compatible streaming chat completions with sophisticated
|
|
||||||
// client rotation and quota management systems to ensure high availability and optimal
|
|
||||||
// resource utilization across multiple backend clients. It handles request translation
|
|
||||||
// between Claude API format and the underlying Gemini backend, providing seamless
|
|
||||||
// API compatibility while maintaining robust error handling and connection management.
|
|
||||||
package claude
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"fmt"
|
|
||||||
"github.com/gin-gonic/gin"
|
|
||||||
"github.com/luispater/CLIProxyAPI/internal/api/handlers"
|
|
||||||
"github.com/luispater/CLIProxyAPI/internal/api/translator/claude/code"
|
|
||||||
"github.com/luispater/CLIProxyAPI/internal/client"
|
|
||||||
log "github.com/sirupsen/logrus"
|
|
||||||
"net/http"
|
|
||||||
"strings"
|
|
||||||
"time"
|
|
||||||
)
|
|
||||||
|
|
||||||
// ClaudeCodeAPIHandlers contains the handlers for Claude API endpoints.
|
|
||||||
// It holds a pool of clients to interact with the backend service.
|
|
||||||
type ClaudeCodeAPIHandlers struct {
|
|
||||||
*handlers.APIHandlers
|
|
||||||
}
|
|
||||||
|
|
||||||
// NewClaudeCodeAPIHandlers creates a new Claude API handlers instance.
|
|
||||||
// It takes an APIHandlers instance as input and returns a ClaudeCodeAPIHandlers.
|
|
||||||
func NewClaudeCodeAPIHandlers(apiHandlers *handlers.APIHandlers) *ClaudeCodeAPIHandlers {
|
|
||||||
return &ClaudeCodeAPIHandlers{
|
|
||||||
APIHandlers: apiHandlers,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// ClaudeMessages handles Claude-compatible streaming chat completions.
|
|
||||||
// This function implements a sophisticated client rotation and quota management system
|
|
||||||
// to ensure high availability and optimal resource utilization across multiple backend clients.
|
|
||||||
func (h *ClaudeCodeAPIHandlers) ClaudeMessages(c *gin.Context) {
|
|
||||||
// Extract raw JSON data from the incoming request
|
|
||||||
rawJSON, err := c.GetRawData()
|
|
||||||
// If data retrieval fails, return a 400 Bad Request error.
|
|
||||||
if err != nil {
|
|
||||||
c.JSON(http.StatusBadRequest, handlers.ErrorResponse{
|
|
||||||
Error: handlers.ErrorDetail{
|
|
||||||
Message: fmt.Sprintf("Invalid request: %v", err),
|
|
||||||
Type: "invalid_request_error",
|
|
||||||
},
|
|
||||||
})
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// Set up Server-Sent Events (SSE) headers for streaming response
|
|
||||||
// These headers are essential for maintaining a persistent connection
|
|
||||||
// and enabling real-time streaming of chat completions
|
|
||||||
c.Header("Content-Type", "text/event-stream")
|
|
||||||
c.Header("Cache-Control", "no-cache")
|
|
||||||
c.Header("Connection", "keep-alive")
|
|
||||||
c.Header("Access-Control-Allow-Origin", "*")
|
|
||||||
|
|
||||||
// Get the http.Flusher interface to manually flush the response.
|
|
||||||
// This is crucial for streaming as it allows immediate sending of data chunks
|
|
||||||
flusher, ok := c.Writer.(http.Flusher)
|
|
||||||
if !ok {
|
|
||||||
c.JSON(http.StatusInternalServerError, handlers.ErrorResponse{
|
|
||||||
Error: handlers.ErrorDetail{
|
|
||||||
Message: "Streaming not supported",
|
|
||||||
Type: "server_error",
|
|
||||||
},
|
|
||||||
})
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// Parse and prepare the Claude request, extracting model name, system instructions,
|
|
||||||
// conversation contents, and available tools from the raw JSON
|
|
||||||
modelName, systemInstruction, contents, tools := code.PrepareClaudeRequest(rawJSON)
|
|
||||||
|
|
||||||
// Map Claude model names to corresponding Gemini models
|
|
||||||
// This allows the proxy to handle Claude API calls using Gemini backend
|
|
||||||
if modelName == "claude-sonnet-4-20250514" {
|
|
||||||
modelName = "gemini-2.5-pro"
|
|
||||||
} else if modelName == "claude-3-5-haiku-20241022" {
|
|
||||||
modelName = "gemini-2.5-flash"
|
|
||||||
}
|
|
||||||
|
|
||||||
// Create a cancellable context for the backend client request
|
|
||||||
// This allows proper cleanup and cancellation of ongoing requests
|
|
||||||
cliCtx, cliCancel := context.WithCancel(context.Background())
|
|
||||||
var cliClient *client.Client
|
|
||||||
defer func() {
|
|
||||||
// Ensure the client's mutex is unlocked on function exit.
|
|
||||||
// This prevents deadlocks and ensures proper resource cleanup
|
|
||||||
if cliClient != nil {
|
|
||||||
cliClient.RequestMutex.Unlock()
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
|
|
||||||
// Main client rotation loop with quota management
|
|
||||||
// This loop implements a sophisticated load balancing and failover mechanism
|
|
||||||
outLoop:
|
|
||||||
for {
|
|
||||||
var errorResponse *client.ErrorMessage
|
|
||||||
cliClient, errorResponse = h.GetClient(modelName)
|
|
||||||
if errorResponse != nil {
|
|
||||||
c.Status(errorResponse.StatusCode)
|
|
||||||
_, _ = fmt.Fprint(c.Writer, errorResponse.Error)
|
|
||||||
flusher.Flush()
|
|
||||||
cliCancel()
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// Determine the authentication method being used by the selected client
|
|
||||||
// This affects how responses are formatted and logged
|
|
||||||
isGlAPIKey := false
|
|
||||||
if glAPIKey := cliClient.GetGenerativeLanguageAPIKey(); glAPIKey != "" {
|
|
||||||
log.Debugf("Request use generative language API Key: %s", glAPIKey)
|
|
||||||
isGlAPIKey = true
|
|
||||||
} else {
|
|
||||||
log.Debugf("Request use account: %s, project id: %s", cliClient.GetEmail(), cliClient.GetProjectID())
|
|
||||||
}
|
|
||||||
// Initiate streaming communication with the backend client
|
|
||||||
// This returns two channels: one for response chunks and one for errors
|
|
||||||
|
|
||||||
includeThoughts := false
|
|
||||||
if userAgent, hasKey := c.Request.Header["User-Agent"]; hasKey {
|
|
||||||
includeThoughts = !strings.Contains(userAgent[0], "claude-cli")
|
|
||||||
}
|
|
||||||
|
|
||||||
respChan, errChan := cliClient.SendMessageStream(cliCtx, rawJSON, modelName, systemInstruction, contents, tools, includeThoughts)
|
|
||||||
|
|
||||||
// Track response state for proper Claude format conversion
|
|
||||||
hasFirstResponse := false
|
|
||||||
responseType := 0
|
|
||||||
responseIndex := 0
|
|
||||||
|
|
||||||
// Main streaming loop - handles multiple concurrent events using Go channels
|
|
||||||
// This select statement manages four different types of events simultaneously
|
|
||||||
for {
|
|
||||||
select {
|
|
||||||
// Case 1: Handle client disconnection
|
|
||||||
// Detects when the HTTP client has disconnected and cleans up resources
|
|
||||||
case <-c.Request.Context().Done():
|
|
||||||
if c.Request.Context().Err().Error() == "context canceled" {
|
|
||||||
log.Debugf("Client disconnected: %v", c.Request.Context().Err())
|
|
||||||
cliCancel() // Cancel the backend request to prevent resource leaks
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// Case 2: Process incoming response chunks from the backend
|
|
||||||
// This handles the actual streaming data from the AI model
|
|
||||||
case chunk, okStream := <-respChan:
|
|
||||||
if !okStream {
|
|
||||||
// Stream has ended - send the final message_stop event
|
|
||||||
// This follows the Claude API specification for stream termination
|
|
||||||
_, _ = c.Writer.Write([]byte(`event: message_stop`))
|
|
||||||
_, _ = c.Writer.Write([]byte("\n"))
|
|
||||||
_, _ = c.Writer.Write([]byte(`data: {"type":"message_stop"}`))
|
|
||||||
_, _ = c.Writer.Write([]byte("\n\n\n"))
|
|
||||||
|
|
||||||
flusher.Flush()
|
|
||||||
cliCancel()
|
|
||||||
return
|
|
||||||
}
|
|
||||||
// Convert the backend response to Claude-compatible format
|
|
||||||
// This translation layer ensures API compatibility
|
|
||||||
claudeFormat := code.ConvertCliToClaude(chunk, isGlAPIKey, hasFirstResponse, &responseType, &responseIndex)
|
|
||||||
if claudeFormat != "" {
|
|
||||||
_, _ = c.Writer.Write([]byte(claudeFormat))
|
|
||||||
flusher.Flush() // Immediately send the chunk to the client
|
|
||||||
}
|
|
||||||
hasFirstResponse = true
|
|
||||||
|
|
||||||
// Case 3: Handle errors from the backend
|
|
||||||
// This manages various error conditions and implements retry logic
|
|
||||||
case errInfo, okError := <-errChan:
|
|
||||||
if okError {
|
|
||||||
// Special handling for quota exceeded errors
|
|
||||||
// If configured, attempt to switch to a different project/client
|
|
||||||
if errInfo.StatusCode == 429 && h.Cfg.QuotaExceeded.SwitchProject {
|
|
||||||
continue outLoop // Restart the client selection process
|
|
||||||
} else {
|
|
||||||
// Forward other errors directly to the client
|
|
||||||
c.Status(errInfo.StatusCode)
|
|
||||||
_, _ = fmt.Fprint(c.Writer, errInfo.Error.Error())
|
|
||||||
flusher.Flush()
|
|
||||||
cliCancel()
|
|
||||||
}
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// Case 4: Send periodic keep-alive signals
|
|
||||||
// Prevents connection timeouts during long-running requests
|
|
||||||
case <-time.After(500 * time.Millisecond):
|
|
||||||
if hasFirstResponse {
|
|
||||||
// Send a ping event to maintain the connection
|
|
||||||
// This is especially important for slow AI model responses
|
|
||||||
output := "event: ping\n"
|
|
||||||
output = output + `data: {"type": "ping"}`
|
|
||||||
output = output + "\n\n\n"
|
|
||||||
_, _ = c.Writer.Write([]byte(output))
|
|
||||||
|
|
||||||
flusher.Flush()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
|
||||||
382
internal/api/handlers/claude/code_handlers.go
Normal file
382
internal/api/handlers/claude/code_handlers.go
Normal file
@@ -0,0 +1,382 @@
|
|||||||
|
// Package claude provides HTTP handlers for Claude API code-related functionality.
|
||||||
|
// This package implements Claude-compatible streaming chat completions with sophisticated
|
||||||
|
// client rotation and quota management systems to ensure high availability and optimal
|
||||||
|
// resource utilization across multiple backend clients. It handles request translation
|
||||||
|
// between Claude API format and the underlying Gemini backend, providing seamless
|
||||||
|
// API compatibility while maintaining robust error handling and connection management.
|
||||||
|
package claude
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"net/http"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
"github.com/luispater/CLIProxyAPI/internal/api/handlers"
|
||||||
|
"github.com/luispater/CLIProxyAPI/internal/client"
|
||||||
|
translatorClaudeCodeToCodex "github.com/luispater/CLIProxyAPI/internal/translator/codex/claude/code"
|
||||||
|
translatorClaudeCodeToGeminiCli "github.com/luispater/CLIProxyAPI/internal/translator/gemini-cli/claude/code"
|
||||||
|
"github.com/luispater/CLIProxyAPI/internal/util"
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
"github.com/tidwall/gjson"
|
||||||
|
"github.com/tidwall/sjson"
|
||||||
|
)
|
||||||
|
|
||||||
|
// ClaudeCodeAPIHandlers contains the handlers for Claude API endpoints.
|
||||||
|
// It holds a pool of clients to interact with the backend service.
|
||||||
|
type ClaudeCodeAPIHandlers struct {
|
||||||
|
*handlers.APIHandlers
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewClaudeCodeAPIHandlers creates a new Claude API handlers instance.
|
||||||
|
// It takes an APIHandlers instance as input and returns a ClaudeCodeAPIHandlers.
|
||||||
|
func NewClaudeCodeAPIHandlers(apiHandlers *handlers.APIHandlers) *ClaudeCodeAPIHandlers {
|
||||||
|
return &ClaudeCodeAPIHandlers{
|
||||||
|
APIHandlers: apiHandlers,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ClaudeMessages handles Claude-compatible streaming chat completions.
|
||||||
|
// This function implements a sophisticated client rotation and quota management system
|
||||||
|
// to ensure high availability and optimal resource utilization across multiple backend clients.
|
||||||
|
func (h *ClaudeCodeAPIHandlers) ClaudeMessages(c *gin.Context) {
|
||||||
|
// Extract raw JSON data from the incoming request
|
||||||
|
rawJSON, err := c.GetRawData()
|
||||||
|
// If data retrieval fails, return a 400 Bad Request error.
|
||||||
|
if err != nil {
|
||||||
|
c.JSON(http.StatusBadRequest, handlers.ErrorResponse{
|
||||||
|
Error: handlers.ErrorDetail{
|
||||||
|
Message: fmt.Sprintf("Invalid request: %v", err),
|
||||||
|
Type: "invalid_request_error",
|
||||||
|
},
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// h.handleGeminiStreamingResponse(c, rawJSON)
|
||||||
|
// h.handleCodexStreamingResponse(c, rawJSON)
|
||||||
|
modelName := gjson.GetBytes(rawJSON, "model")
|
||||||
|
provider := util.GetProviderName(modelName.String())
|
||||||
|
if provider == "gemini" {
|
||||||
|
h.handleGeminiStreamingResponse(c, rawJSON)
|
||||||
|
} else if provider == "gpt" {
|
||||||
|
h.handleCodexStreamingResponse(c, rawJSON)
|
||||||
|
} else {
|
||||||
|
h.handleGeminiStreamingResponse(c, rawJSON)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// handleGeminiStreamingResponse streams Claude-compatible responses backed by Gemini.
|
||||||
|
// It sets up SSE, selects a backend client with rotation/quota logic,
|
||||||
|
// forwards chunks, and translates them to Claude CLI format.
|
||||||
|
func (h *ClaudeCodeAPIHandlers) handleGeminiStreamingResponse(c *gin.Context, rawJSON []byte) {
|
||||||
|
// Set up Server-Sent Events (SSE) headers for streaming response
|
||||||
|
// These headers are essential for maintaining a persistent connection
|
||||||
|
// and enabling real-time streaming of chat completions
|
||||||
|
c.Header("Content-Type", "text/event-stream")
|
||||||
|
c.Header("Cache-Control", "no-cache")
|
||||||
|
c.Header("Connection", "keep-alive")
|
||||||
|
c.Header("Access-Control-Allow-Origin", "*")
|
||||||
|
|
||||||
|
// Get the http.Flusher interface to manually flush the response.
|
||||||
|
// This is crucial for streaming as it allows immediate sending of data chunks
|
||||||
|
flusher, ok := c.Writer.(http.Flusher)
|
||||||
|
if !ok {
|
||||||
|
c.JSON(http.StatusInternalServerError, handlers.ErrorResponse{
|
||||||
|
Error: handlers.ErrorDetail{
|
||||||
|
Message: "Streaming not supported",
|
||||||
|
Type: "server_error",
|
||||||
|
},
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Parse and prepare the Claude request, extracting model name, system instructions,
|
||||||
|
// conversation contents, and available tools from the raw JSON
|
||||||
|
modelName, systemInstruction, contents, tools := translatorClaudeCodeToGeminiCli.ConvertClaudeCodeRequestToCli(rawJSON)
|
||||||
|
|
||||||
|
// Map Claude model names to corresponding Gemini models
|
||||||
|
// This allows the proxy to handle Claude API calls using Gemini backend
|
||||||
|
if modelName == "claude-sonnet-4-20250514" {
|
||||||
|
modelName = "gemini-2.5-pro"
|
||||||
|
} else if modelName == "claude-3-5-haiku-20241022" {
|
||||||
|
modelName = "gemini-2.5-flash"
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create a cancellable context for the backend client request
|
||||||
|
// This allows proper cleanup and cancellation of ongoing requests
|
||||||
|
cliCtx, cliCancel := context.WithCancel(context.Background())
|
||||||
|
var cliClient client.Client
|
||||||
|
cliClient = client.NewGeminiClient(nil, nil, nil)
|
||||||
|
defer func() {
|
||||||
|
// Ensure the client's mutex is unlocked on function exit.
|
||||||
|
// This prevents deadlocks and ensures proper resource cleanup
|
||||||
|
if cliClient != nil {
|
||||||
|
cliClient.GetRequestMutex().Unlock()
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
// Main client rotation loop with quota management
|
||||||
|
// This loop implements a sophisticated load balancing and failover mechanism
|
||||||
|
outLoop:
|
||||||
|
for {
|
||||||
|
var errorResponse *client.ErrorMessage
|
||||||
|
cliClient, errorResponse = h.GetClient(modelName)
|
||||||
|
if errorResponse != nil {
|
||||||
|
c.Status(errorResponse.StatusCode)
|
||||||
|
_, _ = fmt.Fprint(c.Writer, errorResponse.Error)
|
||||||
|
flusher.Flush()
|
||||||
|
cliCancel()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Determine the authentication method being used by the selected client
|
||||||
|
// This affects how responses are formatted and logged
|
||||||
|
isGlAPIKey := false
|
||||||
|
if glAPIKey := cliClient.(*client.GeminiClient).GetGenerativeLanguageAPIKey(); glAPIKey != "" {
|
||||||
|
log.Debugf("Request use gemini generative language API Key: %s", glAPIKey)
|
||||||
|
isGlAPIKey = true
|
||||||
|
} else {
|
||||||
|
log.Debugf("Request use gemini account: %s, project id: %s", cliClient.GetEmail(), cliClient.(*client.GeminiClient).GetProjectID())
|
||||||
|
}
|
||||||
|
// Initiate streaming communication with the backend client
|
||||||
|
// This returns two channels: one for response chunks and one for errors
|
||||||
|
|
||||||
|
includeThoughts := false
|
||||||
|
if userAgent, hasKey := c.Request.Header["User-Agent"]; hasKey {
|
||||||
|
includeThoughts = !strings.Contains(userAgent[0], "claude-cli")
|
||||||
|
}
|
||||||
|
|
||||||
|
respChan, errChan := cliClient.SendMessageStream(cliCtx, rawJSON, modelName, systemInstruction, contents, tools, includeThoughts)
|
||||||
|
|
||||||
|
// Track response state for proper Claude format conversion
|
||||||
|
hasFirstResponse := false
|
||||||
|
responseType := 0
|
||||||
|
responseIndex := 0
|
||||||
|
|
||||||
|
// Main streaming loop - handles multiple concurrent events using Go channels
|
||||||
|
// This select statement manages four different types of events simultaneously
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
// Case 1: Handle client disconnection
|
||||||
|
// Detects when the HTTP client has disconnected and cleans up resources
|
||||||
|
case <-c.Request.Context().Done():
|
||||||
|
if c.Request.Context().Err().Error() == "context canceled" {
|
||||||
|
log.Debugf("GeminiClient disconnected: %v", c.Request.Context().Err())
|
||||||
|
cliCancel() // Cancel the backend request to prevent resource leaks
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Case 2: Process incoming response chunks from the backend
|
||||||
|
// This handles the actual streaming data from the AI model
|
||||||
|
case chunk, okStream := <-respChan:
|
||||||
|
if !okStream {
|
||||||
|
// Stream has ended - send the final message_stop event
|
||||||
|
// This follows the Claude API specification for stream termination
|
||||||
|
_, _ = c.Writer.Write([]byte(`event: message_stop`))
|
||||||
|
_, _ = c.Writer.Write([]byte("\n"))
|
||||||
|
_, _ = c.Writer.Write([]byte(`data: {"type":"message_stop"}`))
|
||||||
|
_, _ = c.Writer.Write([]byte("\n\n\n"))
|
||||||
|
|
||||||
|
flusher.Flush()
|
||||||
|
cliCancel()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
// Convert the backend response to Claude-compatible format
|
||||||
|
// This translation layer ensures API compatibility
|
||||||
|
claudeFormat := translatorClaudeCodeToGeminiCli.ConvertCliResponseToClaudeCode(chunk, isGlAPIKey, hasFirstResponse, &responseType, &responseIndex)
|
||||||
|
if claudeFormat != "" {
|
||||||
|
_, _ = c.Writer.Write([]byte(claudeFormat))
|
||||||
|
flusher.Flush() // Immediately send the chunk to the client
|
||||||
|
}
|
||||||
|
hasFirstResponse = true
|
||||||
|
|
||||||
|
// Case 3: Handle errors from the backend
|
||||||
|
// This manages various error conditions and implements retry logic
|
||||||
|
case errInfo, okError := <-errChan:
|
||||||
|
if okError {
|
||||||
|
// Special handling for quota exceeded errors
|
||||||
|
// If configured, attempt to switch to a different project/client
|
||||||
|
if errInfo.StatusCode == 429 && h.Cfg.QuotaExceeded.SwitchProject {
|
||||||
|
continue outLoop // Restart the client selection process
|
||||||
|
} else {
|
||||||
|
// Forward other errors directly to the client
|
||||||
|
c.Status(errInfo.StatusCode)
|
||||||
|
_, _ = fmt.Fprint(c.Writer, errInfo.Error.Error())
|
||||||
|
flusher.Flush()
|
||||||
|
cliCancel()
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Case 4: Send periodic keep-alive signals
|
||||||
|
// Prevents connection timeouts during long-running requests
|
||||||
|
case <-time.After(500 * time.Millisecond):
|
||||||
|
if hasFirstResponse {
|
||||||
|
// Send a ping event to maintain the connection
|
||||||
|
// This is especially important for slow AI model responses
|
||||||
|
output := "event: ping\n"
|
||||||
|
output = output + `data: {"type": "ping"}`
|
||||||
|
output = output + "\n\n\n"
|
||||||
|
_, _ = c.Writer.Write([]byte(output))
|
||||||
|
|
||||||
|
flusher.Flush()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// handleCodexStreamingResponse streams Claude-compatible responses backed by OpenAI.
|
||||||
|
// It converts the Claude request into Codex/OpenAI responses format, establishes SSE,
|
||||||
|
// and translates streaming chunks back into Claude CLI events.
|
||||||
|
func (h *ClaudeCodeAPIHandlers) handleCodexStreamingResponse(c *gin.Context, rawJSON []byte) {
|
||||||
|
// Set up Server-Sent Events (SSE) headers for streaming response
|
||||||
|
// These headers are essential for maintaining a persistent connection
|
||||||
|
// and enabling real-time streaming of chat completions
|
||||||
|
c.Header("Content-Type", "text/event-stream")
|
||||||
|
c.Header("Cache-Control", "no-cache")
|
||||||
|
c.Header("Connection", "keep-alive")
|
||||||
|
c.Header("Access-Control-Allow-Origin", "*")
|
||||||
|
|
||||||
|
// Get the http.Flusher interface to manually flush the response.
|
||||||
|
// This is crucial for streaming as it allows immediate sending of data chunks
|
||||||
|
flusher, ok := c.Writer.(http.Flusher)
|
||||||
|
if !ok {
|
||||||
|
c.JSON(http.StatusInternalServerError, handlers.ErrorResponse{
|
||||||
|
Error: handlers.ErrorDetail{
|
||||||
|
Message: "Streaming not supported",
|
||||||
|
Type: "server_error",
|
||||||
|
},
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Parse and prepare the Claude request, extracting model name, system instructions,
|
||||||
|
// conversation contents, and available tools from the raw JSON
|
||||||
|
newRequestJSON := translatorClaudeCodeToCodex.ConvertClaudeCodeRequestToCodex(rawJSON)
|
||||||
|
modelName := gjson.GetBytes(rawJSON, "model").String()
|
||||||
|
// Map Claude model names to corresponding Gemini models
|
||||||
|
// This allows the proxy to handle Claude API calls using Gemini backend
|
||||||
|
if modelName == "claude-sonnet-4-20250514" {
|
||||||
|
modelName = "gpt-5"
|
||||||
|
} else if modelName == "claude-3-5-haiku-20241022" {
|
||||||
|
modelName = "gpt-5"
|
||||||
|
}
|
||||||
|
newRequestJSON, _ = sjson.Set(newRequestJSON, "model", modelName)
|
||||||
|
// log.Debugf(string(rawJSON))
|
||||||
|
// log.Debugf(newRequestJSON)
|
||||||
|
// return
|
||||||
|
// Create a cancellable context for the backend client request
|
||||||
|
// This allows proper cleanup and cancellation of ongoing requests
|
||||||
|
cliCtx, cliCancel := context.WithCancel(context.Background())
|
||||||
|
var cliClient client.Client
|
||||||
|
defer func() {
|
||||||
|
// Ensure the client's mutex is unlocked on function exit.
|
||||||
|
// This prevents deadlocks and ensures proper resource cleanup
|
||||||
|
if cliClient != nil {
|
||||||
|
cliClient.GetRequestMutex().Unlock()
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
// Main client rotation loop with quota management
|
||||||
|
// This loop implements a sophisticated load balancing and failover mechanism
|
||||||
|
outLoop:
|
||||||
|
for {
|
||||||
|
var errorResponse *client.ErrorMessage
|
||||||
|
cliClient, errorResponse = h.GetClient(modelName)
|
||||||
|
if errorResponse != nil {
|
||||||
|
c.Status(errorResponse.StatusCode)
|
||||||
|
_, _ = fmt.Fprint(c.Writer, errorResponse.Error)
|
||||||
|
flusher.Flush()
|
||||||
|
cliCancel()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Debugf("Request use codex account: %s", cliClient.GetEmail())
|
||||||
|
|
||||||
|
// Initiate streaming communication with the backend client
|
||||||
|
// This returns two channels: one for response chunks and one for errors
|
||||||
|
respChan, errChan := cliClient.SendRawMessageStream(cliCtx, []byte(newRequestJSON), "")
|
||||||
|
|
||||||
|
// Track response state for proper Claude format conversion
|
||||||
|
hasFirstResponse := false
|
||||||
|
hasToolCall := false
|
||||||
|
|
||||||
|
// Main streaming loop - handles multiple concurrent events using Go channels
|
||||||
|
// This select statement manages four different types of events simultaneously
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
// Case 1: Handle client disconnection
|
||||||
|
// Detects when the HTTP client has disconnected and cleans up resources
|
||||||
|
case <-c.Request.Context().Done():
|
||||||
|
if c.Request.Context().Err().Error() == "context canceled" {
|
||||||
|
log.Debugf("CodexClient disconnected: %v", c.Request.Context().Err())
|
||||||
|
cliCancel() // Cancel the backend request to prevent resource leaks
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Case 2: Process incoming response chunks from the backend
|
||||||
|
// This handles the actual streaming data from the AI model
|
||||||
|
case chunk, okStream := <-respChan:
|
||||||
|
if !okStream {
|
||||||
|
flusher.Flush()
|
||||||
|
cliCancel()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
// Convert the backend response to Claude-compatible format
|
||||||
|
// This translation layer ensures API compatibility
|
||||||
|
if bytes.HasPrefix(chunk, []byte("data: ")) {
|
||||||
|
jsonData := chunk[6:]
|
||||||
|
var claudeFormat string
|
||||||
|
claudeFormat, hasToolCall = translatorClaudeCodeToCodex.ConvertCodexResponseToClaude(jsonData, hasToolCall)
|
||||||
|
// log.Debugf("claudeFormat: %s", claudeFormat)
|
||||||
|
if claudeFormat != "" {
|
||||||
|
_, _ = c.Writer.Write([]byte(claudeFormat))
|
||||||
|
_, _ = c.Writer.Write([]byte("\n"))
|
||||||
|
}
|
||||||
|
flusher.Flush() // Immediately send the chunk to the client
|
||||||
|
hasFirstResponse = true
|
||||||
|
} else {
|
||||||
|
// log.Debugf("chunk: %s", string(chunk))
|
||||||
|
}
|
||||||
|
// Case 3: Handle errors from the backend
|
||||||
|
// This manages various error conditions and implements retry logic
|
||||||
|
case errInfo, okError := <-errChan:
|
||||||
|
if okError {
|
||||||
|
// log.Debugf("Code: %d, Error: %v", errInfo.StatusCode, errInfo.Error)
|
||||||
|
// Special handling for quota exceeded errors
|
||||||
|
// If configured, attempt to switch to a different project/client
|
||||||
|
if errInfo.StatusCode == 429 && h.Cfg.QuotaExceeded.SwitchProject {
|
||||||
|
log.Debugf("quota exceeded, switch client")
|
||||||
|
continue outLoop // Restart the client selection process
|
||||||
|
} else {
|
||||||
|
// Forward other errors directly to the client
|
||||||
|
c.Status(errInfo.StatusCode)
|
||||||
|
_, _ = fmt.Fprint(c.Writer, errInfo.Error.Error())
|
||||||
|
flusher.Flush()
|
||||||
|
cliCancel()
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Case 4: Send periodic keep-alive signals
|
||||||
|
// Prevents connection timeouts during long-running requests
|
||||||
|
case <-time.After(3000 * time.Millisecond):
|
||||||
|
if hasFirstResponse {
|
||||||
|
// Send a ping event to maintain the connection
|
||||||
|
// This is especially important for slow AI model responses
|
||||||
|
output := "event: ping\n"
|
||||||
|
output = output + `data: {"type": "ping"}`
|
||||||
|
output = output + "\n\n"
|
||||||
|
_, _ = c.Writer.Write([]byte(output))
|
||||||
|
|
||||||
|
flusher.Flush()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -1,268 +0,0 @@
|
|||||||
// Package cli provides HTTP handlers for Gemini CLI API functionality.
|
|
||||||
// This package implements handlers that process CLI-specific requests for Gemini API operations,
|
|
||||||
// including content generation and streaming content generation endpoints.
|
|
||||||
// The handlers restrict access to localhost only and manage communication with the backend service.
|
|
||||||
package cli
|
|
||||||
|
|
||||||
import (
|
|
||||||
"bytes"
|
|
||||||
"context"
|
|
||||||
"fmt"
|
|
||||||
"github.com/gin-gonic/gin"
|
|
||||||
"github.com/luispater/CLIProxyAPI/internal/api/handlers"
|
|
||||||
"github.com/luispater/CLIProxyAPI/internal/client"
|
|
||||||
"github.com/luispater/CLIProxyAPI/internal/util"
|
|
||||||
log "github.com/sirupsen/logrus"
|
|
||||||
"github.com/tidwall/gjson"
|
|
||||||
"github.com/tidwall/sjson"
|
|
||||||
"io"
|
|
||||||
"net/http"
|
|
||||||
"strings"
|
|
||||||
"time"
|
|
||||||
)
|
|
||||||
|
|
||||||
// GeminiCLIAPIHandlers contains the handlers for Gemini CLI API endpoints.
|
|
||||||
// It holds a pool of clients to interact with the backend service.
|
|
||||||
type GeminiCLIAPIHandlers struct {
|
|
||||||
*handlers.APIHandlers
|
|
||||||
}
|
|
||||||
|
|
||||||
// NewGeminiCLIAPIHandlers creates a new Gemini CLI API handlers instance.
|
|
||||||
// It takes an APIHandlers instance as input and returns a GeminiCLIAPIHandlers.
|
|
||||||
func NewGeminiCLIAPIHandlers(apiHandlers *handlers.APIHandlers) *GeminiCLIAPIHandlers {
|
|
||||||
return &GeminiCLIAPIHandlers{
|
|
||||||
APIHandlers: apiHandlers,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// CLIHandler handles CLI-specific requests for Gemini API operations.
|
|
||||||
// It restricts access to localhost only and routes requests to appropriate internal handlers.
|
|
||||||
func (h *GeminiCLIAPIHandlers) CLIHandler(c *gin.Context) {
|
|
||||||
if !strings.HasPrefix(c.Request.RemoteAddr, "127.0.0.1:") {
|
|
||||||
c.JSON(http.StatusForbidden, handlers.ErrorResponse{
|
|
||||||
Error: handlers.ErrorDetail{
|
|
||||||
Message: "CLI reply only allow local access",
|
|
||||||
Type: "forbidden",
|
|
||||||
},
|
|
||||||
})
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
rawJSON, _ := c.GetRawData()
|
|
||||||
requestRawURI := c.Request.URL.Path
|
|
||||||
if requestRawURI == "/v1internal:generateContent" {
|
|
||||||
h.internalGenerateContent(c, rawJSON)
|
|
||||||
} else if requestRawURI == "/v1internal:streamGenerateContent" {
|
|
||||||
h.internalStreamGenerateContent(c, rawJSON)
|
|
||||||
} else {
|
|
||||||
reqBody := bytes.NewBuffer(rawJSON)
|
|
||||||
req, err := http.NewRequest("POST", fmt.Sprintf("https://cloudcode-pa.googleapis.com%s", c.Request.URL.RequestURI()), reqBody)
|
|
||||||
if err != nil {
|
|
||||||
c.JSON(http.StatusBadRequest, handlers.ErrorResponse{
|
|
||||||
Error: handlers.ErrorDetail{
|
|
||||||
Message: fmt.Sprintf("Invalid request: %v", err),
|
|
||||||
Type: "invalid_request_error",
|
|
||||||
},
|
|
||||||
})
|
|
||||||
return
|
|
||||||
}
|
|
||||||
for key, value := range c.Request.Header {
|
|
||||||
req.Header[key] = value
|
|
||||||
}
|
|
||||||
|
|
||||||
httpClient, err := util.SetProxy(h.Cfg, &http.Client{})
|
|
||||||
if err != nil {
|
|
||||||
log.Fatalf("set proxy failed: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
resp, err := httpClient.Do(req)
|
|
||||||
if err != nil {
|
|
||||||
c.JSON(http.StatusBadRequest, handlers.ErrorResponse{
|
|
||||||
Error: handlers.ErrorDetail{
|
|
||||||
Message: fmt.Sprintf("Invalid request: %v", err),
|
|
||||||
Type: "invalid_request_error",
|
|
||||||
},
|
|
||||||
})
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
|
|
||||||
defer func() {
|
|
||||||
if err = resp.Body.Close(); err != nil {
|
|
||||||
log.Printf("warn: failed to close response body: %v", err)
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
bodyBytes, _ := io.ReadAll(resp.Body)
|
|
||||||
|
|
||||||
c.JSON(http.StatusBadRequest, handlers.ErrorResponse{
|
|
||||||
Error: handlers.ErrorDetail{
|
|
||||||
Message: string(bodyBytes),
|
|
||||||
Type: "invalid_request_error",
|
|
||||||
},
|
|
||||||
})
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
defer func() {
|
|
||||||
_ = resp.Body.Close()
|
|
||||||
}()
|
|
||||||
|
|
||||||
for key, value := range resp.Header {
|
|
||||||
c.Header(key, value[0])
|
|
||||||
}
|
|
||||||
output, err := io.ReadAll(resp.Body)
|
|
||||||
if err != nil {
|
|
||||||
log.Errorf("Failed to read response body: %v", err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
_, _ = c.Writer.Write(output)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (h *GeminiCLIAPIHandlers) internalStreamGenerateContent(c *gin.Context, rawJSON []byte) {
|
|
||||||
alt := h.GetAlt(c)
|
|
||||||
|
|
||||||
if alt == "" {
|
|
||||||
c.Header("Content-Type", "text/event-stream")
|
|
||||||
c.Header("Cache-Control", "no-cache")
|
|
||||||
c.Header("Connection", "keep-alive")
|
|
||||||
c.Header("Access-Control-Allow-Origin", "*")
|
|
||||||
}
|
|
||||||
|
|
||||||
// Get the http.Flusher interface to manually flush the response.
|
|
||||||
flusher, ok := c.Writer.(http.Flusher)
|
|
||||||
if !ok {
|
|
||||||
c.JSON(http.StatusInternalServerError, handlers.ErrorResponse{
|
|
||||||
Error: handlers.ErrorDetail{
|
|
||||||
Message: "Streaming not supported",
|
|
||||||
Type: "server_error",
|
|
||||||
},
|
|
||||||
})
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
modelResult := gjson.GetBytes(rawJSON, "model")
|
|
||||||
modelName := modelResult.String()
|
|
||||||
|
|
||||||
cliCtx, cliCancel := context.WithCancel(context.Background())
|
|
||||||
var cliClient *client.Client
|
|
||||||
defer func() {
|
|
||||||
// Ensure the client's mutex is unlocked on function exit.
|
|
||||||
if cliClient != nil {
|
|
||||||
cliClient.RequestMutex.Unlock()
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
|
|
||||||
outLoop:
|
|
||||||
for {
|
|
||||||
var errorResponse *client.ErrorMessage
|
|
||||||
cliClient, errorResponse = h.GetClient(modelName)
|
|
||||||
if errorResponse != nil {
|
|
||||||
c.Status(errorResponse.StatusCode)
|
|
||||||
_, _ = fmt.Fprint(c.Writer, errorResponse.Error)
|
|
||||||
flusher.Flush()
|
|
||||||
cliCancel()
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
if glAPIKey := cliClient.GetGenerativeLanguageAPIKey(); glAPIKey != "" {
|
|
||||||
log.Debugf("Request use generative language API Key: %s", glAPIKey)
|
|
||||||
} else {
|
|
||||||
log.Debugf("Request use account: %s, project id: %s", cliClient.GetEmail(), cliClient.GetProjectID())
|
|
||||||
}
|
|
||||||
// Send the message and receive response chunks and errors via channels.
|
|
||||||
respChan, errChan := cliClient.SendRawMessageStream(cliCtx, rawJSON, "")
|
|
||||||
hasFirstResponse := false
|
|
||||||
for {
|
|
||||||
select {
|
|
||||||
// Handle client disconnection.
|
|
||||||
case <-c.Request.Context().Done():
|
|
||||||
if c.Request.Context().Err().Error() == "context canceled" {
|
|
||||||
log.Debugf("Client disconnected: %v", c.Request.Context().Err())
|
|
||||||
cliCancel() // Cancel the backend request.
|
|
||||||
return
|
|
||||||
}
|
|
||||||
// Process incoming response chunks.
|
|
||||||
case chunk, okStream := <-respChan:
|
|
||||||
if !okStream {
|
|
||||||
cliCancel()
|
|
||||||
return
|
|
||||||
}
|
|
||||||
hasFirstResponse = true
|
|
||||||
if cliClient.GetGenerativeLanguageAPIKey() != "" {
|
|
||||||
chunk, _ = sjson.SetRawBytes(chunk, "response", chunk)
|
|
||||||
}
|
|
||||||
_, _ = c.Writer.Write([]byte("data: "))
|
|
||||||
_, _ = c.Writer.Write(chunk)
|
|
||||||
_, _ = c.Writer.Write([]byte("\n\n"))
|
|
||||||
flusher.Flush()
|
|
||||||
// Handle errors from the backend.
|
|
||||||
case err, okError := <-errChan:
|
|
||||||
if okError {
|
|
||||||
if err.StatusCode == 429 && h.Cfg.QuotaExceeded.SwitchProject {
|
|
||||||
continue outLoop
|
|
||||||
} else {
|
|
||||||
c.Status(err.StatusCode)
|
|
||||||
_, _ = fmt.Fprint(c.Writer, err.Error.Error())
|
|
||||||
flusher.Flush()
|
|
||||||
cliCancel()
|
|
||||||
}
|
|
||||||
return
|
|
||||||
}
|
|
||||||
// Send a keep-alive signal to the client.
|
|
||||||
case <-time.After(500 * time.Millisecond):
|
|
||||||
if hasFirstResponse {
|
|
||||||
_, _ = c.Writer.Write([]byte("\n"))
|
|
||||||
flusher.Flush()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (h *GeminiCLIAPIHandlers) internalGenerateContent(c *gin.Context, rawJSON []byte) {
|
|
||||||
c.Header("Content-Type", "application/json")
|
|
||||||
|
|
||||||
modelResult := gjson.GetBytes(rawJSON, "model")
|
|
||||||
modelName := modelResult.String()
|
|
||||||
cliCtx, cliCancel := context.WithCancel(context.Background())
|
|
||||||
var cliClient *client.Client
|
|
||||||
defer func() {
|
|
||||||
if cliClient != nil {
|
|
||||||
cliClient.RequestMutex.Unlock()
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
|
|
||||||
for {
|
|
||||||
var errorResponse *client.ErrorMessage
|
|
||||||
cliClient, errorResponse = h.GetClient(modelName)
|
|
||||||
if errorResponse != nil {
|
|
||||||
c.Status(errorResponse.StatusCode)
|
|
||||||
_, _ = fmt.Fprint(c.Writer, errorResponse.Error)
|
|
||||||
cliCancel()
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
if glAPIKey := cliClient.GetGenerativeLanguageAPIKey(); glAPIKey != "" {
|
|
||||||
log.Debugf("Request use generative language API Key: %s", glAPIKey)
|
|
||||||
} else {
|
|
||||||
log.Debugf("Request use account: %s, project id: %s", cliClient.GetEmail(), cliClient.GetProjectID())
|
|
||||||
}
|
|
||||||
|
|
||||||
resp, err := cliClient.SendRawMessage(cliCtx, rawJSON, "")
|
|
||||||
if err != nil {
|
|
||||||
if err.StatusCode == 429 && h.Cfg.QuotaExceeded.SwitchProject {
|
|
||||||
continue
|
|
||||||
} else {
|
|
||||||
c.Status(err.StatusCode)
|
|
||||||
_, _ = c.Writer.Write([]byte(err.Error.Error()))
|
|
||||||
cliCancel()
|
|
||||||
}
|
|
||||||
break
|
|
||||||
} else {
|
|
||||||
_, _ = c.Writer.Write(resp)
|
|
||||||
cliCancel()
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
491
internal/api/handlers/gemini/cli/cli_handlers.go
Normal file
491
internal/api/handlers/gemini/cli/cli_handlers.go
Normal file
@@ -0,0 +1,491 @@
|
|||||||
|
// Package cli provides HTTP handlers for Gemini CLI API functionality.
|
||||||
|
// This package implements handlers that process CLI-specific requests for Gemini API operations,
|
||||||
|
// including content generation and streaming content generation endpoints.
|
||||||
|
// The handlers restrict access to localhost only and manage communication with the backend service.
|
||||||
|
package cli
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
"github.com/luispater/CLIProxyAPI/internal/api/handlers"
|
||||||
|
"github.com/luispater/CLIProxyAPI/internal/client"
|
||||||
|
translatorGeminiToCodex "github.com/luispater/CLIProxyAPI/internal/translator/codex/gemini"
|
||||||
|
"github.com/luispater/CLIProxyAPI/internal/util"
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
"github.com/tidwall/gjson"
|
||||||
|
"github.com/tidwall/sjson"
|
||||||
|
)
|
||||||
|
|
||||||
|
// GeminiCLIAPIHandlers contains the handlers for Gemini CLI API endpoints.
|
||||||
|
// It holds a pool of clients to interact with the backend service.
|
||||||
|
type GeminiCLIAPIHandlers struct {
|
||||||
|
*handlers.APIHandlers
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewGeminiCLIAPIHandlers creates a new Gemini CLI API handlers instance.
|
||||||
|
// It takes an APIHandlers instance as input and returns a GeminiCLIAPIHandlers.
|
||||||
|
func NewGeminiCLIAPIHandlers(apiHandlers *handlers.APIHandlers) *GeminiCLIAPIHandlers {
|
||||||
|
return &GeminiCLIAPIHandlers{
|
||||||
|
APIHandlers: apiHandlers,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// CLIHandler handles CLI-specific requests for Gemini API operations.
|
||||||
|
// It restricts access to localhost only and routes requests to appropriate internal handlers.
|
||||||
|
func (h *GeminiCLIAPIHandlers) CLIHandler(c *gin.Context) {
|
||||||
|
if !strings.HasPrefix(c.Request.RemoteAddr, "127.0.0.1:") {
|
||||||
|
c.JSON(http.StatusForbidden, handlers.ErrorResponse{
|
||||||
|
Error: handlers.ErrorDetail{
|
||||||
|
Message: "CLI reply only allow local access",
|
||||||
|
Type: "forbidden",
|
||||||
|
},
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
rawJSON, _ := c.GetRawData()
|
||||||
|
requestRawURI := c.Request.URL.Path
|
||||||
|
|
||||||
|
modelName := gjson.GetBytes(rawJSON, "model")
|
||||||
|
provider := util.GetProviderName(modelName.String())
|
||||||
|
|
||||||
|
if requestRawURI == "/v1internal:generateContent" {
|
||||||
|
if provider == "gemini" || provider == "unknow" {
|
||||||
|
h.handleInternalGenerateContent(c, rawJSON)
|
||||||
|
} else if provider == "gpt" {
|
||||||
|
h.handleCodexInternalGenerateContent(c, rawJSON)
|
||||||
|
}
|
||||||
|
} else if requestRawURI == "/v1internal:streamGenerateContent" {
|
||||||
|
if provider == "gemini" || provider == "unknow" {
|
||||||
|
h.handleInternalStreamGenerateContent(c, rawJSON)
|
||||||
|
} else if provider == "gpt" {
|
||||||
|
h.handleCodexInternalStreamGenerateContent(c, rawJSON)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
reqBody := bytes.NewBuffer(rawJSON)
|
||||||
|
req, err := http.NewRequest("POST", fmt.Sprintf("https://cloudcode-pa.googleapis.com%s", c.Request.URL.RequestURI()), reqBody)
|
||||||
|
if err != nil {
|
||||||
|
c.JSON(http.StatusBadRequest, handlers.ErrorResponse{
|
||||||
|
Error: handlers.ErrorDetail{
|
||||||
|
Message: fmt.Sprintf("Invalid request: %v", err),
|
||||||
|
Type: "invalid_request_error",
|
||||||
|
},
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
for key, value := range c.Request.Header {
|
||||||
|
req.Header[key] = value
|
||||||
|
}
|
||||||
|
|
||||||
|
httpClient := util.SetProxy(h.Cfg, &http.Client{})
|
||||||
|
|
||||||
|
resp, err := httpClient.Do(req)
|
||||||
|
if err != nil {
|
||||||
|
c.JSON(http.StatusBadRequest, handlers.ErrorResponse{
|
||||||
|
Error: handlers.ErrorDetail{
|
||||||
|
Message: fmt.Sprintf("Invalid request: %v", err),
|
||||||
|
Type: "invalid_request_error",
|
||||||
|
},
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
|
||||||
|
defer func() {
|
||||||
|
if err = resp.Body.Close(); err != nil {
|
||||||
|
log.Printf("warn: failed to close response body: %v", err)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
bodyBytes, _ := io.ReadAll(resp.Body)
|
||||||
|
|
||||||
|
c.JSON(http.StatusBadRequest, handlers.ErrorResponse{
|
||||||
|
Error: handlers.ErrorDetail{
|
||||||
|
Message: string(bodyBytes),
|
||||||
|
Type: "invalid_request_error",
|
||||||
|
},
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
defer func() {
|
||||||
|
_ = resp.Body.Close()
|
||||||
|
}()
|
||||||
|
|
||||||
|
for key, value := range resp.Header {
|
||||||
|
c.Header(key, value[0])
|
||||||
|
}
|
||||||
|
output, err := io.ReadAll(resp.Body)
|
||||||
|
if err != nil {
|
||||||
|
log.Errorf("Failed to read response body: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
_, _ = c.Writer.Write(output)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *GeminiCLIAPIHandlers) handleInternalStreamGenerateContent(c *gin.Context, rawJSON []byte) {
|
||||||
|
alt := h.GetAlt(c)
|
||||||
|
|
||||||
|
if alt == "" {
|
||||||
|
c.Header("Content-Type", "text/event-stream")
|
||||||
|
c.Header("Cache-Control", "no-cache")
|
||||||
|
c.Header("Connection", "keep-alive")
|
||||||
|
c.Header("Access-Control-Allow-Origin", "*")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get the http.Flusher interface to manually flush the response.
|
||||||
|
flusher, ok := c.Writer.(http.Flusher)
|
||||||
|
if !ok {
|
||||||
|
c.JSON(http.StatusInternalServerError, handlers.ErrorResponse{
|
||||||
|
Error: handlers.ErrorDetail{
|
||||||
|
Message: "Streaming not supported",
|
||||||
|
Type: "server_error",
|
||||||
|
},
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
modelResult := gjson.GetBytes(rawJSON, "model")
|
||||||
|
modelName := modelResult.String()
|
||||||
|
|
||||||
|
cliCtx, cliCancel := context.WithCancel(context.Background())
|
||||||
|
var cliClient client.Client
|
||||||
|
defer func() {
|
||||||
|
// Ensure the client's mutex is unlocked on function exit.
|
||||||
|
if cliClient != nil {
|
||||||
|
cliClient.GetRequestMutex().Unlock()
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
outLoop:
|
||||||
|
for {
|
||||||
|
var errorResponse *client.ErrorMessage
|
||||||
|
cliClient, errorResponse = h.GetClient(modelName)
|
||||||
|
if errorResponse != nil {
|
||||||
|
c.Status(errorResponse.StatusCode)
|
||||||
|
_, _ = fmt.Fprint(c.Writer, errorResponse.Error)
|
||||||
|
flusher.Flush()
|
||||||
|
cliCancel()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if glAPIKey := cliClient.(*client.GeminiClient).GetGenerativeLanguageAPIKey(); glAPIKey != "" {
|
||||||
|
log.Debugf("Request use generative language API Key: %s", glAPIKey)
|
||||||
|
} else {
|
||||||
|
log.Debugf("Request cli use account: %s, project id: %s", cliClient.(*client.GeminiClient).GetEmail(), cliClient.(*client.GeminiClient).GetProjectID())
|
||||||
|
}
|
||||||
|
// Send the message and receive response chunks and errors via channels.
|
||||||
|
respChan, errChan := cliClient.SendRawMessageStream(cliCtx, rawJSON, "")
|
||||||
|
hasFirstResponse := false
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
// Handle client disconnection.
|
||||||
|
case <-c.Request.Context().Done():
|
||||||
|
if c.Request.Context().Err().Error() == "context canceled" {
|
||||||
|
log.Debugf("GeminiClient disconnected: %v", c.Request.Context().Err())
|
||||||
|
cliCancel() // Cancel the backend request.
|
||||||
|
return
|
||||||
|
}
|
||||||
|
// Process incoming response chunks.
|
||||||
|
case chunk, okStream := <-respChan:
|
||||||
|
if !okStream {
|
||||||
|
cliCancel()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
hasFirstResponse = true
|
||||||
|
if cliClient.(*client.GeminiClient).GetGenerativeLanguageAPIKey() != "" {
|
||||||
|
chunk, _ = sjson.SetRawBytes(chunk, "response", chunk)
|
||||||
|
}
|
||||||
|
_, _ = c.Writer.Write([]byte("data: "))
|
||||||
|
_, _ = c.Writer.Write(chunk)
|
||||||
|
_, _ = c.Writer.Write([]byte("\n\n"))
|
||||||
|
|
||||||
|
flusher.Flush()
|
||||||
|
// Handle errors from the backend.
|
||||||
|
case err, okError := <-errChan:
|
||||||
|
if okError {
|
||||||
|
if err.StatusCode == 429 && h.Cfg.QuotaExceeded.SwitchProject {
|
||||||
|
continue outLoop
|
||||||
|
} else {
|
||||||
|
c.Status(err.StatusCode)
|
||||||
|
_, _ = fmt.Fprint(c.Writer, err.Error.Error())
|
||||||
|
flusher.Flush()
|
||||||
|
cliCancel()
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
// Send a keep-alive signal to the client.
|
||||||
|
case <-time.After(500 * time.Millisecond):
|
||||||
|
if hasFirstResponse {
|
||||||
|
_, _ = c.Writer.Write([]byte("\n"))
|
||||||
|
flusher.Flush()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *GeminiCLIAPIHandlers) handleInternalGenerateContent(c *gin.Context, rawJSON []byte) {
|
||||||
|
c.Header("Content-Type", "application/json")
|
||||||
|
// log.Debugf("GenerateContent: %s", string(rawJSON))
|
||||||
|
modelResult := gjson.GetBytes(rawJSON, "model")
|
||||||
|
modelName := modelResult.String()
|
||||||
|
cliCtx, cliCancel := context.WithCancel(context.Background())
|
||||||
|
var cliClient client.Client
|
||||||
|
defer func() {
|
||||||
|
if cliClient != nil {
|
||||||
|
cliClient.GetRequestMutex().Unlock()
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
for {
|
||||||
|
var errorResponse *client.ErrorMessage
|
||||||
|
cliClient, errorResponse = h.GetClient(modelName)
|
||||||
|
if errorResponse != nil {
|
||||||
|
c.Status(errorResponse.StatusCode)
|
||||||
|
_, _ = fmt.Fprint(c.Writer, errorResponse.Error)
|
||||||
|
cliCancel()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if glAPIKey := cliClient.(*client.GeminiClient).GetGenerativeLanguageAPIKey(); glAPIKey != "" {
|
||||||
|
log.Debugf("Request use generative language API Key: %s", glAPIKey)
|
||||||
|
} else {
|
||||||
|
log.Debugf("Request cli use account: %s, project id: %s", cliClient.(*client.GeminiClient).GetEmail(), cliClient.(*client.GeminiClient).GetProjectID())
|
||||||
|
}
|
||||||
|
|
||||||
|
resp, err := cliClient.SendRawMessage(cliCtx, rawJSON, "")
|
||||||
|
if err != nil {
|
||||||
|
if err.StatusCode == 429 && h.Cfg.QuotaExceeded.SwitchProject {
|
||||||
|
continue
|
||||||
|
} else {
|
||||||
|
c.Status(err.StatusCode)
|
||||||
|
_, _ = c.Writer.Write([]byte(err.Error.Error()))
|
||||||
|
log.Debugf("code: %d, error: %s", err.StatusCode, err.Error.Error())
|
||||||
|
cliCancel()
|
||||||
|
}
|
||||||
|
break
|
||||||
|
} else {
|
||||||
|
_, _ = c.Writer.Write(resp)
|
||||||
|
cliCancel()
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *GeminiCLIAPIHandlers) handleCodexInternalStreamGenerateContent(c *gin.Context, rawJSON []byte) {
|
||||||
|
c.Header("Content-Type", "text/event-stream")
|
||||||
|
c.Header("Cache-Control", "no-cache")
|
||||||
|
c.Header("Connection", "keep-alive")
|
||||||
|
c.Header("Access-Control-Allow-Origin", "*")
|
||||||
|
|
||||||
|
// Get the http.Flusher interface to manually flush the response.
|
||||||
|
flusher, ok := c.Writer.(http.Flusher)
|
||||||
|
if !ok {
|
||||||
|
c.JSON(http.StatusInternalServerError, handlers.ErrorResponse{
|
||||||
|
Error: handlers.ErrorDetail{
|
||||||
|
Message: "Streaming not supported",
|
||||||
|
Type: "server_error",
|
||||||
|
},
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
modelResult := gjson.GetBytes(rawJSON, "model")
|
||||||
|
rawJSON = []byte(gjson.GetBytes(rawJSON, "request").Raw)
|
||||||
|
rawJSON, _ = sjson.SetBytes(rawJSON, "model", modelResult.String())
|
||||||
|
rawJSON, _ = sjson.SetRawBytes(rawJSON, "system_instruction", []byte(gjson.GetBytes(rawJSON, "systemInstruction").Raw))
|
||||||
|
rawJSON, _ = sjson.DeleteBytes(rawJSON, "systemInstruction")
|
||||||
|
|
||||||
|
// log.Debugf("Request: %s", string(rawJSON))
|
||||||
|
// return
|
||||||
|
|
||||||
|
// Prepare the request for the backend client.
|
||||||
|
newRequestJSON := translatorGeminiToCodex.ConvertGeminiRequestToCodex(rawJSON)
|
||||||
|
// log.Debugf("Request: %s", newRequestJSON)
|
||||||
|
|
||||||
|
modelName := gjson.GetBytes(rawJSON, "model")
|
||||||
|
|
||||||
|
cliCtx, cliCancel := context.WithCancel(context.Background())
|
||||||
|
var cliClient client.Client
|
||||||
|
defer func() {
|
||||||
|
// Ensure the client's mutex is unlocked on function exit.
|
||||||
|
if cliClient != nil {
|
||||||
|
cliClient.GetRequestMutex().Unlock()
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
outLoop:
|
||||||
|
for {
|
||||||
|
var errorResponse *client.ErrorMessage
|
||||||
|
cliClient, errorResponse = h.GetClient(modelName.String())
|
||||||
|
if errorResponse != nil {
|
||||||
|
c.Status(errorResponse.StatusCode)
|
||||||
|
_, _ = fmt.Fprint(c.Writer, errorResponse.Error)
|
||||||
|
flusher.Flush()
|
||||||
|
cliCancel()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Debugf("Request codex use account: %s", cliClient.GetEmail())
|
||||||
|
|
||||||
|
// Send the message and receive response chunks and errors via channels.
|
||||||
|
respChan, errChan := cliClient.SendRawMessageStream(cliCtx, []byte(newRequestJSON), "")
|
||||||
|
|
||||||
|
params := &translatorGeminiToCodex.ConvertCodexResponseToGeminiParams{
|
||||||
|
Model: modelName.String(),
|
||||||
|
CreatedAt: 0,
|
||||||
|
ResponseID: "",
|
||||||
|
LastStorageOutput: "",
|
||||||
|
}
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
// Handle client disconnection.
|
||||||
|
case <-c.Request.Context().Done():
|
||||||
|
if c.Request.Context().Err().Error() == "context canceled" {
|
||||||
|
log.Debugf("CodexClient disconnected: %v", c.Request.Context().Err())
|
||||||
|
cliCancel() // Cancel the backend request.
|
||||||
|
return
|
||||||
|
}
|
||||||
|
// Process incoming response chunks.
|
||||||
|
case chunk, okStream := <-respChan:
|
||||||
|
if !okStream {
|
||||||
|
cliCancel()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
// _, _ = logFile.Write(chunk)
|
||||||
|
// _, _ = logFile.Write([]byte("\n"))
|
||||||
|
|
||||||
|
if bytes.HasPrefix(chunk, []byte("data: ")) {
|
||||||
|
jsonData := chunk[6:]
|
||||||
|
data := gjson.ParseBytes(jsonData)
|
||||||
|
typeResult := data.Get("type")
|
||||||
|
if typeResult.String() != "" {
|
||||||
|
outputs := translatorGeminiToCodex.ConvertCodexResponseToGemini(jsonData, params)
|
||||||
|
if len(outputs) > 0 {
|
||||||
|
for i := 0; i < len(outputs); i++ {
|
||||||
|
outputs[i], _ = sjson.SetRaw("{}", "response", outputs[i])
|
||||||
|
_, _ = c.Writer.Write([]byte("data: "))
|
||||||
|
_, _ = c.Writer.Write([]byte(outputs[i]))
|
||||||
|
_, _ = c.Writer.Write([]byte("\n\n"))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
flusher.Flush()
|
||||||
|
// Handle errors from the backend.
|
||||||
|
case errMessage, okError := <-errChan:
|
||||||
|
if okError {
|
||||||
|
if errMessage.StatusCode == 429 && h.Cfg.QuotaExceeded.SwitchProject {
|
||||||
|
continue outLoop
|
||||||
|
} else {
|
||||||
|
log.Debugf("code: %d, error: %s", errMessage.StatusCode, errMessage.Error.Error())
|
||||||
|
c.Status(errMessage.StatusCode)
|
||||||
|
_, _ = fmt.Fprint(c.Writer, errMessage.Error.Error())
|
||||||
|
flusher.Flush()
|
||||||
|
cliCancel()
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
// Send a keep-alive signal to the client.
|
||||||
|
case <-time.After(500 * time.Millisecond):
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *GeminiCLIAPIHandlers) handleCodexInternalGenerateContent(c *gin.Context, rawJSON []byte) {
|
||||||
|
c.Header("Content-Type", "application/json")
|
||||||
|
orgRawJSON := rawJSON
|
||||||
|
modelResult := gjson.GetBytes(rawJSON, "model")
|
||||||
|
rawJSON = []byte(gjson.GetBytes(rawJSON, "request").Raw)
|
||||||
|
rawJSON, _ = sjson.SetBytes(rawJSON, "model", modelResult.String())
|
||||||
|
rawJSON, _ = sjson.SetRawBytes(rawJSON, "system_instruction", []byte(gjson.GetBytes(rawJSON, "systemInstruction").Raw))
|
||||||
|
rawJSON, _ = sjson.DeleteBytes(rawJSON, "systemInstruction")
|
||||||
|
|
||||||
|
// Prepare the request for the backend client.
|
||||||
|
newRequestJSON := translatorGeminiToCodex.ConvertGeminiRequestToCodex(rawJSON)
|
||||||
|
// log.Debugf("Request: %s", newRequestJSON)
|
||||||
|
|
||||||
|
modelName := gjson.GetBytes(rawJSON, "model")
|
||||||
|
|
||||||
|
cliCtx, cliCancel := context.WithCancel(context.Background())
|
||||||
|
var cliClient client.Client
|
||||||
|
defer func() {
|
||||||
|
// Ensure the client's mutex is unlocked on function exit.
|
||||||
|
if cliClient != nil {
|
||||||
|
cliClient.GetRequestMutex().Unlock()
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
outLoop:
|
||||||
|
for {
|
||||||
|
var errorResponse *client.ErrorMessage
|
||||||
|
cliClient, errorResponse = h.GetClient(modelName.String())
|
||||||
|
if errorResponse != nil {
|
||||||
|
c.Status(errorResponse.StatusCode)
|
||||||
|
_, _ = fmt.Fprint(c.Writer, errorResponse.Error)
|
||||||
|
cliCancel()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Debugf("Request codex use account: %s", cliClient.GetEmail())
|
||||||
|
|
||||||
|
// Send the message and receive response chunks and errors via channels.
|
||||||
|
respChan, errChan := cliClient.SendRawMessageStream(cliCtx, []byte(newRequestJSON), "")
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
// Handle client disconnection.
|
||||||
|
case <-c.Request.Context().Done():
|
||||||
|
if c.Request.Context().Err().Error() == "context canceled" {
|
||||||
|
log.Debugf("CodexClient disconnected: %v", c.Request.Context().Err())
|
||||||
|
cliCancel() // Cancel the backend request.
|
||||||
|
return
|
||||||
|
}
|
||||||
|
// Process incoming response chunks.
|
||||||
|
case chunk, okStream := <-respChan:
|
||||||
|
if !okStream {
|
||||||
|
cliCancel()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if bytes.HasPrefix(chunk, []byte("data: ")) {
|
||||||
|
jsonData := chunk[6:]
|
||||||
|
data := gjson.ParseBytes(jsonData)
|
||||||
|
typeResult := data.Get("type")
|
||||||
|
if typeResult.String() != "" {
|
||||||
|
var geminiStr string
|
||||||
|
geminiStr = translatorGeminiToCodex.ConvertCodexResponseToGeminiNonStream(jsonData, modelName.String())
|
||||||
|
if geminiStr != "" {
|
||||||
|
_, _ = c.Writer.Write([]byte(geminiStr))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// Handle errors from the backend.
|
||||||
|
case err, okError := <-errChan:
|
||||||
|
if okError {
|
||||||
|
if err.StatusCode == 429 && h.Cfg.QuotaExceeded.SwitchProject {
|
||||||
|
continue outLoop
|
||||||
|
} else {
|
||||||
|
c.Status(err.StatusCode)
|
||||||
|
_, _ = fmt.Fprint(c.Writer, err.Error.Error())
|
||||||
|
log.Debugf("org: %s", string(orgRawJSON))
|
||||||
|
log.Debugf("raw: %s", string(rawJSON))
|
||||||
|
log.Debugf("newRequestJSON: %s", newRequestJSON)
|
||||||
|
cliCancel()
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
// Send a keep-alive signal to the client.
|
||||||
|
case <-time.After(500 * time.Millisecond):
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -1,437 +0,0 @@
|
|||||||
// Package gemini provides HTTP handlers for Gemini API endpoints.
|
|
||||||
// This package implements handlers for managing Gemini model operations including
|
|
||||||
// model listing, content generation, streaming content generation, and token counting.
|
|
||||||
// It serves as a proxy layer between clients and the Gemini backend service,
|
|
||||||
// handling request translation, client management, and response processing.
|
|
||||||
package gemini
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"fmt"
|
|
||||||
"github.com/gin-gonic/gin"
|
|
||||||
"github.com/luispater/CLIProxyAPI/internal/api/handlers"
|
|
||||||
"github.com/luispater/CLIProxyAPI/internal/api/translator/gemini/cli"
|
|
||||||
"github.com/luispater/CLIProxyAPI/internal/client"
|
|
||||||
log "github.com/sirupsen/logrus"
|
|
||||||
"github.com/tidwall/gjson"
|
|
||||||
"github.com/tidwall/sjson"
|
|
||||||
"net/http"
|
|
||||||
"strings"
|
|
||||||
"time"
|
|
||||||
)
|
|
||||||
|
|
||||||
// GeminiAPIHandlers contains the handlers for Gemini API endpoints.
|
|
||||||
// It holds a pool of clients to interact with the backend service.
|
|
||||||
type GeminiAPIHandlers struct {
|
|
||||||
*handlers.APIHandlers
|
|
||||||
}
|
|
||||||
|
|
||||||
// NewGeminiAPIHandlers creates a new Gemini API handlers instance.
|
|
||||||
// It takes an APIHandlers instance as input and returns a GeminiAPIHandlers.
|
|
||||||
func NewGeminiAPIHandlers(apiHandlers *handlers.APIHandlers) *GeminiAPIHandlers {
|
|
||||||
return &GeminiAPIHandlers{
|
|
||||||
APIHandlers: apiHandlers,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// GeminiModels handles the Gemini models listing endpoint.
|
|
||||||
// It returns a JSON response containing available Gemini models and their specifications.
|
|
||||||
func (h *GeminiAPIHandlers) GeminiModels(c *gin.Context) {
|
|
||||||
c.Status(http.StatusOK)
|
|
||||||
c.Header("Content-Type", "application/json; charset=UTF-8")
|
|
||||||
_, _ = c.Writer.Write([]byte(`{"models":[{"name":"models/gemini-2.5-flash","version":"001","displayName":"Gemini `))
|
|
||||||
_, _ = c.Writer.Write([]byte(`2.5 Flash","description":"Stable version of Gemini 2.5 Flash, our mid-size multimod`))
|
|
||||||
_, _ = c.Writer.Write([]byte(`al model that supports up to 1 million tokens, released in June of 2025.","inputTok`))
|
|
||||||
_, _ = c.Writer.Write([]byte(`enLimit":1048576,"outputTokenLimit":65536,"supportedGenerationMethods":["generateCo`))
|
|
||||||
_, _ = c.Writer.Write([]byte(`ntent","countTokens","createCachedContent","batchGenerateContent"],"temperature":1,`))
|
|
||||||
_, _ = c.Writer.Write([]byte(`"topP":0.95,"topK":64,"maxTemperature":2,"thinking":true},{"name":"models/gemini-2.`))
|
|
||||||
_, _ = c.Writer.Write([]byte(`5-pro","version":"2.5","displayName":"Gemini 2.5 Pro","description":"Stable release`))
|
|
||||||
_, _ = c.Writer.Write([]byte(` (June 17th, 2025) of Gemini 2.5 Pro","inputTokenLimit":1048576,"outputTokenLimit":`))
|
|
||||||
_, _ = c.Writer.Write([]byte(`65536,"supportedGenerationMethods":["generateContent","countTokens","createCachedCo`))
|
|
||||||
_, _ = c.Writer.Write([]byte(`ntent","batchGenerateContent"],"temperature":1,"topP":0.95,"topK":64,"maxTemperatur`))
|
|
||||||
_, _ = c.Writer.Write([]byte(`e":2,"thinking":true}],"nextPageToken":""}`))
|
|
||||||
}
|
|
||||||
|
|
||||||
// GeminiGetHandler handles GET requests for specific Gemini model information.
|
|
||||||
// It returns detailed information about a specific Gemini model based on the action parameter.
|
|
||||||
func (h *GeminiAPIHandlers) GeminiGetHandler(c *gin.Context) {
|
|
||||||
var request struct {
|
|
||||||
Action string `uri:"action" binding:"required"`
|
|
||||||
}
|
|
||||||
if err := c.ShouldBindUri(&request); err != nil {
|
|
||||||
c.JSON(http.StatusBadRequest, handlers.ErrorResponse{
|
|
||||||
Error: handlers.ErrorDetail{
|
|
||||||
Message: fmt.Sprintf("Invalid request: %v", err),
|
|
||||||
Type: "invalid_request_error",
|
|
||||||
},
|
|
||||||
})
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if request.Action == "gemini-2.5-pro" {
|
|
||||||
c.Status(http.StatusOK)
|
|
||||||
c.Header("Content-Type", "application/json; charset=UTF-8")
|
|
||||||
_, _ = c.Writer.Write([]byte(`{"name":"models/gemini-2.5-pro","version":"2.5","displayName":"Gemini 2.5 Pro",`))
|
|
||||||
_, _ = c.Writer.Write([]byte(`"description":"Stable release (June 17th, 2025) of Gemini 2.5 Pro","inputTokenL`))
|
|
||||||
_, _ = c.Writer.Write([]byte(`imit":1048576,"outputTokenLimit":65536,"supportedGenerationMethods":["generateC`))
|
|
||||||
_, _ = c.Writer.Write([]byte(`ontent","countTokens","createCachedContent","batchGenerateContent"],"temperatur`))
|
|
||||||
_, _ = c.Writer.Write([]byte(`e":1,"topP":0.95,"topK":64,"maxTemperature":2,"thinking":true}`))
|
|
||||||
} else if request.Action == "gemini-2.5-flash" {
|
|
||||||
c.Status(http.StatusOK)
|
|
||||||
c.Header("Content-Type", "application/json; charset=UTF-8")
|
|
||||||
_, _ = c.Writer.Write([]byte(`{"name":"models/gemini-2.5-flash","version":"001","displayName":"Gemini 2.5 Fla`))
|
|
||||||
_, _ = c.Writer.Write([]byte(`sh","description":"Stable version of Gemini 2.5 Flash, our mid-size multimodal `))
|
|
||||||
_, _ = c.Writer.Write([]byte(`model that supports up to 1 million tokens, released in June of 2025.","inputTo`))
|
|
||||||
_, _ = c.Writer.Write([]byte(`kenLimit":1048576,"outputTokenLimit":65536,"supportedGenerationMethods":["gener`))
|
|
||||||
_, _ = c.Writer.Write([]byte(`ateContent","countTokens","createCachedContent","batchGenerateContent"],"temper`))
|
|
||||||
_, _ = c.Writer.Write([]byte(`ature":1,"topP":0.95,"topK":64,"maxTemperature":2,"thinking":true}`))
|
|
||||||
} else {
|
|
||||||
c.Status(http.StatusNotFound)
|
|
||||||
_, _ = c.Writer.Write([]byte(
|
|
||||||
`{"error":{"message":"Not Found","code":404,"status":"NOT_FOUND"}}`,
|
|
||||||
))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// GeminiHandler handles POST requests for Gemini API operations.
|
|
||||||
// It routes requests to appropriate handlers based on the action parameter (model:method format).
|
|
||||||
func (h *GeminiAPIHandlers) GeminiHandler(c *gin.Context) {
|
|
||||||
var request struct {
|
|
||||||
Action string `uri:"action" binding:"required"`
|
|
||||||
}
|
|
||||||
if err := c.ShouldBindUri(&request); err != nil {
|
|
||||||
c.JSON(http.StatusBadRequest, handlers.ErrorResponse{
|
|
||||||
Error: handlers.ErrorDetail{
|
|
||||||
Message: fmt.Sprintf("Invalid request: %v", err),
|
|
||||||
Type: "invalid_request_error",
|
|
||||||
},
|
|
||||||
})
|
|
||||||
return
|
|
||||||
}
|
|
||||||
action := strings.Split(request.Action, ":")
|
|
||||||
if len(action) != 2 {
|
|
||||||
c.JSON(http.StatusNotFound, handlers.ErrorResponse{
|
|
||||||
Error: handlers.ErrorDetail{
|
|
||||||
Message: fmt.Sprintf("%s not found.", c.Request.URL.Path),
|
|
||||||
Type: "invalid_request_error",
|
|
||||||
},
|
|
||||||
})
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
modelName := action[0]
|
|
||||||
method := action[1]
|
|
||||||
rawJSON, _ := c.GetRawData()
|
|
||||||
rawJSON, _ = sjson.SetBytes(rawJSON, "model", []byte(modelName))
|
|
||||||
|
|
||||||
if method == "generateContent" {
|
|
||||||
h.geminiGenerateContent(c, rawJSON)
|
|
||||||
} else if method == "streamGenerateContent" {
|
|
||||||
h.geminiStreamGenerateContent(c, rawJSON)
|
|
||||||
} else if method == "countTokens" {
|
|
||||||
h.geminiCountTokens(c, rawJSON)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (h *GeminiAPIHandlers) geminiStreamGenerateContent(c *gin.Context, rawJSON []byte) {
|
|
||||||
alt := h.GetAlt(c)
|
|
||||||
|
|
||||||
if alt == "" {
|
|
||||||
c.Header("Content-Type", "text/event-stream")
|
|
||||||
c.Header("Cache-Control", "no-cache")
|
|
||||||
c.Header("Connection", "keep-alive")
|
|
||||||
c.Header("Access-Control-Allow-Origin", "*")
|
|
||||||
}
|
|
||||||
|
|
||||||
// Get the http.Flusher interface to manually flush the response.
|
|
||||||
flusher, ok := c.Writer.(http.Flusher)
|
|
||||||
if !ok {
|
|
||||||
c.JSON(http.StatusInternalServerError, handlers.ErrorResponse{
|
|
||||||
Error: handlers.ErrorDetail{
|
|
||||||
Message: "Streaming not supported",
|
|
||||||
Type: "server_error",
|
|
||||||
},
|
|
||||||
})
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
modelResult := gjson.GetBytes(rawJSON, "model")
|
|
||||||
modelName := modelResult.String()
|
|
||||||
|
|
||||||
cliCtx, cliCancel := context.WithCancel(context.Background())
|
|
||||||
var cliClient *client.Client
|
|
||||||
defer func() {
|
|
||||||
// Ensure the client's mutex is unlocked on function exit.
|
|
||||||
if cliClient != nil {
|
|
||||||
cliClient.RequestMutex.Unlock()
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
|
|
||||||
outLoop:
|
|
||||||
for {
|
|
||||||
var errorResponse *client.ErrorMessage
|
|
||||||
cliClient, errorResponse = h.GetClient(modelName)
|
|
||||||
if errorResponse != nil {
|
|
||||||
c.Status(errorResponse.StatusCode)
|
|
||||||
_, _ = fmt.Fprint(c.Writer, errorResponse.Error)
|
|
||||||
flusher.Flush()
|
|
||||||
cliCancel()
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
template := ""
|
|
||||||
parsed := gjson.Parse(string(rawJSON))
|
|
||||||
contents := parsed.Get("request.contents")
|
|
||||||
if contents.Exists() {
|
|
||||||
template = string(rawJSON)
|
|
||||||
} else {
|
|
||||||
template = `{"project":"","request":{},"model":""}`
|
|
||||||
template, _ = sjson.SetRaw(template, "request", string(rawJSON))
|
|
||||||
template, _ = sjson.Set(template, "model", gjson.Get(template, "request.model").String())
|
|
||||||
template, _ = sjson.Delete(template, "request.model")
|
|
||||||
}
|
|
||||||
|
|
||||||
template, errFixCLIToolResponse := cli.FixCLIToolResponse(template)
|
|
||||||
if errFixCLIToolResponse != nil {
|
|
||||||
c.JSON(http.StatusInternalServerError, handlers.ErrorResponse{
|
|
||||||
Error: handlers.ErrorDetail{
|
|
||||||
Message: errFixCLIToolResponse.Error(),
|
|
||||||
Type: "server_error",
|
|
||||||
},
|
|
||||||
})
|
|
||||||
cliCancel()
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
systemInstructionResult := gjson.Get(template, "request.system_instruction")
|
|
||||||
if systemInstructionResult.Exists() {
|
|
||||||
template, _ = sjson.SetRaw(template, "request.systemInstruction", systemInstructionResult.Raw)
|
|
||||||
template, _ = sjson.Delete(template, "request.system_instruction")
|
|
||||||
}
|
|
||||||
rawJSON = []byte(template)
|
|
||||||
|
|
||||||
if glAPIKey := cliClient.GetGenerativeLanguageAPIKey(); glAPIKey != "" {
|
|
||||||
log.Debugf("Request use generative language API Key: %s", glAPIKey)
|
|
||||||
} else {
|
|
||||||
log.Debugf("Request use account: %s, project id: %s", cliClient.GetEmail(), cliClient.GetProjectID())
|
|
||||||
}
|
|
||||||
|
|
||||||
// Send the message and receive response chunks and errors via channels.
|
|
||||||
respChan, errChan := cliClient.SendRawMessageStream(cliCtx, rawJSON, alt)
|
|
||||||
for {
|
|
||||||
select {
|
|
||||||
// Handle client disconnection.
|
|
||||||
case <-c.Request.Context().Done():
|
|
||||||
if c.Request.Context().Err().Error() == "context canceled" {
|
|
||||||
log.Debugf("Client disconnected: %v", c.Request.Context().Err())
|
|
||||||
cliCancel() // Cancel the backend request.
|
|
||||||
return
|
|
||||||
}
|
|
||||||
// Process incoming response chunks.
|
|
||||||
case chunk, okStream := <-respChan:
|
|
||||||
if !okStream {
|
|
||||||
cliCancel()
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if cliClient.GetGenerativeLanguageAPIKey() == "" {
|
|
||||||
if alt == "" {
|
|
||||||
responseResult := gjson.GetBytes(chunk, "response")
|
|
||||||
if responseResult.Exists() {
|
|
||||||
chunk = []byte(responseResult.Raw)
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
chunkTemplate := "[]"
|
|
||||||
responseResult := gjson.ParseBytes(chunk)
|
|
||||||
if responseResult.IsArray() {
|
|
||||||
responseResultItems := responseResult.Array()
|
|
||||||
for i := 0; i < len(responseResultItems); i++ {
|
|
||||||
responseResultItem := responseResultItems[i]
|
|
||||||
if responseResultItem.Get("response").Exists() {
|
|
||||||
chunkTemplate, _ = sjson.SetRaw(chunkTemplate, "-1", responseResultItem.Get("response").Raw)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
chunk = []byte(chunkTemplate)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if alt == "" {
|
|
||||||
_, _ = c.Writer.Write([]byte("data: "))
|
|
||||||
_, _ = c.Writer.Write(chunk)
|
|
||||||
_, _ = c.Writer.Write([]byte("\n\n"))
|
|
||||||
} else {
|
|
||||||
_, _ = c.Writer.Write(chunk)
|
|
||||||
}
|
|
||||||
flusher.Flush()
|
|
||||||
// Handle errors from the backend.
|
|
||||||
case err, okError := <-errChan:
|
|
||||||
if okError {
|
|
||||||
if err.StatusCode == 429 && h.Cfg.QuotaExceeded.SwitchProject {
|
|
||||||
log.Debugf("quota exceeded, switch client")
|
|
||||||
continue outLoop
|
|
||||||
} else {
|
|
||||||
log.Debugf("error code :%d, error: %v", err.StatusCode, err.Error.Error())
|
|
||||||
c.Status(err.StatusCode)
|
|
||||||
_, _ = fmt.Fprint(c.Writer, err.Error.Error())
|
|
||||||
flusher.Flush()
|
|
||||||
cliCancel()
|
|
||||||
}
|
|
||||||
return
|
|
||||||
}
|
|
||||||
// Send a keep-alive signal to the client.
|
|
||||||
case <-time.After(500 * time.Millisecond):
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (h *GeminiAPIHandlers) geminiCountTokens(c *gin.Context, rawJSON []byte) {
|
|
||||||
c.Header("Content-Type", "application/json")
|
|
||||||
|
|
||||||
alt := h.GetAlt(c)
|
|
||||||
// orgrawJSON := rawJSON
|
|
||||||
modelResult := gjson.GetBytes(rawJSON, "model")
|
|
||||||
modelName := modelResult.String()
|
|
||||||
cliCtx, cliCancel := context.WithCancel(context.Background())
|
|
||||||
var cliClient *client.Client
|
|
||||||
defer func() {
|
|
||||||
if cliClient != nil {
|
|
||||||
cliClient.RequestMutex.Unlock()
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
|
|
||||||
for {
|
|
||||||
var errorResponse *client.ErrorMessage
|
|
||||||
cliClient, errorResponse = h.GetClient(modelName, false)
|
|
||||||
if errorResponse != nil {
|
|
||||||
c.Status(errorResponse.StatusCode)
|
|
||||||
_, _ = fmt.Fprint(c.Writer, errorResponse.Error)
|
|
||||||
cliCancel()
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
if glAPIKey := cliClient.GetGenerativeLanguageAPIKey(); glAPIKey != "" {
|
|
||||||
log.Debugf("Request use generative language API Key: %s", glAPIKey)
|
|
||||||
} else {
|
|
||||||
log.Debugf("Request use account: %s, project id: %s", cliClient.GetEmail(), cliClient.GetProjectID())
|
|
||||||
|
|
||||||
template := `{"request":{}}`
|
|
||||||
if gjson.GetBytes(rawJSON, "generateContentRequest").Exists() {
|
|
||||||
template, _ = sjson.SetRaw(template, "request", gjson.GetBytes(rawJSON, "generateContentRequest").Raw)
|
|
||||||
template, _ = sjson.Delete(template, "generateContentRequest")
|
|
||||||
} else if gjson.GetBytes(rawJSON, "contents").Exists() {
|
|
||||||
template, _ = sjson.SetRaw(template, "request.contents", gjson.GetBytes(rawJSON, "contents").Raw)
|
|
||||||
template, _ = sjson.Delete(template, "contents")
|
|
||||||
}
|
|
||||||
rawJSON = []byte(template)
|
|
||||||
}
|
|
||||||
|
|
||||||
resp, err := cliClient.SendRawTokenCount(cliCtx, rawJSON, alt)
|
|
||||||
if err != nil {
|
|
||||||
if err.StatusCode == 429 && h.Cfg.QuotaExceeded.SwitchProject {
|
|
||||||
continue
|
|
||||||
} else {
|
|
||||||
c.Status(err.StatusCode)
|
|
||||||
_, _ = c.Writer.Write([]byte(err.Error.Error()))
|
|
||||||
cliCancel()
|
|
||||||
// log.Debugf(err.Error.Error())
|
|
||||||
// log.Debugf(string(rawJSON))
|
|
||||||
// log.Debugf(string(orgrawJSON))
|
|
||||||
}
|
|
||||||
break
|
|
||||||
} else {
|
|
||||||
if cliClient.GetGenerativeLanguageAPIKey() == "" {
|
|
||||||
responseResult := gjson.GetBytes(resp, "response")
|
|
||||||
if responseResult.Exists() {
|
|
||||||
resp = []byte(responseResult.Raw)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
_, _ = c.Writer.Write(resp)
|
|
||||||
cliCancel()
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (h *GeminiAPIHandlers) geminiGenerateContent(c *gin.Context, rawJSON []byte) {
|
|
||||||
c.Header("Content-Type", "application/json")
|
|
||||||
|
|
||||||
alt := h.GetAlt(c)
|
|
||||||
|
|
||||||
modelResult := gjson.GetBytes(rawJSON, "model")
|
|
||||||
modelName := modelResult.String()
|
|
||||||
cliCtx, cliCancel := context.WithCancel(context.Background())
|
|
||||||
var cliClient *client.Client
|
|
||||||
defer func() {
|
|
||||||
if cliClient != nil {
|
|
||||||
cliClient.RequestMutex.Unlock()
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
|
|
||||||
for {
|
|
||||||
var errorResponse *client.ErrorMessage
|
|
||||||
cliClient, errorResponse = h.GetClient(modelName)
|
|
||||||
if errorResponse != nil {
|
|
||||||
c.Status(errorResponse.StatusCode)
|
|
||||||
_, _ = fmt.Fprint(c.Writer, errorResponse.Error)
|
|
||||||
cliCancel()
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
template := ""
|
|
||||||
parsed := gjson.Parse(string(rawJSON))
|
|
||||||
contents := parsed.Get("request.contents")
|
|
||||||
if contents.Exists() {
|
|
||||||
template = string(rawJSON)
|
|
||||||
} else {
|
|
||||||
template = `{"project":"","request":{},"model":""}`
|
|
||||||
template, _ = sjson.SetRaw(template, "request", string(rawJSON))
|
|
||||||
template, _ = sjson.Set(template, "model", gjson.Get(template, "request.model").String())
|
|
||||||
template, _ = sjson.Delete(template, "request.model")
|
|
||||||
}
|
|
||||||
|
|
||||||
template, errFixCLIToolResponse := cli.FixCLIToolResponse(template)
|
|
||||||
if errFixCLIToolResponse != nil {
|
|
||||||
c.JSON(http.StatusInternalServerError, handlers.ErrorResponse{
|
|
||||||
Error: handlers.ErrorDetail{
|
|
||||||
Message: errFixCLIToolResponse.Error(),
|
|
||||||
Type: "server_error",
|
|
||||||
},
|
|
||||||
})
|
|
||||||
cliCancel()
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
systemInstructionResult := gjson.Get(template, "request.system_instruction")
|
|
||||||
if systemInstructionResult.Exists() {
|
|
||||||
template, _ = sjson.SetRaw(template, "request.systemInstruction", systemInstructionResult.Raw)
|
|
||||||
template, _ = sjson.Delete(template, "request.system_instruction")
|
|
||||||
}
|
|
||||||
rawJSON = []byte(template)
|
|
||||||
|
|
||||||
if glAPIKey := cliClient.GetGenerativeLanguageAPIKey(); glAPIKey != "" {
|
|
||||||
log.Debugf("Request use generative language API Key: %s", glAPIKey)
|
|
||||||
} else {
|
|
||||||
log.Debugf("Request use account: %s, project id: %s", cliClient.GetEmail(), cliClient.GetProjectID())
|
|
||||||
}
|
|
||||||
resp, err := cliClient.SendRawMessage(cliCtx, rawJSON, alt)
|
|
||||||
if err != nil {
|
|
||||||
if err.StatusCode == 429 && h.Cfg.QuotaExceeded.SwitchProject {
|
|
||||||
continue
|
|
||||||
} else {
|
|
||||||
c.Status(err.StatusCode)
|
|
||||||
_, _ = c.Writer.Write([]byte(err.Error.Error()))
|
|
||||||
cliCancel()
|
|
||||||
}
|
|
||||||
break
|
|
||||||
} else {
|
|
||||||
if cliClient.GetGenerativeLanguageAPIKey() == "" {
|
|
||||||
responseResult := gjson.GetBytes(resp, "response")
|
|
||||||
if responseResult.Exists() {
|
|
||||||
resp = []byte(responseResult.Raw)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
_, _ = c.Writer.Write(resp)
|
|
||||||
cliCancel()
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
735
internal/api/handlers/gemini/gemini_handlers.go
Normal file
735
internal/api/handlers/gemini/gemini_handlers.go
Normal file
@@ -0,0 +1,735 @@
|
|||||||
|
// Package gemini provides HTTP handlers for Gemini API endpoints.
|
||||||
|
// This package implements handlers for managing Gemini model operations including
|
||||||
|
// model listing, content generation, streaming content generation, and token counting.
|
||||||
|
// It serves as a proxy layer between clients and the Gemini backend service,
|
||||||
|
// handling request translation, client management, and response processing.
|
||||||
|
package gemini
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"net/http"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
"github.com/luispater/CLIProxyAPI/internal/api/handlers"
|
||||||
|
"github.com/luispater/CLIProxyAPI/internal/client"
|
||||||
|
translatorGeminiToCodex "github.com/luispater/CLIProxyAPI/internal/translator/codex/gemini"
|
||||||
|
translatorGeminiToGeminiCli "github.com/luispater/CLIProxyAPI/internal/translator/gemini-cli/gemini/cli"
|
||||||
|
"github.com/luispater/CLIProxyAPI/internal/util"
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
"github.com/tidwall/gjson"
|
||||||
|
"github.com/tidwall/sjson"
|
||||||
|
)
|
||||||
|
|
||||||
|
// GeminiAPIHandlers contains the handlers for Gemini API endpoints.
|
||||||
|
// It holds a pool of clients to interact with the backend service.
|
||||||
|
type GeminiAPIHandlers struct {
|
||||||
|
*handlers.APIHandlers
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewGeminiAPIHandlers creates a new Gemini API handlers instance.
|
||||||
|
// It takes an APIHandlers instance as input and returns a GeminiAPIHandlers.
|
||||||
|
func NewGeminiAPIHandlers(apiHandlers *handlers.APIHandlers) *GeminiAPIHandlers {
|
||||||
|
return &GeminiAPIHandlers{
|
||||||
|
APIHandlers: apiHandlers,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// GeminiModels handles the Gemini models listing endpoint.
|
||||||
|
// It returns a JSON response containing available Gemini models and their specifications.
|
||||||
|
func (h *GeminiAPIHandlers) GeminiModels(c *gin.Context) {
|
||||||
|
c.JSON(http.StatusOK, gin.H{
|
||||||
|
"data": []map[string]any{
|
||||||
|
{
|
||||||
|
"id": "gemini-2.5-flash",
|
||||||
|
"object": "model",
|
||||||
|
"version": "001",
|
||||||
|
"name": "Gemini 2.5 Flash",
|
||||||
|
"description": "Stable version of Gemini 2.5 Flash, our mid-size multimodal model that supports up to 1 million tokens, released in June of 2025.",
|
||||||
|
"context_length": 1_048_576,
|
||||||
|
"max_completion_tokens": 65_536,
|
||||||
|
"supported_parameters": []string{
|
||||||
|
"tools",
|
||||||
|
"temperature",
|
||||||
|
"top_p",
|
||||||
|
"top_k",
|
||||||
|
},
|
||||||
|
"temperature": 1,
|
||||||
|
"topP": 0.95,
|
||||||
|
"topK": 64,
|
||||||
|
"maxTemperature": 2,
|
||||||
|
"thinking": true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": "gemini-2.5-pro",
|
||||||
|
"object": "model",
|
||||||
|
"version": "2.5",
|
||||||
|
"name": "Gemini 2.5 Pro",
|
||||||
|
"description": "Stable release (June 17th, 2025) of Gemini 2.5 Pro",
|
||||||
|
"context_length": 1_048_576,
|
||||||
|
"max_completion_tokens": 65_536,
|
||||||
|
"supported_parameters": []string{
|
||||||
|
"tools",
|
||||||
|
"temperature",
|
||||||
|
"top_p",
|
||||||
|
"top_k",
|
||||||
|
},
|
||||||
|
"temperature": 1,
|
||||||
|
"topP": 0.95,
|
||||||
|
"topK": 64,
|
||||||
|
"maxTemperature": 2,
|
||||||
|
"thinking": true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": "gpt-5",
|
||||||
|
"object": "model",
|
||||||
|
"version": "gpt-5-2025-08-07",
|
||||||
|
"name": "GPT 5",
|
||||||
|
"description": "Stable version of GPT 5, The best model for coding and agentic tasks across domains.",
|
||||||
|
"context_length": 400_000,
|
||||||
|
"max_completion_tokens": 128_000,
|
||||||
|
"supported_parameters": []string{
|
||||||
|
"tools",
|
||||||
|
},
|
||||||
|
"temperature": 1,
|
||||||
|
"topP": 0.95,
|
||||||
|
"topK": 64,
|
||||||
|
"maxTemperature": 2,
|
||||||
|
"thinking": true,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// GeminiGetHandler handles GET requests for specific Gemini model information.
|
||||||
|
// It returns detailed information about a specific Gemini model based on the action parameter.
|
||||||
|
func (h *GeminiAPIHandlers) GeminiGetHandler(c *gin.Context) {
|
||||||
|
var request struct {
|
||||||
|
Action string `uri:"action" binding:"required"`
|
||||||
|
}
|
||||||
|
if err := c.ShouldBindUri(&request); err != nil {
|
||||||
|
c.JSON(http.StatusBadRequest, handlers.ErrorResponse{
|
||||||
|
Error: handlers.ErrorDetail{
|
||||||
|
Message: fmt.Sprintf("Invalid request: %v", err),
|
||||||
|
Type: "invalid_request_error",
|
||||||
|
},
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
switch request.Action {
|
||||||
|
case "gemini-2.5-pro":
|
||||||
|
c.JSON(http.StatusOK, gin.H{
|
||||||
|
"id": "gemini-2.5-pro",
|
||||||
|
"object": "model",
|
||||||
|
"version": "2.5",
|
||||||
|
"name": "Gemini 2.5 Pro",
|
||||||
|
"description": "Stable release (June 17th, 2025) of Gemini 2.5 Pro",
|
||||||
|
"context_length": 1_048_576,
|
||||||
|
"max_completion_tokens": 65_536,
|
||||||
|
"supported_parameters": []string{
|
||||||
|
"tools",
|
||||||
|
"temperature",
|
||||||
|
"top_p",
|
||||||
|
"top_k",
|
||||||
|
},
|
||||||
|
"temperature": 1,
|
||||||
|
"topP": 0.95,
|
||||||
|
"topK": 64,
|
||||||
|
"maxTemperature": 2,
|
||||||
|
"thinking": true,
|
||||||
|
})
|
||||||
|
case "gemini-2.5-flash":
|
||||||
|
c.JSON(http.StatusOK, gin.H{
|
||||||
|
"id": "gemini-2.5-flash",
|
||||||
|
"object": "model",
|
||||||
|
"version": "001",
|
||||||
|
"name": "Gemini 2.5 Flash",
|
||||||
|
"description": "Stable version of Gemini 2.5 Flash, our mid-size multimodal model that supports up to 1 million tokens, released in June of 2025.",
|
||||||
|
"context_length": 1_048_576,
|
||||||
|
"max_completion_tokens": 65_536,
|
||||||
|
"supported_parameters": []string{
|
||||||
|
"tools",
|
||||||
|
"temperature",
|
||||||
|
"top_p",
|
||||||
|
"top_k",
|
||||||
|
},
|
||||||
|
"temperature": 1,
|
||||||
|
"topP": 0.95,
|
||||||
|
"topK": 64,
|
||||||
|
"maxTemperature": 2,
|
||||||
|
"thinking": true,
|
||||||
|
})
|
||||||
|
case "gpt-5":
|
||||||
|
c.JSON(http.StatusOK, gin.H{
|
||||||
|
"id": "gpt-5",
|
||||||
|
"object": "model",
|
||||||
|
"version": "gpt-5-2025-08-07",
|
||||||
|
"name": "GPT 5",
|
||||||
|
"description": "Stable version of GPT 5, The best model for coding and agentic tasks across domains.",
|
||||||
|
"context_length": 400_000,
|
||||||
|
"max_completion_tokens": 128_000,
|
||||||
|
"supported_parameters": []string{
|
||||||
|
"tools",
|
||||||
|
},
|
||||||
|
"temperature": 1,
|
||||||
|
"topP": 0.95,
|
||||||
|
"topK": 64,
|
||||||
|
"maxTemperature": 2,
|
||||||
|
"thinking": true,
|
||||||
|
})
|
||||||
|
default:
|
||||||
|
c.JSON(http.StatusNotFound, handlers.ErrorResponse{
|
||||||
|
Error: handlers.ErrorDetail{
|
||||||
|
Message: "Not Found",
|
||||||
|
Type: "not_found",
|
||||||
|
},
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// GeminiHandler handles POST requests for Gemini API operations.
|
||||||
|
// It routes requests to appropriate handlers based on the action parameter (model:method format).
|
||||||
|
func (h *GeminiAPIHandlers) GeminiHandler(c *gin.Context) {
|
||||||
|
var request struct {
|
||||||
|
Action string `uri:"action" binding:"required"`
|
||||||
|
}
|
||||||
|
if err := c.ShouldBindUri(&request); err != nil {
|
||||||
|
c.JSON(http.StatusBadRequest, handlers.ErrorResponse{
|
||||||
|
Error: handlers.ErrorDetail{
|
||||||
|
Message: fmt.Sprintf("Invalid request: %v", err),
|
||||||
|
Type: "invalid_request_error",
|
||||||
|
},
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
action := strings.Split(request.Action, ":")
|
||||||
|
if len(action) != 2 {
|
||||||
|
c.JSON(http.StatusNotFound, handlers.ErrorResponse{
|
||||||
|
Error: handlers.ErrorDetail{
|
||||||
|
Message: fmt.Sprintf("%s not found.", c.Request.URL.Path),
|
||||||
|
Type: "invalid_request_error",
|
||||||
|
},
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
modelName := action[0]
|
||||||
|
method := action[1]
|
||||||
|
rawJSON, _ := c.GetRawData()
|
||||||
|
rawJSON, _ = sjson.SetBytes(rawJSON, "model", []byte(modelName))
|
||||||
|
|
||||||
|
provider := util.GetProviderName(modelName)
|
||||||
|
if provider == "gemini" || provider == "unknow" {
|
||||||
|
switch method {
|
||||||
|
case "generateContent":
|
||||||
|
h.handleGeminiGenerateContent(c, rawJSON)
|
||||||
|
case "streamGenerateContent":
|
||||||
|
h.handleGeminiStreamGenerateContent(c, rawJSON)
|
||||||
|
case "countTokens":
|
||||||
|
h.handleGeminiCountTokens(c, rawJSON)
|
||||||
|
}
|
||||||
|
} else if provider == "gpt" {
|
||||||
|
switch method {
|
||||||
|
case "generateContent":
|
||||||
|
h.handleCodexGenerateContent(c, rawJSON)
|
||||||
|
case "streamGenerateContent":
|
||||||
|
h.handleCodexStreamGenerateContent(c, rawJSON)
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *GeminiAPIHandlers) handleGeminiStreamGenerateContent(c *gin.Context, rawJSON []byte) {
|
||||||
|
alt := h.GetAlt(c)
|
||||||
|
|
||||||
|
if alt == "" {
|
||||||
|
c.Header("Content-Type", "text/event-stream")
|
||||||
|
c.Header("Cache-Control", "no-cache")
|
||||||
|
c.Header("Connection", "keep-alive")
|
||||||
|
c.Header("Access-Control-Allow-Origin", "*")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get the http.Flusher interface to manually flush the response.
|
||||||
|
flusher, ok := c.Writer.(http.Flusher)
|
||||||
|
if !ok {
|
||||||
|
c.JSON(http.StatusInternalServerError, handlers.ErrorResponse{
|
||||||
|
Error: handlers.ErrorDetail{
|
||||||
|
Message: "Streaming not supported",
|
||||||
|
Type: "server_error",
|
||||||
|
},
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
modelResult := gjson.GetBytes(rawJSON, "model")
|
||||||
|
modelName := modelResult.String()
|
||||||
|
|
||||||
|
cliCtx, cliCancel := context.WithCancel(context.Background())
|
||||||
|
var cliClient client.Client
|
||||||
|
defer func() {
|
||||||
|
// Ensure the client's mutex is unlocked on function exit.
|
||||||
|
if cliClient != nil {
|
||||||
|
cliClient.GetRequestMutex().Unlock()
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
outLoop:
|
||||||
|
for {
|
||||||
|
var errorResponse *client.ErrorMessage
|
||||||
|
cliClient, errorResponse = h.GetClient(modelName)
|
||||||
|
if errorResponse != nil {
|
||||||
|
c.Status(errorResponse.StatusCode)
|
||||||
|
_, _ = fmt.Fprint(c.Writer, errorResponse.Error)
|
||||||
|
flusher.Flush()
|
||||||
|
cliCancel()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
template := ""
|
||||||
|
parsed := gjson.Parse(string(rawJSON))
|
||||||
|
contents := parsed.Get("request.contents")
|
||||||
|
if contents.Exists() {
|
||||||
|
template = string(rawJSON)
|
||||||
|
} else {
|
||||||
|
template = `{"project":"","request":{},"model":""}`
|
||||||
|
template, _ = sjson.SetRaw(template, "request", string(rawJSON))
|
||||||
|
template, _ = sjson.Set(template, "model", gjson.Get(template, "request.model").String())
|
||||||
|
template, _ = sjson.Delete(template, "request.model")
|
||||||
|
}
|
||||||
|
|
||||||
|
template, errFixCLIToolResponse := translatorGeminiToGeminiCli.FixCLIToolResponse(template)
|
||||||
|
if errFixCLIToolResponse != nil {
|
||||||
|
c.JSON(http.StatusInternalServerError, handlers.ErrorResponse{
|
||||||
|
Error: handlers.ErrorDetail{
|
||||||
|
Message: errFixCLIToolResponse.Error(),
|
||||||
|
Type: "server_error",
|
||||||
|
},
|
||||||
|
})
|
||||||
|
cliCancel()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
systemInstructionResult := gjson.Get(template, "request.system_instruction")
|
||||||
|
if systemInstructionResult.Exists() {
|
||||||
|
template, _ = sjson.SetRaw(template, "request.systemInstruction", systemInstructionResult.Raw)
|
||||||
|
template, _ = sjson.Delete(template, "request.system_instruction")
|
||||||
|
}
|
||||||
|
rawJSON = []byte(template)
|
||||||
|
|
||||||
|
if glAPIKey := cliClient.(*client.GeminiClient).GetGenerativeLanguageAPIKey(); glAPIKey != "" {
|
||||||
|
log.Debugf("Request use generative language API Key: %s", glAPIKey)
|
||||||
|
} else {
|
||||||
|
log.Debugf("Request cli use account: %s, project id: %s", cliClient.(*client.GeminiClient).GetEmail(), cliClient.(*client.GeminiClient).GetProjectID())
|
||||||
|
}
|
||||||
|
|
||||||
|
// Send the message and receive response chunks and errors via channels.
|
||||||
|
respChan, errChan := cliClient.SendRawMessageStream(cliCtx, rawJSON, alt)
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
// Handle client disconnection.
|
||||||
|
case <-c.Request.Context().Done():
|
||||||
|
if c.Request.Context().Err().Error() == "context canceled" {
|
||||||
|
log.Debugf("GeminiClient disconnected: %v", c.Request.Context().Err())
|
||||||
|
cliCancel() // Cancel the backend request.
|
||||||
|
return
|
||||||
|
}
|
||||||
|
// Process incoming response chunks.
|
||||||
|
case chunk, okStream := <-respChan:
|
||||||
|
if !okStream {
|
||||||
|
cliCancel()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if cliClient.(*client.GeminiClient).GetGenerativeLanguageAPIKey() == "" {
|
||||||
|
if alt == "" {
|
||||||
|
responseResult := gjson.GetBytes(chunk, "response")
|
||||||
|
if responseResult.Exists() {
|
||||||
|
chunk = []byte(responseResult.Raw)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
chunkTemplate := "[]"
|
||||||
|
responseResult := gjson.ParseBytes(chunk)
|
||||||
|
if responseResult.IsArray() {
|
||||||
|
responseResultItems := responseResult.Array()
|
||||||
|
for i := 0; i < len(responseResultItems); i++ {
|
||||||
|
responseResultItem := responseResultItems[i]
|
||||||
|
if responseResultItem.Get("response").Exists() {
|
||||||
|
chunkTemplate, _ = sjson.SetRaw(chunkTemplate, "-1", responseResultItem.Get("response").Raw)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
chunk = []byte(chunkTemplate)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if alt == "" {
|
||||||
|
_, _ = c.Writer.Write([]byte("data: "))
|
||||||
|
_, _ = c.Writer.Write(chunk)
|
||||||
|
_, _ = c.Writer.Write([]byte("\n\n"))
|
||||||
|
} else {
|
||||||
|
_, _ = c.Writer.Write(chunk)
|
||||||
|
}
|
||||||
|
flusher.Flush()
|
||||||
|
// Handle errors from the backend.
|
||||||
|
case err, okError := <-errChan:
|
||||||
|
if okError {
|
||||||
|
if err.StatusCode == 429 && h.Cfg.QuotaExceeded.SwitchProject {
|
||||||
|
log.Debugf("quota exceeded, switch client")
|
||||||
|
continue outLoop
|
||||||
|
} else {
|
||||||
|
log.Debugf("error code :%d, error: %v", err.StatusCode, err.Error.Error())
|
||||||
|
c.Status(err.StatusCode)
|
||||||
|
_, _ = fmt.Fprint(c.Writer, err.Error.Error())
|
||||||
|
flusher.Flush()
|
||||||
|
cliCancel()
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
// Send a keep-alive signal to the client.
|
||||||
|
case <-time.After(500 * time.Millisecond):
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *GeminiAPIHandlers) handleGeminiCountTokens(c *gin.Context, rawJSON []byte) {
|
||||||
|
c.Header("Content-Type", "application/json")
|
||||||
|
|
||||||
|
alt := h.GetAlt(c)
|
||||||
|
// orgrawJSON := rawJSON
|
||||||
|
modelResult := gjson.GetBytes(rawJSON, "model")
|
||||||
|
modelName := modelResult.String()
|
||||||
|
cliCtx, cliCancel := context.WithCancel(context.Background())
|
||||||
|
var cliClient client.Client
|
||||||
|
defer func() {
|
||||||
|
if cliClient != nil {
|
||||||
|
cliClient.GetRequestMutex().Unlock()
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
for {
|
||||||
|
var errorResponse *client.ErrorMessage
|
||||||
|
cliClient, errorResponse = h.GetClient(modelName, false)
|
||||||
|
if errorResponse != nil {
|
||||||
|
c.Status(errorResponse.StatusCode)
|
||||||
|
_, _ = fmt.Fprint(c.Writer, errorResponse.Error)
|
||||||
|
cliCancel()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if glAPIKey := cliClient.(*client.GeminiClient).GetGenerativeLanguageAPIKey(); glAPIKey != "" {
|
||||||
|
log.Debugf("Request use generative language API Key: %s", glAPIKey)
|
||||||
|
} else {
|
||||||
|
log.Debugf("Request cli use account: %s, project id: %s", cliClient.(*client.GeminiClient).GetEmail(), cliClient.(*client.GeminiClient).GetProjectID())
|
||||||
|
|
||||||
|
template := `{"request":{}}`
|
||||||
|
if gjson.GetBytes(rawJSON, "generateContentRequest").Exists() {
|
||||||
|
template, _ = sjson.SetRaw(template, "request", gjson.GetBytes(rawJSON, "generateContentRequest").Raw)
|
||||||
|
template, _ = sjson.Delete(template, "generateContentRequest")
|
||||||
|
} else if gjson.GetBytes(rawJSON, "contents").Exists() {
|
||||||
|
template, _ = sjson.SetRaw(template, "request.contents", gjson.GetBytes(rawJSON, "contents").Raw)
|
||||||
|
template, _ = sjson.Delete(template, "contents")
|
||||||
|
}
|
||||||
|
rawJSON = []byte(template)
|
||||||
|
}
|
||||||
|
|
||||||
|
resp, err := cliClient.SendRawTokenCount(cliCtx, rawJSON, alt)
|
||||||
|
if err != nil {
|
||||||
|
if err.StatusCode == 429 && h.Cfg.QuotaExceeded.SwitchProject {
|
||||||
|
continue
|
||||||
|
} else {
|
||||||
|
c.Status(err.StatusCode)
|
||||||
|
_, _ = c.Writer.Write([]byte(err.Error.Error()))
|
||||||
|
cliCancel()
|
||||||
|
// log.Debugf(err.Error.Error())
|
||||||
|
// log.Debugf(string(rawJSON))
|
||||||
|
// log.Debugf(string(orgrawJSON))
|
||||||
|
}
|
||||||
|
break
|
||||||
|
} else {
|
||||||
|
if cliClient.(*client.GeminiClient).GetGenerativeLanguageAPIKey() == "" {
|
||||||
|
responseResult := gjson.GetBytes(resp, "response")
|
||||||
|
if responseResult.Exists() {
|
||||||
|
resp = []byte(responseResult.Raw)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
_, _ = c.Writer.Write(resp)
|
||||||
|
cliCancel()
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *GeminiAPIHandlers) handleGeminiGenerateContent(c *gin.Context, rawJSON []byte) {
|
||||||
|
c.Header("Content-Type", "application/json")
|
||||||
|
|
||||||
|
alt := h.GetAlt(c)
|
||||||
|
|
||||||
|
modelResult := gjson.GetBytes(rawJSON, "model")
|
||||||
|
modelName := modelResult.String()
|
||||||
|
cliCtx, cliCancel := context.WithCancel(context.Background())
|
||||||
|
var cliClient client.Client
|
||||||
|
defer func() {
|
||||||
|
if cliClient != nil {
|
||||||
|
cliClient.GetRequestMutex().Unlock()
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
for {
|
||||||
|
var errorResponse *client.ErrorMessage
|
||||||
|
cliClient, errorResponse = h.GetClient(modelName)
|
||||||
|
if errorResponse != nil {
|
||||||
|
c.Status(errorResponse.StatusCode)
|
||||||
|
_, _ = fmt.Fprint(c.Writer, errorResponse.Error)
|
||||||
|
cliCancel()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
template := ""
|
||||||
|
parsed := gjson.Parse(string(rawJSON))
|
||||||
|
contents := parsed.Get("request.contents")
|
||||||
|
if contents.Exists() {
|
||||||
|
template = string(rawJSON)
|
||||||
|
} else {
|
||||||
|
template = `{"project":"","request":{},"model":""}`
|
||||||
|
template, _ = sjson.SetRaw(template, "request", string(rawJSON))
|
||||||
|
template, _ = sjson.Set(template, "model", gjson.Get(template, "request.model").String())
|
||||||
|
template, _ = sjson.Delete(template, "request.model")
|
||||||
|
}
|
||||||
|
|
||||||
|
template, errFixCLIToolResponse := translatorGeminiToGeminiCli.FixCLIToolResponse(template)
|
||||||
|
if errFixCLIToolResponse != nil {
|
||||||
|
c.JSON(http.StatusInternalServerError, handlers.ErrorResponse{
|
||||||
|
Error: handlers.ErrorDetail{
|
||||||
|
Message: errFixCLIToolResponse.Error(),
|
||||||
|
Type: "server_error",
|
||||||
|
},
|
||||||
|
})
|
||||||
|
cliCancel()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
systemInstructionResult := gjson.Get(template, "request.system_instruction")
|
||||||
|
if systemInstructionResult.Exists() {
|
||||||
|
template, _ = sjson.SetRaw(template, "request.systemInstruction", systemInstructionResult.Raw)
|
||||||
|
template, _ = sjson.Delete(template, "request.system_instruction")
|
||||||
|
}
|
||||||
|
rawJSON = []byte(template)
|
||||||
|
|
||||||
|
if glAPIKey := cliClient.(*client.GeminiClient).GetGenerativeLanguageAPIKey(); glAPIKey != "" {
|
||||||
|
log.Debugf("Request use generative language API Key: %s", glAPIKey)
|
||||||
|
} else {
|
||||||
|
log.Debugf("Request cli use account: %s, project id: %s", cliClient.(*client.GeminiClient).GetEmail(), cliClient.(*client.GeminiClient).GetProjectID())
|
||||||
|
}
|
||||||
|
resp, err := cliClient.SendRawMessage(cliCtx, rawJSON, alt)
|
||||||
|
if err != nil {
|
||||||
|
if err.StatusCode == 429 && h.Cfg.QuotaExceeded.SwitchProject {
|
||||||
|
continue
|
||||||
|
} else {
|
||||||
|
c.Status(err.StatusCode)
|
||||||
|
_, _ = c.Writer.Write([]byte(err.Error.Error()))
|
||||||
|
cliCancel()
|
||||||
|
}
|
||||||
|
break
|
||||||
|
} else {
|
||||||
|
if cliClient.(*client.GeminiClient).GetGenerativeLanguageAPIKey() == "" {
|
||||||
|
responseResult := gjson.GetBytes(resp, "response")
|
||||||
|
if responseResult.Exists() {
|
||||||
|
resp = []byte(responseResult.Raw)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
_, _ = c.Writer.Write(resp)
|
||||||
|
cliCancel()
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *GeminiAPIHandlers) handleCodexStreamGenerateContent(c *gin.Context, rawJSON []byte) {
|
||||||
|
c.Header("Content-Type", "text/event-stream")
|
||||||
|
c.Header("Cache-Control", "no-cache")
|
||||||
|
c.Header("Connection", "keep-alive")
|
||||||
|
c.Header("Access-Control-Allow-Origin", "*")
|
||||||
|
|
||||||
|
// Get the http.Flusher interface to manually flush the response.
|
||||||
|
flusher, ok := c.Writer.(http.Flusher)
|
||||||
|
if !ok {
|
||||||
|
c.JSON(http.StatusInternalServerError, handlers.ErrorResponse{
|
||||||
|
Error: handlers.ErrorDetail{
|
||||||
|
Message: "Streaming not supported",
|
||||||
|
Type: "server_error",
|
||||||
|
},
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Prepare the request for the backend client.
|
||||||
|
newRequestJSON := translatorGeminiToCodex.ConvertGeminiRequestToCodex(rawJSON)
|
||||||
|
// log.Debugf("Request: %s", newRequestJSON)
|
||||||
|
|
||||||
|
modelName := gjson.GetBytes(rawJSON, "model")
|
||||||
|
|
||||||
|
cliCtx, cliCancel := context.WithCancel(context.Background())
|
||||||
|
var cliClient client.Client
|
||||||
|
defer func() {
|
||||||
|
// Ensure the client's mutex is unlocked on function exit.
|
||||||
|
if cliClient != nil {
|
||||||
|
cliClient.GetRequestMutex().Unlock()
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
outLoop:
|
||||||
|
for {
|
||||||
|
var errorResponse *client.ErrorMessage
|
||||||
|
cliClient, errorResponse = h.GetClient(modelName.String())
|
||||||
|
if errorResponse != nil {
|
||||||
|
c.Status(errorResponse.StatusCode)
|
||||||
|
_, _ = fmt.Fprint(c.Writer, errorResponse.Error)
|
||||||
|
flusher.Flush()
|
||||||
|
cliCancel()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Debugf("Request codex use account: %s", cliClient.GetEmail())
|
||||||
|
|
||||||
|
// Send the message and receive response chunks and errors via channels.
|
||||||
|
respChan, errChan := cliClient.SendRawMessageStream(cliCtx, []byte(newRequestJSON), "")
|
||||||
|
params := &translatorGeminiToCodex.ConvertCodexResponseToGeminiParams{
|
||||||
|
Model: modelName.String(),
|
||||||
|
CreatedAt: 0,
|
||||||
|
ResponseID: "",
|
||||||
|
LastStorageOutput: "",
|
||||||
|
}
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
// Handle client disconnection.
|
||||||
|
case <-c.Request.Context().Done():
|
||||||
|
if c.Request.Context().Err().Error() == "context canceled" {
|
||||||
|
log.Debugf("CodexClient disconnected: %v", c.Request.Context().Err())
|
||||||
|
cliCancel() // Cancel the backend request.
|
||||||
|
return
|
||||||
|
}
|
||||||
|
// Process incoming response chunks.
|
||||||
|
case chunk, okStream := <-respChan:
|
||||||
|
if !okStream {
|
||||||
|
cliCancel()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if bytes.HasPrefix(chunk, []byte("data: ")) {
|
||||||
|
jsonData := chunk[6:]
|
||||||
|
data := gjson.ParseBytes(jsonData)
|
||||||
|
typeResult := data.Get("type")
|
||||||
|
if typeResult.String() != "" {
|
||||||
|
outputs := translatorGeminiToCodex.ConvertCodexResponseToGemini(jsonData, params)
|
||||||
|
if len(outputs) > 0 {
|
||||||
|
for i := 0; i < len(outputs); i++ {
|
||||||
|
_, _ = c.Writer.Write([]byte("data: "))
|
||||||
|
_, _ = c.Writer.Write([]byte(outputs[i]))
|
||||||
|
_, _ = c.Writer.Write([]byte("\n\n"))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// log.Debugf(string(jsonData))
|
||||||
|
}
|
||||||
|
flusher.Flush()
|
||||||
|
// Handle errors from the backend.
|
||||||
|
case err, okError := <-errChan:
|
||||||
|
if okError {
|
||||||
|
if err.StatusCode == 429 && h.Cfg.QuotaExceeded.SwitchProject {
|
||||||
|
continue outLoop
|
||||||
|
} else {
|
||||||
|
c.Status(err.StatusCode)
|
||||||
|
_, _ = fmt.Fprint(c.Writer, err.Error.Error())
|
||||||
|
flusher.Flush()
|
||||||
|
cliCancel()
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
// Send a keep-alive signal to the client.
|
||||||
|
case <-time.After(500 * time.Millisecond):
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *GeminiAPIHandlers) handleCodexGenerateContent(c *gin.Context, rawJSON []byte) {
|
||||||
|
c.Header("Content-Type", "application/json")
|
||||||
|
|
||||||
|
// Prepare the request for the backend client.
|
||||||
|
newRequestJSON := translatorGeminiToCodex.ConvertGeminiRequestToCodex(rawJSON)
|
||||||
|
// log.Debugf("Request: %s", newRequestJSON)
|
||||||
|
|
||||||
|
modelName := gjson.GetBytes(rawJSON, "model")
|
||||||
|
|
||||||
|
cliCtx, cliCancel := context.WithCancel(context.Background())
|
||||||
|
var cliClient client.Client
|
||||||
|
defer func() {
|
||||||
|
// Ensure the client's mutex is unlocked on function exit.
|
||||||
|
if cliClient != nil {
|
||||||
|
cliClient.GetRequestMutex().Unlock()
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
outLoop:
|
||||||
|
for {
|
||||||
|
var errorResponse *client.ErrorMessage
|
||||||
|
cliClient, errorResponse = h.GetClient(modelName.String())
|
||||||
|
if errorResponse != nil {
|
||||||
|
c.Status(errorResponse.StatusCode)
|
||||||
|
_, _ = fmt.Fprint(c.Writer, errorResponse.Error)
|
||||||
|
cliCancel()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Debugf("Request codex use account: %s", cliClient.GetEmail())
|
||||||
|
|
||||||
|
// Send the message and receive response chunks and errors via channels.
|
||||||
|
respChan, errChan := cliClient.SendRawMessageStream(cliCtx, []byte(newRequestJSON), "")
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
// Handle client disconnection.
|
||||||
|
case <-c.Request.Context().Done():
|
||||||
|
if c.Request.Context().Err().Error() == "context canceled" {
|
||||||
|
log.Debugf("CodexClient disconnected: %v", c.Request.Context().Err())
|
||||||
|
cliCancel() // Cancel the backend request.
|
||||||
|
return
|
||||||
|
}
|
||||||
|
// Process incoming response chunks.
|
||||||
|
case chunk, okStream := <-respChan:
|
||||||
|
if !okStream {
|
||||||
|
cliCancel()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if bytes.HasPrefix(chunk, []byte("data: ")) {
|
||||||
|
jsonData := chunk[6:]
|
||||||
|
data := gjson.ParseBytes(jsonData)
|
||||||
|
typeResult := data.Get("type")
|
||||||
|
if typeResult.String() != "" {
|
||||||
|
var geminiStr string
|
||||||
|
geminiStr = translatorGeminiToCodex.ConvertCodexResponseToGeminiNonStream(jsonData, modelName.String())
|
||||||
|
if geminiStr != "" {
|
||||||
|
_, _ = c.Writer.Write([]byte(geminiStr))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// Handle errors from the backend.
|
||||||
|
case err, okError := <-errChan:
|
||||||
|
if okError {
|
||||||
|
if err.StatusCode == 429 && h.Cfg.QuotaExceeded.SwitchProject {
|
||||||
|
continue outLoop
|
||||||
|
} else {
|
||||||
|
c.Status(err.StatusCode)
|
||||||
|
_, _ = fmt.Fprint(c.Writer, err.Error.Error())
|
||||||
|
cliCancel()
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
// Send a keep-alive signal to the client.
|
||||||
|
case <-time.After(500 * time.Millisecond):
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -5,52 +5,78 @@ package handlers
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"sync"
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
"github.com/luispater/CLIProxyAPI/internal/client"
|
"github.com/luispater/CLIProxyAPI/internal/client"
|
||||||
"github.com/luispater/CLIProxyAPI/internal/config"
|
"github.com/luispater/CLIProxyAPI/internal/config"
|
||||||
|
"github.com/luispater/CLIProxyAPI/internal/util"
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
"sync"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
// ErrorResponse represents a standard error response format for the API.
|
// ErrorResponse represents a standard error response format for the API.
|
||||||
// It contains a single ErrorDetail field.
|
// It contains a single ErrorDetail field.
|
||||||
type ErrorResponse struct {
|
type ErrorResponse struct {
|
||||||
|
// Error contains detailed information about the error that occurred.
|
||||||
Error ErrorDetail `json:"error"`
|
Error ErrorDetail `json:"error"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// ErrorDetail provides specific information about an error that occurred.
|
// ErrorDetail provides specific information about an error that occurred.
|
||||||
// It includes a human-readable message, an error type, and an optional error code.
|
// It includes a human-readable message, an error type, and an optional error code.
|
||||||
type ErrorDetail struct {
|
type ErrorDetail struct {
|
||||||
// A human-readable message providing more details about the error.
|
// Message is a human-readable message providing more details about the error.
|
||||||
Message string `json:"message"`
|
Message string `json:"message"`
|
||||||
// The type of error that occurred (e.g., "invalid_request_error").
|
|
||||||
|
// Type is the category of error that occurred (e.g., "invalid_request_error").
|
||||||
Type string `json:"type"`
|
Type string `json:"type"`
|
||||||
// A short code identifying the error, if applicable.
|
|
||||||
|
// Code is a short code identifying the error, if applicable.
|
||||||
Code string `json:"code,omitempty"`
|
Code string `json:"code,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// APIHandlers contains the handlers for API endpoints.
|
// APIHandlers contains the handlers for API endpoints.
|
||||||
// It holds a pool of clients to interact with the backend service.
|
// It holds a pool of clients to interact with the backend service and manages
|
||||||
|
// load balancing, client selection, and configuration.
|
||||||
type APIHandlers struct {
|
type APIHandlers struct {
|
||||||
CliClients []*client.Client
|
// CliClients is the pool of available AI service clients.
|
||||||
Cfg *config.Config
|
CliClients []client.Client
|
||||||
Mutex *sync.Mutex
|
|
||||||
LastUsedClientIndex int
|
// Cfg holds the current application configuration.
|
||||||
|
Cfg *config.Config
|
||||||
|
|
||||||
|
// Mutex ensures thread-safe access to shared resources.
|
||||||
|
Mutex *sync.Mutex
|
||||||
|
|
||||||
|
// LastUsedClientIndex tracks the last used client index for each provider
|
||||||
|
// to implement round-robin load balancing.
|
||||||
|
LastUsedClientIndex map[string]int
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewAPIHandlers creates a new API handlers instance.
|
// NewAPIHandlers creates a new API handlers instance.
|
||||||
// It takes a slice of clients and a debug flag as input.
|
// It takes a slice of clients and configuration as input.
|
||||||
func NewAPIHandlers(cliClients []*client.Client, cfg *config.Config) *APIHandlers {
|
//
|
||||||
|
// Parameters:
|
||||||
|
// - cliClients: A slice of AI service clients
|
||||||
|
// - cfg: The application configuration
|
||||||
|
//
|
||||||
|
// Returns:
|
||||||
|
// - *APIHandlers: A new API handlers instance
|
||||||
|
func NewAPIHandlers(cliClients []client.Client, cfg *config.Config) *APIHandlers {
|
||||||
return &APIHandlers{
|
return &APIHandlers{
|
||||||
CliClients: cliClients,
|
CliClients: cliClients,
|
||||||
Cfg: cfg,
|
Cfg: cfg,
|
||||||
Mutex: &sync.Mutex{},
|
Mutex: &sync.Mutex{},
|
||||||
LastUsedClientIndex: 0,
|
LastUsedClientIndex: make(map[string]int),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// UpdateClients updates the handlers' client list and configuration
|
// UpdateClients updates the handlers' client list and configuration.
|
||||||
func (h *APIHandlers) UpdateClients(clients []*client.Client, cfg *config.Config) {
|
// This method is called when the configuration or authentication tokens change.
|
||||||
|
//
|
||||||
|
// Parameters:
|
||||||
|
// - clients: The new slice of AI service clients
|
||||||
|
// - cfg: The new application configuration
|
||||||
|
func (h *APIHandlers) UpdateClients(clients []client.Client, cfg *config.Config) {
|
||||||
h.CliClients = clients
|
h.CliClients = clients
|
||||||
h.Cfg = cfg
|
h.Cfg = cfg
|
||||||
}
|
}
|
||||||
@@ -58,30 +84,63 @@ func (h *APIHandlers) UpdateClients(clients []*client.Client, cfg *config.Config
|
|||||||
// GetClient returns an available client from the pool using round-robin load balancing.
|
// GetClient returns an available client from the pool using round-robin load balancing.
|
||||||
// It checks for quota limits and tries to find an unlocked client for immediate use.
|
// It checks for quota limits and tries to find an unlocked client for immediate use.
|
||||||
// The modelName parameter is used to check quota status for specific models.
|
// The modelName parameter is used to check quota status for specific models.
|
||||||
func (h *APIHandlers) GetClient(modelName string, isGenerateContent ...bool) (*client.Client, *client.ErrorMessage) {
|
//
|
||||||
if len(h.CliClients) == 0 {
|
// Parameters:
|
||||||
|
// - modelName: The name of the model to be used
|
||||||
|
// - isGenerateContent: Optional parameter to indicate if this is for content generation
|
||||||
|
//
|
||||||
|
// Returns:
|
||||||
|
// - client.Client: An available client for the requested model
|
||||||
|
// - *client.ErrorMessage: An error message if no client is available
|
||||||
|
func (h *APIHandlers) GetClient(modelName string, isGenerateContent ...bool) (client.Client, *client.ErrorMessage) {
|
||||||
|
provider := util.GetProviderName(modelName)
|
||||||
|
clients := make([]client.Client, 0)
|
||||||
|
if provider == "gemini" {
|
||||||
|
for i := 0; i < len(h.CliClients); i++ {
|
||||||
|
if cli, ok := h.CliClients[i].(*client.GeminiClient); ok {
|
||||||
|
clients = append(clients, cli)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else if provider == "gpt" {
|
||||||
|
for i := 0; i < len(h.CliClients); i++ {
|
||||||
|
if cli, ok := h.CliClients[i].(*client.CodexClient); ok {
|
||||||
|
clients = append(clients, cli)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if _, hasKey := h.LastUsedClientIndex[provider]; !hasKey {
|
||||||
|
h.LastUsedClientIndex[provider] = 0
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(clients) == 0 {
|
||||||
return nil, &client.ErrorMessage{StatusCode: 500, Error: fmt.Errorf("no clients available")}
|
return nil, &client.ErrorMessage{StatusCode: 500, Error: fmt.Errorf("no clients available")}
|
||||||
}
|
}
|
||||||
|
|
||||||
var cliClient *client.Client
|
var cliClient client.Client
|
||||||
|
|
||||||
// Lock the mutex to update the last used client index
|
// Lock the mutex to update the last used client index
|
||||||
h.Mutex.Lock()
|
h.Mutex.Lock()
|
||||||
startIndex := h.LastUsedClientIndex
|
startIndex := h.LastUsedClientIndex[provider]
|
||||||
if (len(isGenerateContent) > 0 && isGenerateContent[0]) || len(isGenerateContent) == 0 {
|
if (len(isGenerateContent) > 0 && isGenerateContent[0]) || len(isGenerateContent) == 0 {
|
||||||
currentIndex := (startIndex + 1) % len(h.CliClients)
|
currentIndex := (startIndex + 1) % len(clients)
|
||||||
h.LastUsedClientIndex = currentIndex
|
h.LastUsedClientIndex[provider] = currentIndex
|
||||||
}
|
}
|
||||||
h.Mutex.Unlock()
|
h.Mutex.Unlock()
|
||||||
|
|
||||||
// Reorder the client to start from the last used index
|
// Reorder the client to start from the last used index
|
||||||
reorderedClients := make([]*client.Client, 0)
|
reorderedClients := make([]client.Client, 0)
|
||||||
for i := 0; i < len(h.CliClients); i++ {
|
for i := 0; i < len(clients); i++ {
|
||||||
cliClient = h.CliClients[(startIndex+1+i)%len(h.CliClients)]
|
cliClient = clients[(startIndex+1+i)%len(clients)]
|
||||||
if cliClient.IsModelQuotaExceeded(modelName) {
|
if cliClient.IsModelQuotaExceeded(modelName) {
|
||||||
log.Debugf("Model %s is quota exceeded for account %s, project id: %s", modelName, cliClient.GetEmail(), cliClient.GetProjectID())
|
if provider == "gemini" {
|
||||||
|
log.Debugf("Gemini Model %s is quota exceeded for account %s, project id: %s", modelName, cliClient.GetEmail(), cliClient.(*client.GeminiClient).GetProjectID())
|
||||||
|
} else if provider == "gpt" {
|
||||||
|
log.Debugf("Codex Model %s is quota exceeded for account %s", modelName, cliClient.GetEmail())
|
||||||
|
}
|
||||||
cliClient = nil
|
cliClient = nil
|
||||||
continue
|
continue
|
||||||
|
|
||||||
}
|
}
|
||||||
reorderedClients = append(reorderedClients, cliClient)
|
reorderedClients = append(reorderedClients, cliClient)
|
||||||
}
|
}
|
||||||
@@ -93,14 +152,14 @@ func (h *APIHandlers) GetClient(modelName string, isGenerateContent ...bool) (*c
|
|||||||
locked := false
|
locked := false
|
||||||
for i := 0; i < len(reorderedClients); i++ {
|
for i := 0; i < len(reorderedClients); i++ {
|
||||||
cliClient = reorderedClients[i]
|
cliClient = reorderedClients[i]
|
||||||
if cliClient.RequestMutex.TryLock() {
|
if cliClient.GetRequestMutex().TryLock() {
|
||||||
locked = true
|
locked = true
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if !locked {
|
if !locked {
|
||||||
cliClient = h.CliClients[0]
|
cliClient = clients[0]
|
||||||
cliClient.RequestMutex.Lock()
|
cliClient.GetRequestMutex().Lock()
|
||||||
}
|
}
|
||||||
|
|
||||||
return cliClient, nil
|
return cliClient, nil
|
||||||
@@ -108,6 +167,12 @@ func (h *APIHandlers) GetClient(modelName string, isGenerateContent ...bool) (*c
|
|||||||
|
|
||||||
// GetAlt extracts the 'alt' parameter from the request query string.
|
// GetAlt extracts the 'alt' parameter from the request query string.
|
||||||
// It checks both 'alt' and '$alt' parameters and returns the appropriate value.
|
// It checks both 'alt' and '$alt' parameters and returns the appropriate value.
|
||||||
|
//
|
||||||
|
// Parameters:
|
||||||
|
// - c: The Gin context containing the HTTP request
|
||||||
|
//
|
||||||
|
// Returns:
|
||||||
|
// - string: The alt parameter value, or empty string if it's "sse"
|
||||||
func (h *APIHandlers) GetAlt(c *gin.Context) string {
|
func (h *APIHandlers) GetAlt(c *gin.Context) string {
|
||||||
var alt string
|
var alt string
|
||||||
var hasAlt bool
|
var hasAlt bool
|
||||||
|
|||||||
@@ -1,264 +0,0 @@
|
|||||||
// Package openai provides HTTP handlers for OpenAI API endpoints.
|
|
||||||
// This package implements the OpenAI-compatible API interface, including model listing
|
|
||||||
// and chat completion functionality. It supports both streaming and non-streaming responses,
|
|
||||||
// and manages a pool of clients to interact with backend services.
|
|
||||||
// The handlers translate OpenAI API requests to the appropriate backend format and
|
|
||||||
// convert responses back to OpenAI-compatible format.
|
|
||||||
package openai
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"fmt"
|
|
||||||
"github.com/luispater/CLIProxyAPI/internal/api/handlers"
|
|
||||||
"github.com/luispater/CLIProxyAPI/internal/api/translator/openai"
|
|
||||||
"github.com/luispater/CLIProxyAPI/internal/client"
|
|
||||||
log "github.com/sirupsen/logrus"
|
|
||||||
"github.com/tidwall/gjson"
|
|
||||||
"net/http"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
|
||||||
)
|
|
||||||
|
|
||||||
// OpenAIAPIHandlers contains the handlers for OpenAI API endpoints.
|
|
||||||
// It holds a pool of clients to interact with the backend service.
|
|
||||||
type OpenAIAPIHandlers struct {
|
|
||||||
*handlers.APIHandlers
|
|
||||||
}
|
|
||||||
|
|
||||||
// NewOpenAIAPIHandlers creates a new OpenAI API handlers instance.
|
|
||||||
// It takes an APIHandlers instance as input and returns an OpenAIAPIHandlers.
|
|
||||||
func NewOpenAIAPIHandlers(apiHandlers *handlers.APIHandlers) *OpenAIAPIHandlers {
|
|
||||||
return &OpenAIAPIHandlers{
|
|
||||||
APIHandlers: apiHandlers,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Models handles the /v1/models endpoint.
|
|
||||||
// It returns a hardcoded list of available AI models.
|
|
||||||
func (h *OpenAIAPIHandlers) Models(c *gin.Context) {
|
|
||||||
c.JSON(http.StatusOK, gin.H{
|
|
||||||
"data": []map[string]any{
|
|
||||||
{
|
|
||||||
"id": "gemini-2.5-pro",
|
|
||||||
"object": "model",
|
|
||||||
"version": "2.5",
|
|
||||||
"name": "Gemini 2.5 Pro",
|
|
||||||
"description": "Stable release (June 17th, 2025) of Gemini 2.5 Pro",
|
|
||||||
"context_length": 1048576,
|
|
||||||
"max_completion_tokens": 65536,
|
|
||||||
"supported_parameters": []string{
|
|
||||||
"tools",
|
|
||||||
"temperature",
|
|
||||||
"top_p",
|
|
||||||
"top_k",
|
|
||||||
},
|
|
||||||
"temperature": 1,
|
|
||||||
"topP": 0.95,
|
|
||||||
"topK": 64,
|
|
||||||
"maxTemperature": 2,
|
|
||||||
"thinking": true,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": "gemini-2.5-flash",
|
|
||||||
"object": "model",
|
|
||||||
"version": "001",
|
|
||||||
"name": "Gemini 2.5 Flash",
|
|
||||||
"description": "Stable version of Gemini 2.5 Flash, our mid-size multimodal model that supports up to 1 million tokens, released in June of 2025.",
|
|
||||||
"context_length": 1048576,
|
|
||||||
"max_completion_tokens": 65536,
|
|
||||||
"supported_parameters": []string{
|
|
||||||
"tools",
|
|
||||||
"temperature",
|
|
||||||
"top_p",
|
|
||||||
"top_k",
|
|
||||||
},
|
|
||||||
"temperature": 1,
|
|
||||||
"topP": 0.95,
|
|
||||||
"topK": 64,
|
|
||||||
"maxTemperature": 2,
|
|
||||||
"thinking": true,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
// ChatCompletions handles the /v1/chat/completions endpoint.
|
|
||||||
// It determines whether the request is for a streaming or non-streaming response
|
|
||||||
// and calls the appropriate handler.
|
|
||||||
func (h *OpenAIAPIHandlers) ChatCompletions(c *gin.Context) {
|
|
||||||
rawJSON, err := c.GetRawData()
|
|
||||||
// If data retrieval fails, return a 400 Bad Request error.
|
|
||||||
if err != nil {
|
|
||||||
c.JSON(http.StatusBadRequest, handlers.ErrorResponse{
|
|
||||||
Error: handlers.ErrorDetail{
|
|
||||||
Message: fmt.Sprintf("Invalid request: %v", err),
|
|
||||||
Type: "invalid_request_error",
|
|
||||||
},
|
|
||||||
})
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// Check if the client requested a streaming response.
|
|
||||||
streamResult := gjson.GetBytes(rawJSON, "stream")
|
|
||||||
if streamResult.Type == gjson.True {
|
|
||||||
h.handleStreamingResponse(c, rawJSON)
|
|
||||||
} else {
|
|
||||||
h.handleNonStreamingResponse(c, rawJSON)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// handleNonStreamingResponse handles non-streaming chat completion responses.
|
|
||||||
// It selects a client from the pool, sends the request, and aggregates the response
|
|
||||||
// before sending it back to the client.
|
|
||||||
func (h *OpenAIAPIHandlers) handleNonStreamingResponse(c *gin.Context, rawJSON []byte) {
|
|
||||||
c.Header("Content-Type", "application/json")
|
|
||||||
|
|
||||||
modelName, systemInstruction, contents, tools := openai.PrepareRequest(rawJSON)
|
|
||||||
cliCtx, cliCancel := context.WithCancel(context.Background())
|
|
||||||
var cliClient *client.Client
|
|
||||||
defer func() {
|
|
||||||
if cliClient != nil {
|
|
||||||
cliClient.RequestMutex.Unlock()
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
|
|
||||||
for {
|
|
||||||
var errorResponse *client.ErrorMessage
|
|
||||||
cliClient, errorResponse = h.GetClient(modelName)
|
|
||||||
if errorResponse != nil {
|
|
||||||
c.Status(errorResponse.StatusCode)
|
|
||||||
_, _ = fmt.Fprint(c.Writer, errorResponse.Error)
|
|
||||||
cliCancel()
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
isGlAPIKey := false
|
|
||||||
if glAPIKey := cliClient.GetGenerativeLanguageAPIKey(); glAPIKey != "" {
|
|
||||||
log.Debugf("Request use generative language API Key: %s", glAPIKey)
|
|
||||||
isGlAPIKey = true
|
|
||||||
} else {
|
|
||||||
log.Debugf("Request use account: %s, project id: %s", cliClient.GetEmail(), cliClient.GetProjectID())
|
|
||||||
}
|
|
||||||
|
|
||||||
resp, err := cliClient.SendMessage(cliCtx, rawJSON, modelName, systemInstruction, contents, tools)
|
|
||||||
if err != nil {
|
|
||||||
if err.StatusCode == 429 && h.Cfg.QuotaExceeded.SwitchProject {
|
|
||||||
continue
|
|
||||||
} else {
|
|
||||||
c.Status(err.StatusCode)
|
|
||||||
_, _ = c.Writer.Write([]byte(err.Error.Error()))
|
|
||||||
cliCancel()
|
|
||||||
}
|
|
||||||
break
|
|
||||||
} else {
|
|
||||||
openAIFormat := openai.ConvertCliToOpenAINonStream(resp, time.Now().Unix(), isGlAPIKey)
|
|
||||||
if openAIFormat != "" {
|
|
||||||
_, _ = c.Writer.Write([]byte(openAIFormat))
|
|
||||||
}
|
|
||||||
cliCancel()
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// handleStreamingResponse handles streaming responses
|
|
||||||
func (h *OpenAIAPIHandlers) handleStreamingResponse(c *gin.Context, rawJSON []byte) {
|
|
||||||
c.Header("Content-Type", "text/event-stream")
|
|
||||||
c.Header("Cache-Control", "no-cache")
|
|
||||||
c.Header("Connection", "keep-alive")
|
|
||||||
c.Header("Access-Control-Allow-Origin", "*")
|
|
||||||
|
|
||||||
// Get the http.Flusher interface to manually flush the response.
|
|
||||||
flusher, ok := c.Writer.(http.Flusher)
|
|
||||||
if !ok {
|
|
||||||
c.JSON(http.StatusInternalServerError, handlers.ErrorResponse{
|
|
||||||
Error: handlers.ErrorDetail{
|
|
||||||
Message: "Streaming not supported",
|
|
||||||
Type: "server_error",
|
|
||||||
},
|
|
||||||
})
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// Prepare the request for the backend client.
|
|
||||||
modelName, systemInstruction, contents, tools := openai.PrepareRequest(rawJSON)
|
|
||||||
cliCtx, cliCancel := context.WithCancel(context.Background())
|
|
||||||
var cliClient *client.Client
|
|
||||||
defer func() {
|
|
||||||
// Ensure the client's mutex is unlocked on function exit.
|
|
||||||
if cliClient != nil {
|
|
||||||
cliClient.RequestMutex.Unlock()
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
|
|
||||||
outLoop:
|
|
||||||
for {
|
|
||||||
var errorResponse *client.ErrorMessage
|
|
||||||
cliClient, errorResponse = h.GetClient(modelName)
|
|
||||||
if errorResponse != nil {
|
|
||||||
c.Status(errorResponse.StatusCode)
|
|
||||||
_, _ = fmt.Fprint(c.Writer, errorResponse.Error)
|
|
||||||
flusher.Flush()
|
|
||||||
cliCancel()
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
isGlAPIKey := false
|
|
||||||
if glAPIKey := cliClient.GetGenerativeLanguageAPIKey(); glAPIKey != "" {
|
|
||||||
log.Debugf("Request use generative language API Key: %s", glAPIKey)
|
|
||||||
isGlAPIKey = true
|
|
||||||
} else {
|
|
||||||
log.Debugf("Request use account: %s, project id: %s", cliClient.GetEmail(), cliClient.GetProjectID())
|
|
||||||
}
|
|
||||||
// Send the message and receive response chunks and errors via channels.
|
|
||||||
respChan, errChan := cliClient.SendMessageStream(cliCtx, rawJSON, modelName, systemInstruction, contents, tools)
|
|
||||||
hasFirstResponse := false
|
|
||||||
for {
|
|
||||||
select {
|
|
||||||
// Handle client disconnection.
|
|
||||||
case <-c.Request.Context().Done():
|
|
||||||
if c.Request.Context().Err().Error() == "context canceled" {
|
|
||||||
log.Debugf("Client disconnected: %v", c.Request.Context().Err())
|
|
||||||
cliCancel() // Cancel the backend request.
|
|
||||||
return
|
|
||||||
}
|
|
||||||
// Process incoming response chunks.
|
|
||||||
case chunk, okStream := <-respChan:
|
|
||||||
if !okStream {
|
|
||||||
// Stream is closed, send the final [DONE] message.
|
|
||||||
_, _ = fmt.Fprintf(c.Writer, "data: [DONE]\n\n")
|
|
||||||
flusher.Flush()
|
|
||||||
cliCancel()
|
|
||||||
return
|
|
||||||
}
|
|
||||||
// Convert the chunk to OpenAI format and send it to the client.
|
|
||||||
hasFirstResponse = true
|
|
||||||
openAIFormat := openai.ConvertCliToOpenAI(chunk, time.Now().Unix(), isGlAPIKey)
|
|
||||||
if openAIFormat != "" {
|
|
||||||
_, _ = fmt.Fprintf(c.Writer, "data: %s\n\n", openAIFormat)
|
|
||||||
flusher.Flush()
|
|
||||||
}
|
|
||||||
// Handle errors from the backend.
|
|
||||||
case err, okError := <-errChan:
|
|
||||||
if okError {
|
|
||||||
if err.StatusCode == 429 && h.Cfg.QuotaExceeded.SwitchProject {
|
|
||||||
continue outLoop
|
|
||||||
} else {
|
|
||||||
c.Status(err.StatusCode)
|
|
||||||
_, _ = fmt.Fprint(c.Writer, err.Error.Error())
|
|
||||||
flusher.Flush()
|
|
||||||
cliCancel()
|
|
||||||
}
|
|
||||||
return
|
|
||||||
}
|
|
||||||
// Send a keep-alive signal to the client.
|
|
||||||
case <-time.After(500 * time.Millisecond):
|
|
||||||
if hasFirstResponse {
|
|
||||||
_, _ = c.Writer.Write([]byte(": CLI-PROXY-API PROCESSING\n\n"))
|
|
||||||
flusher.Flush()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
506
internal/api/handlers/openai/openai_handlers.go
Normal file
506
internal/api/handlers/openai/openai_handlers.go
Normal file
@@ -0,0 +1,506 @@
|
|||||||
|
// Package openai provides HTTP handlers for OpenAI API endpoints.
|
||||||
|
// This package implements the OpenAI-compatible API interface, including model listing
|
||||||
|
// and chat completion functionality. It supports both streaming and non-streaming responses,
|
||||||
|
// and manages a pool of clients to interact with backend services.
|
||||||
|
// The handlers translate OpenAI API requests to the appropriate backend format and
|
||||||
|
// convert responses back to OpenAI-compatible format.
|
||||||
|
package openai
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"net/http"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/luispater/CLIProxyAPI/internal/api/handlers"
|
||||||
|
"github.com/luispater/CLIProxyAPI/internal/client"
|
||||||
|
translatorOpenAIToCodex "github.com/luispater/CLIProxyAPI/internal/translator/codex/openai"
|
||||||
|
translatorOpenAIToGeminiCli "github.com/luispater/CLIProxyAPI/internal/translator/gemini-cli/openai"
|
||||||
|
"github.com/luispater/CLIProxyAPI/internal/util"
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
"github.com/tidwall/gjson"
|
||||||
|
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
)
|
||||||
|
|
||||||
|
// OpenAIAPIHandlers contains the handlers for OpenAI API endpoints.
|
||||||
|
// It holds a pool of clients to interact with the backend service.
|
||||||
|
type OpenAIAPIHandlers struct {
|
||||||
|
*handlers.APIHandlers
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewOpenAIAPIHandlers creates a new OpenAI API handlers instance.
|
||||||
|
// It takes an APIHandlers instance as input and returns an OpenAIAPIHandlers.
|
||||||
|
//
|
||||||
|
// Parameters:
|
||||||
|
// - apiHandlers: The base API handlers instance
|
||||||
|
//
|
||||||
|
// Returns:
|
||||||
|
// - *OpenAIAPIHandlers: A new OpenAI API handlers instance
|
||||||
|
func NewOpenAIAPIHandlers(apiHandlers *handlers.APIHandlers) *OpenAIAPIHandlers {
|
||||||
|
return &OpenAIAPIHandlers{
|
||||||
|
APIHandlers: apiHandlers,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Models handles the /v1/models endpoint.
|
||||||
|
// It returns a hardcoded list of available AI models with their capabilities
|
||||||
|
// and specifications in OpenAI-compatible format.
|
||||||
|
func (h *OpenAIAPIHandlers) Models(c *gin.Context) {
|
||||||
|
c.JSON(http.StatusOK, gin.H{
|
||||||
|
"data": []map[string]any{
|
||||||
|
{
|
||||||
|
"id": "gemini-2.5-pro",
|
||||||
|
"object": "model",
|
||||||
|
"version": "2.5",
|
||||||
|
"name": "Gemini 2.5 Pro",
|
||||||
|
"description": "Stable release (June 17th, 2025) of Gemini 2.5 Pro",
|
||||||
|
"context_length": 1_048_576,
|
||||||
|
"max_completion_tokens": 65_536,
|
||||||
|
"supported_parameters": []string{
|
||||||
|
"tools",
|
||||||
|
"temperature",
|
||||||
|
"top_p",
|
||||||
|
"top_k",
|
||||||
|
},
|
||||||
|
"temperature": 1,
|
||||||
|
"topP": 0.95,
|
||||||
|
"topK": 64,
|
||||||
|
"maxTemperature": 2,
|
||||||
|
"thinking": true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": "gemini-2.5-flash",
|
||||||
|
"object": "model",
|
||||||
|
"version": "001",
|
||||||
|
"name": "Gemini 2.5 Flash",
|
||||||
|
"description": "Stable version of Gemini 2.5 Flash, our mid-size multimodal model that supports up to 1 million tokens, released in June of 2025.",
|
||||||
|
"context_length": 1_048_576,
|
||||||
|
"max_completion_tokens": 65_536,
|
||||||
|
"supported_parameters": []string{
|
||||||
|
"tools",
|
||||||
|
"temperature",
|
||||||
|
"top_p",
|
||||||
|
"top_k",
|
||||||
|
},
|
||||||
|
"temperature": 1,
|
||||||
|
"topP": 0.95,
|
||||||
|
"topK": 64,
|
||||||
|
"maxTemperature": 2,
|
||||||
|
"thinking": true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": "gpt-5",
|
||||||
|
"object": "model",
|
||||||
|
"version": "gpt-5-2025-08-07",
|
||||||
|
"name": "GPT 5",
|
||||||
|
"description": "Stable version of GPT 5, The best model for coding and agentic tasks across domains.",
|
||||||
|
"context_length": 400_000,
|
||||||
|
"max_completion_tokens": 128_000,
|
||||||
|
"supported_parameters": []string{
|
||||||
|
"tools",
|
||||||
|
},
|
||||||
|
"temperature": 1,
|
||||||
|
"topP": 0.95,
|
||||||
|
"topK": 64,
|
||||||
|
"maxTemperature": 2,
|
||||||
|
"thinking": true,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// ChatCompletions handles the /v1/chat/completions endpoint.
|
||||||
|
// It determines whether the request is for a streaming or non-streaming response
|
||||||
|
// and calls the appropriate handler based on the model provider.
|
||||||
|
//
|
||||||
|
// Parameters:
|
||||||
|
// - c: The Gin context containing the HTTP request and response
|
||||||
|
func (h *OpenAIAPIHandlers) ChatCompletions(c *gin.Context) {
|
||||||
|
rawJSON, err := c.GetRawData()
|
||||||
|
// If data retrieval fails, return a 400 Bad Request error.
|
||||||
|
if err != nil {
|
||||||
|
c.JSON(http.StatusBadRequest, handlers.ErrorResponse{
|
||||||
|
Error: handlers.ErrorDetail{
|
||||||
|
Message: fmt.Sprintf("Invalid request: %v", err),
|
||||||
|
Type: "invalid_request_error",
|
||||||
|
},
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check if the client requested a streaming response.
|
||||||
|
streamResult := gjson.GetBytes(rawJSON, "stream")
|
||||||
|
modelName := gjson.GetBytes(rawJSON, "model")
|
||||||
|
provider := util.GetProviderName(modelName.String())
|
||||||
|
if provider == "gemini" {
|
||||||
|
if streamResult.Type == gjson.True {
|
||||||
|
h.handleGeminiStreamingResponse(c, rawJSON)
|
||||||
|
} else {
|
||||||
|
h.handleGeminiNonStreamingResponse(c, rawJSON)
|
||||||
|
}
|
||||||
|
} else if provider == "gpt" {
|
||||||
|
if streamResult.Type == gjson.True {
|
||||||
|
h.handleCodexStreamingResponse(c, rawJSON)
|
||||||
|
} else {
|
||||||
|
h.handleCodexNonStreamingResponse(c, rawJSON)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// handleGeminiNonStreamingResponse handles non-streaming chat completion responses
|
||||||
|
// for Gemini models. It selects a client from the pool, sends the request, and
|
||||||
|
// aggregates the response before sending it back to the client in OpenAI format.
|
||||||
|
//
|
||||||
|
// Parameters:
|
||||||
|
// - c: The Gin context containing the HTTP request and response
|
||||||
|
// - rawJSON: The raw JSON bytes of the OpenAI-compatible request
|
||||||
|
func (h *OpenAIAPIHandlers) handleGeminiNonStreamingResponse(c *gin.Context, rawJSON []byte) {
|
||||||
|
c.Header("Content-Type", "application/json")
|
||||||
|
|
||||||
|
modelName, systemInstruction, contents, tools := translatorOpenAIToGeminiCli.ConvertOpenAIChatRequestToCli(rawJSON)
|
||||||
|
cliCtx, cliCancel := context.WithCancel(context.Background())
|
||||||
|
var cliClient client.Client
|
||||||
|
defer func() {
|
||||||
|
if cliClient != nil {
|
||||||
|
cliClient.GetRequestMutex().Unlock()
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
for {
|
||||||
|
var errorResponse *client.ErrorMessage
|
||||||
|
cliClient, errorResponse = h.GetClient(modelName)
|
||||||
|
if errorResponse != nil {
|
||||||
|
c.Status(errorResponse.StatusCode)
|
||||||
|
_, _ = fmt.Fprint(c.Writer, errorResponse.Error)
|
||||||
|
cliCancel()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
isGlAPIKey := false
|
||||||
|
if glAPIKey := cliClient.(*client.GeminiClient).GetGenerativeLanguageAPIKey(); glAPIKey != "" {
|
||||||
|
log.Debugf("Request use generative language API Key: %s", glAPIKey)
|
||||||
|
isGlAPIKey = true
|
||||||
|
} else {
|
||||||
|
log.Debugf("Request cli use account: %s, project id: %s", cliClient.(*client.GeminiClient).GetEmail(), cliClient.(*client.GeminiClient).GetProjectID())
|
||||||
|
}
|
||||||
|
|
||||||
|
resp, err := cliClient.SendMessage(cliCtx, rawJSON, modelName, systemInstruction, contents, tools)
|
||||||
|
if err != nil {
|
||||||
|
if err.StatusCode == 429 && h.Cfg.QuotaExceeded.SwitchProject {
|
||||||
|
continue
|
||||||
|
} else {
|
||||||
|
c.Status(err.StatusCode)
|
||||||
|
_, _ = c.Writer.Write([]byte(err.Error.Error()))
|
||||||
|
cliCancel()
|
||||||
|
}
|
||||||
|
break
|
||||||
|
} else {
|
||||||
|
openAIFormat := translatorOpenAIToGeminiCli.ConvertCliResponseToOpenAIChatNonStream(resp, time.Now().Unix(), isGlAPIKey)
|
||||||
|
if openAIFormat != "" {
|
||||||
|
_, _ = c.Writer.Write([]byte(openAIFormat))
|
||||||
|
}
|
||||||
|
cliCancel()
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// handleGeminiStreamingResponse handles streaming responses for Gemini models.
|
||||||
|
// It establishes a streaming connection with the backend service and forwards
|
||||||
|
// the response chunks to the client in real-time using Server-Sent Events.
|
||||||
|
//
|
||||||
|
// Parameters:
|
||||||
|
// - c: The Gin context containing the HTTP request and response
|
||||||
|
// - rawJSON: The raw JSON bytes of the OpenAI-compatible request
|
||||||
|
func (h *OpenAIAPIHandlers) handleGeminiStreamingResponse(c *gin.Context, rawJSON []byte) {
|
||||||
|
c.Header("Content-Type", "text/event-stream")
|
||||||
|
c.Header("Cache-Control", "no-cache")
|
||||||
|
c.Header("Connection", "keep-alive")
|
||||||
|
c.Header("Access-Control-Allow-Origin", "*")
|
||||||
|
|
||||||
|
// Get the http.Flusher interface to manually flush the response.
|
||||||
|
flusher, ok := c.Writer.(http.Flusher)
|
||||||
|
if !ok {
|
||||||
|
c.JSON(http.StatusInternalServerError, handlers.ErrorResponse{
|
||||||
|
Error: handlers.ErrorDetail{
|
||||||
|
Message: "Streaming not supported",
|
||||||
|
Type: "server_error",
|
||||||
|
},
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Prepare the request for the backend client.
|
||||||
|
modelName, systemInstruction, contents, tools := translatorOpenAIToGeminiCli.ConvertOpenAIChatRequestToCli(rawJSON)
|
||||||
|
cliCtx, cliCancel := context.WithCancel(context.Background())
|
||||||
|
var cliClient client.Client
|
||||||
|
defer func() {
|
||||||
|
// Ensure the client's mutex is unlocked on function exit.
|
||||||
|
if cliClient != nil {
|
||||||
|
cliClient.GetRequestMutex().Unlock()
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
outLoop:
|
||||||
|
for {
|
||||||
|
var errorResponse *client.ErrorMessage
|
||||||
|
cliClient, errorResponse = h.GetClient(modelName)
|
||||||
|
if errorResponse != nil {
|
||||||
|
c.Status(errorResponse.StatusCode)
|
||||||
|
_, _ = fmt.Fprint(c.Writer, errorResponse.Error)
|
||||||
|
flusher.Flush()
|
||||||
|
cliCancel()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
isGlAPIKey := false
|
||||||
|
if glAPIKey := cliClient.(*client.GeminiClient).GetGenerativeLanguageAPIKey(); glAPIKey != "" {
|
||||||
|
log.Debugf("Request use generative language API Key: %s", glAPIKey)
|
||||||
|
isGlAPIKey = true
|
||||||
|
} else {
|
||||||
|
log.Debugf("Request cli use account: %s, project id: %s", cliClient.GetEmail(), cliClient.(*client.GeminiClient).GetProjectID())
|
||||||
|
}
|
||||||
|
// Send the message and receive response chunks and errors via channels.
|
||||||
|
respChan, errChan := cliClient.SendMessageStream(cliCtx, rawJSON, modelName, systemInstruction, contents, tools)
|
||||||
|
hasFirstResponse := false
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
// Handle client disconnection.
|
||||||
|
case <-c.Request.Context().Done():
|
||||||
|
if c.Request.Context().Err().Error() == "context canceled" {
|
||||||
|
log.Debugf("GeminiClient disconnected: %v", c.Request.Context().Err())
|
||||||
|
cliCancel() // Cancel the backend request.
|
||||||
|
return
|
||||||
|
}
|
||||||
|
// Process incoming response chunks.
|
||||||
|
case chunk, okStream := <-respChan:
|
||||||
|
if !okStream {
|
||||||
|
// Stream is closed, send the final [DONE] message.
|
||||||
|
_, _ = fmt.Fprintf(c.Writer, "data: [DONE]\n\n")
|
||||||
|
flusher.Flush()
|
||||||
|
cliCancel()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
// Convert the chunk to OpenAI format and send it to the client.
|
||||||
|
hasFirstResponse = true
|
||||||
|
openAIFormat := translatorOpenAIToGeminiCli.ConvertCliResponseToOpenAIChat(chunk, time.Now().Unix(), isGlAPIKey)
|
||||||
|
if openAIFormat != "" {
|
||||||
|
_, _ = fmt.Fprintf(c.Writer, "data: %s\n\n", openAIFormat)
|
||||||
|
flusher.Flush()
|
||||||
|
}
|
||||||
|
// Handle errors from the backend.
|
||||||
|
case err, okError := <-errChan:
|
||||||
|
if okError {
|
||||||
|
if err.StatusCode == 429 && h.Cfg.QuotaExceeded.SwitchProject {
|
||||||
|
continue outLoop
|
||||||
|
} else {
|
||||||
|
c.Status(err.StatusCode)
|
||||||
|
_, _ = fmt.Fprint(c.Writer, err.Error.Error())
|
||||||
|
flusher.Flush()
|
||||||
|
cliCancel()
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
// Send a keep-alive signal to the client.
|
||||||
|
case <-time.After(500 * time.Millisecond):
|
||||||
|
if hasFirstResponse {
|
||||||
|
_, _ = c.Writer.Write([]byte(": CLI-PROXY-API PROCESSING\n\n"))
|
||||||
|
flusher.Flush()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// handleCodexNonStreamingResponse handles non-streaming chat completion responses
|
||||||
|
// for OpenAI models. It selects a client from the pool, sends the request, and
|
||||||
|
// aggregates the response before sending it back to the client in OpenAI format.
|
||||||
|
//
|
||||||
|
// Parameters:
|
||||||
|
// - c: The Gin context containing the HTTP request and response
|
||||||
|
// - rawJSON: The raw JSON bytes of the OpenAI-compatible request
|
||||||
|
func (h *OpenAIAPIHandlers) handleCodexNonStreamingResponse(c *gin.Context, rawJSON []byte) {
|
||||||
|
c.Header("Content-Type", "application/json")
|
||||||
|
|
||||||
|
newRequestJSON := translatorOpenAIToCodex.ConvertOpenAIChatRequestToCodex(rawJSON)
|
||||||
|
modelName := gjson.GetBytes(rawJSON, "model")
|
||||||
|
cliCtx, cliCancel := context.WithCancel(context.Background())
|
||||||
|
var cliClient client.Client
|
||||||
|
defer func() {
|
||||||
|
if cliClient != nil {
|
||||||
|
cliClient.GetRequestMutex().Unlock()
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
outLoop:
|
||||||
|
for {
|
||||||
|
var errorResponse *client.ErrorMessage
|
||||||
|
cliClient, errorResponse = h.GetClient(modelName.String())
|
||||||
|
if errorResponse != nil {
|
||||||
|
c.Status(errorResponse.StatusCode)
|
||||||
|
_, _ = c.Writer.Write([]byte(errorResponse.Error.Error()))
|
||||||
|
cliCancel()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Debugf("Request codex use account: %s", cliClient.GetEmail())
|
||||||
|
|
||||||
|
// Send the message and receive response chunks and errors via channels.
|
||||||
|
respChan, errChan := cliClient.SendRawMessageStream(cliCtx, []byte(newRequestJSON), "")
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
// Handle client disconnection.
|
||||||
|
case <-c.Request.Context().Done():
|
||||||
|
if c.Request.Context().Err().Error() == "context canceled" {
|
||||||
|
log.Debugf("CodexClient disconnected: %v", c.Request.Context().Err())
|
||||||
|
cliCancel() // Cancel the backend request.
|
||||||
|
return
|
||||||
|
}
|
||||||
|
// Process incoming response chunks.
|
||||||
|
case chunk, okStream := <-respChan:
|
||||||
|
if !okStream {
|
||||||
|
cliCancel()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if bytes.HasPrefix(chunk, []byte("data: ")) {
|
||||||
|
jsonData := chunk[6:]
|
||||||
|
data := gjson.ParseBytes(jsonData)
|
||||||
|
typeResult := data.Get("type")
|
||||||
|
if typeResult.String() == "response.completed" {
|
||||||
|
responseResult := data.Get("response")
|
||||||
|
openaiStr := translatorOpenAIToCodex.ConvertCodexResponseToOpenAIChatNonStream(responseResult.Raw, time.Now().Unix())
|
||||||
|
_, _ = c.Writer.Write([]byte(openaiStr))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// Handle errors from the backend.
|
||||||
|
case err, okError := <-errChan:
|
||||||
|
if okError {
|
||||||
|
if err.StatusCode == 429 && h.Cfg.QuotaExceeded.SwitchProject {
|
||||||
|
continue outLoop
|
||||||
|
} else {
|
||||||
|
c.Status(err.StatusCode)
|
||||||
|
_, _ = c.Writer.Write([]byte(err.Error.Error()))
|
||||||
|
cliCancel()
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
// Send a keep-alive signal to the client.
|
||||||
|
case <-time.After(500 * time.Millisecond):
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// handleCodexStreamingResponse handles streaming responses for OpenAI models.
|
||||||
|
// It establishes a streaming connection with the backend service and forwards
|
||||||
|
// the response chunks to the client in real-time using Server-Sent Events.
|
||||||
|
//
|
||||||
|
// Parameters:
|
||||||
|
// - c: The Gin context containing the HTTP request and response
|
||||||
|
// - rawJSON: The raw JSON bytes of the OpenAI-compatible request
|
||||||
|
func (h *OpenAIAPIHandlers) handleCodexStreamingResponse(c *gin.Context, rawJSON []byte) {
|
||||||
|
c.Header("Content-Type", "text/event-stream")
|
||||||
|
c.Header("Cache-Control", "no-cache")
|
||||||
|
c.Header("Connection", "keep-alive")
|
||||||
|
c.Header("Access-Control-Allow-Origin", "*")
|
||||||
|
|
||||||
|
// Get the http.Flusher interface to manually flush the response.
|
||||||
|
flusher, ok := c.Writer.(http.Flusher)
|
||||||
|
if !ok {
|
||||||
|
c.JSON(http.StatusInternalServerError, handlers.ErrorResponse{
|
||||||
|
Error: handlers.ErrorDetail{
|
||||||
|
Message: "Streaming not supported",
|
||||||
|
Type: "server_error",
|
||||||
|
},
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Prepare the request for the backend client.
|
||||||
|
newRequestJSON := translatorOpenAIToCodex.ConvertOpenAIChatRequestToCodex(rawJSON)
|
||||||
|
// log.Debugf("Request: %s", newRequestJSON)
|
||||||
|
|
||||||
|
modelName := gjson.GetBytes(rawJSON, "model")
|
||||||
|
|
||||||
|
cliCtx, cliCancel := context.WithCancel(context.Background())
|
||||||
|
var cliClient client.Client
|
||||||
|
defer func() {
|
||||||
|
// Ensure the client's mutex is unlocked on function exit.
|
||||||
|
if cliClient != nil {
|
||||||
|
cliClient.GetRequestMutex().Unlock()
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
outLoop:
|
||||||
|
for {
|
||||||
|
var errorResponse *client.ErrorMessage
|
||||||
|
cliClient, errorResponse = h.GetClient(modelName.String())
|
||||||
|
if errorResponse != nil {
|
||||||
|
c.Status(errorResponse.StatusCode)
|
||||||
|
_, _ = fmt.Fprint(c.Writer, errorResponse.Error)
|
||||||
|
flusher.Flush()
|
||||||
|
cliCancel()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Debugf("Request codex use account: %s", cliClient.GetEmail())
|
||||||
|
|
||||||
|
// Send the message and receive response chunks and errors via channels.
|
||||||
|
var params *translatorOpenAIToCodex.ConvertCliToOpenAIParams
|
||||||
|
respChan, errChan := cliClient.SendRawMessageStream(cliCtx, []byte(newRequestJSON), "")
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
// Handle client disconnection.
|
||||||
|
case <-c.Request.Context().Done():
|
||||||
|
if c.Request.Context().Err().Error() == "context canceled" {
|
||||||
|
log.Debugf("CodexClient disconnected: %v", c.Request.Context().Err())
|
||||||
|
cliCancel() // Cancel the backend request.
|
||||||
|
return
|
||||||
|
}
|
||||||
|
// Process incoming response chunks.
|
||||||
|
case chunk, okStream := <-respChan:
|
||||||
|
if !okStream {
|
||||||
|
_, _ = c.Writer.Write([]byte("[done]\n\n"))
|
||||||
|
flusher.Flush()
|
||||||
|
cliCancel()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
// log.Debugf("Response: %s\n", string(chunk))
|
||||||
|
// Convert the chunk to OpenAI format and send it to the client.
|
||||||
|
if bytes.HasPrefix(chunk, []byte("data: ")) {
|
||||||
|
jsonData := chunk[6:]
|
||||||
|
data := gjson.ParseBytes(jsonData)
|
||||||
|
typeResult := data.Get("type")
|
||||||
|
if typeResult.String() != "" {
|
||||||
|
var openaiStr string
|
||||||
|
params, openaiStr = translatorOpenAIToCodex.ConvertCodexResponseToOpenAIChat(jsonData, params)
|
||||||
|
if openaiStr != "" {
|
||||||
|
_, _ = c.Writer.Write([]byte("data: "))
|
||||||
|
_, _ = c.Writer.Write([]byte(openaiStr))
|
||||||
|
_, _ = c.Writer.Write([]byte("\n\n"))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// log.Debugf(string(jsonData))
|
||||||
|
}
|
||||||
|
flusher.Flush()
|
||||||
|
// Handle errors from the backend.
|
||||||
|
case err, okError := <-errChan:
|
||||||
|
if okError {
|
||||||
|
if err.StatusCode == 429 && h.Cfg.QuotaExceeded.SwitchProject {
|
||||||
|
continue outLoop
|
||||||
|
} else {
|
||||||
|
c.Status(err.StatusCode)
|
||||||
|
_, _ = fmt.Fprint(c.Writer, err.Error.Error())
|
||||||
|
flusher.Flush()
|
||||||
|
cliCancel()
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
// Send a keep-alive signal to the client.
|
||||||
|
case <-time.After(500 * time.Millisecond):
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
88
internal/api/middleware/request_logging.go
Normal file
88
internal/api/middleware/request_logging.go
Normal file
@@ -0,0 +1,88 @@
|
|||||||
|
// Package middleware provides HTTP middleware components for the CLI Proxy API server.
|
||||||
|
// This file contains the request logging middleware that captures comprehensive
|
||||||
|
// request and response data when enabled through configuration.
|
||||||
|
package middleware
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"io"
|
||||||
|
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
"github.com/luispater/CLIProxyAPI/internal/logging"
|
||||||
|
)
|
||||||
|
|
||||||
|
// RequestLoggingMiddleware creates a Gin middleware function that logs HTTP requests and responses
|
||||||
|
// when enabled through the provided logger. The middleware has zero overhead when logging is disabled.
|
||||||
|
func RequestLoggingMiddleware(logger logging.RequestLogger) gin.HandlerFunc {
|
||||||
|
return func(c *gin.Context) {
|
||||||
|
// Early return if logging is disabled (zero overhead)
|
||||||
|
if !logger.IsEnabled() {
|
||||||
|
c.Next()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Capture request information
|
||||||
|
requestInfo, err := captureRequestInfo(c)
|
||||||
|
if err != nil {
|
||||||
|
// Log error but continue processing
|
||||||
|
// In a real implementation, you might want to use a proper logger here
|
||||||
|
c.Next()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create response writer wrapper
|
||||||
|
wrapper := NewResponseWriterWrapper(c.Writer, logger, requestInfo)
|
||||||
|
c.Writer = wrapper
|
||||||
|
|
||||||
|
// Process the request
|
||||||
|
c.Next()
|
||||||
|
|
||||||
|
// Finalize logging after request processing
|
||||||
|
if err := wrapper.Finalize(); err != nil {
|
||||||
|
// Log error but don't interrupt the response
|
||||||
|
// In a real implementation, you might want to use a proper logger here
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// captureRequestInfo extracts and captures request information for logging.
|
||||||
|
func captureRequestInfo(c *gin.Context) (*RequestInfo, error) {
|
||||||
|
// Capture URL
|
||||||
|
url := c.Request.URL.String()
|
||||||
|
if c.Request.URL.Path != "" {
|
||||||
|
url = c.Request.URL.Path
|
||||||
|
if c.Request.URL.RawQuery != "" {
|
||||||
|
url += "?" + c.Request.URL.RawQuery
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Capture method
|
||||||
|
method := c.Request.Method
|
||||||
|
|
||||||
|
// Capture headers
|
||||||
|
headers := make(map[string][]string)
|
||||||
|
for key, values := range c.Request.Header {
|
||||||
|
headers[key] = values
|
||||||
|
}
|
||||||
|
|
||||||
|
// Capture request body
|
||||||
|
var body []byte
|
||||||
|
if c.Request.Body != nil {
|
||||||
|
// Read the body
|
||||||
|
bodyBytes, err := io.ReadAll(c.Request.Body)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Restore the body for the actual request processing
|
||||||
|
c.Request.Body = io.NopCloser(bytes.NewBuffer(bodyBytes))
|
||||||
|
body = bodyBytes
|
||||||
|
}
|
||||||
|
|
||||||
|
return &RequestInfo{
|
||||||
|
URL: url,
|
||||||
|
Method: method,
|
||||||
|
Headers: headers,
|
||||||
|
Body: body,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
208
internal/api/middleware/response_writer.go
Normal file
208
internal/api/middleware/response_writer.go
Normal file
@@ -0,0 +1,208 @@
|
|||||||
|
// Package middleware provides HTTP middleware components for the CLI Proxy API server.
|
||||||
|
// This includes request logging middleware and response writer wrappers that capture
|
||||||
|
// request and response data for logging purposes while maintaining zero-latency performance.
|
||||||
|
package middleware
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
"github.com/luispater/CLIProxyAPI/internal/logging"
|
||||||
|
)
|
||||||
|
|
||||||
|
// RequestInfo holds information about the current request for logging purposes.
|
||||||
|
type RequestInfo struct {
|
||||||
|
URL string
|
||||||
|
Method string
|
||||||
|
Headers map[string][]string
|
||||||
|
Body []byte
|
||||||
|
}
|
||||||
|
|
||||||
|
// ResponseWriterWrapper wraps gin.ResponseWriter to capture response data for logging.
|
||||||
|
// It maintains zero-latency performance by prioritizing client response over logging operations.
|
||||||
|
type ResponseWriterWrapper struct {
|
||||||
|
gin.ResponseWriter
|
||||||
|
body *bytes.Buffer
|
||||||
|
isStreaming bool
|
||||||
|
streamWriter logging.StreamingLogWriter
|
||||||
|
chunkChannel chan []byte
|
||||||
|
logger logging.RequestLogger
|
||||||
|
requestInfo *RequestInfo
|
||||||
|
statusCode int
|
||||||
|
headers map[string][]string
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewResponseWriterWrapper creates a new response writer wrapper.
|
||||||
|
func NewResponseWriterWrapper(w gin.ResponseWriter, logger logging.RequestLogger, requestInfo *RequestInfo) *ResponseWriterWrapper {
|
||||||
|
return &ResponseWriterWrapper{
|
||||||
|
ResponseWriter: w,
|
||||||
|
body: &bytes.Buffer{},
|
||||||
|
logger: logger,
|
||||||
|
requestInfo: requestInfo,
|
||||||
|
headers: make(map[string][]string),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Write intercepts response data while maintaining normal Gin functionality.
|
||||||
|
// CRITICAL: This method prioritizes client response (zero-latency) over logging operations.
|
||||||
|
func (w *ResponseWriterWrapper) Write(data []byte) (int, error) {
|
||||||
|
// CRITICAL: Write to client first (zero latency)
|
||||||
|
n, err := w.ResponseWriter.Write(data)
|
||||||
|
|
||||||
|
// THEN: Handle logging based on response type
|
||||||
|
if w.isStreaming {
|
||||||
|
// For streaming responses: Send to async logging channel (non-blocking)
|
||||||
|
if w.chunkChannel != nil {
|
||||||
|
select {
|
||||||
|
case w.chunkChannel <- append([]byte(nil), data...): // Non-blocking send with copy
|
||||||
|
default: // Channel full, skip logging to avoid blocking
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// For non-streaming responses: Buffer complete response
|
||||||
|
w.body.Write(data)
|
||||||
|
}
|
||||||
|
|
||||||
|
return n, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// WriteHeader captures the status code and detects streaming responses.
|
||||||
|
func (w *ResponseWriterWrapper) WriteHeader(statusCode int) {
|
||||||
|
w.statusCode = statusCode
|
||||||
|
|
||||||
|
// Capture response headers
|
||||||
|
for key, values := range w.ResponseWriter.Header() {
|
||||||
|
w.headers[key] = values
|
||||||
|
}
|
||||||
|
|
||||||
|
// Detect streaming based on Content-Type
|
||||||
|
contentType := w.ResponseWriter.Header().Get("Content-Type")
|
||||||
|
w.isStreaming = w.detectStreaming(contentType)
|
||||||
|
|
||||||
|
// If streaming, initialize streaming log writer
|
||||||
|
if w.isStreaming && w.logger.IsEnabled() {
|
||||||
|
streamWriter, err := w.logger.LogStreamingRequest(
|
||||||
|
w.requestInfo.URL,
|
||||||
|
w.requestInfo.Method,
|
||||||
|
w.requestInfo.Headers,
|
||||||
|
w.requestInfo.Body,
|
||||||
|
)
|
||||||
|
if err == nil {
|
||||||
|
w.streamWriter = streamWriter
|
||||||
|
w.chunkChannel = make(chan []byte, 100) // Buffered channel for async writes
|
||||||
|
|
||||||
|
// Start async chunk processor
|
||||||
|
go w.processStreamingChunks()
|
||||||
|
|
||||||
|
// Write status immediately
|
||||||
|
_ = streamWriter.WriteStatus(statusCode, w.headers)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Call original WriteHeader
|
||||||
|
w.ResponseWriter.WriteHeader(statusCode)
|
||||||
|
}
|
||||||
|
|
||||||
|
// detectStreaming determines if the response is streaming based on Content-Type and request analysis.
|
||||||
|
func (w *ResponseWriterWrapper) detectStreaming(contentType string) bool {
|
||||||
|
// Check Content-Type for Server-Sent Events
|
||||||
|
if strings.Contains(contentType, "text/event-stream") {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check request body for streaming indicators
|
||||||
|
if w.requestInfo.Body != nil {
|
||||||
|
bodyStr := string(w.requestInfo.Body)
|
||||||
|
if strings.Contains(bodyStr, `"stream": true`) || strings.Contains(bodyStr, `"stream":true`) {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
// processStreamingChunks handles async processing of streaming chunks.
|
||||||
|
func (w *ResponseWriterWrapper) processStreamingChunks() {
|
||||||
|
if w.streamWriter == nil || w.chunkChannel == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
for chunk := range w.chunkChannel {
|
||||||
|
w.streamWriter.WriteChunkAsync(chunk)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Finalize completes the logging process for the response.
|
||||||
|
func (w *ResponseWriterWrapper) Finalize() error {
|
||||||
|
if !w.logger.IsEnabled() {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if w.isStreaming {
|
||||||
|
// Close streaming channel and writer
|
||||||
|
if w.chunkChannel != nil {
|
||||||
|
close(w.chunkChannel)
|
||||||
|
w.chunkChannel = nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if w.streamWriter != nil {
|
||||||
|
return w.streamWriter.Close()
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// Capture final status code and headers if not already captured
|
||||||
|
finalStatusCode := w.statusCode
|
||||||
|
if finalStatusCode == 0 {
|
||||||
|
// Get status from underlying ResponseWriter if available
|
||||||
|
if statusWriter, ok := w.ResponseWriter.(interface{ Status() int }); ok {
|
||||||
|
finalStatusCode = statusWriter.Status()
|
||||||
|
} else {
|
||||||
|
finalStatusCode = 200 // Default
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Capture final headers
|
||||||
|
finalHeaders := make(map[string][]string)
|
||||||
|
for key, values := range w.ResponseWriter.Header() {
|
||||||
|
finalHeaders[key] = values
|
||||||
|
}
|
||||||
|
// Merge with any headers we captured earlier
|
||||||
|
for key, values := range w.headers {
|
||||||
|
finalHeaders[key] = values
|
||||||
|
}
|
||||||
|
|
||||||
|
// Log complete non-streaming response
|
||||||
|
return w.logger.LogRequest(
|
||||||
|
w.requestInfo.URL,
|
||||||
|
w.requestInfo.Method,
|
||||||
|
w.requestInfo.Headers,
|
||||||
|
w.requestInfo.Body,
|
||||||
|
finalStatusCode,
|
||||||
|
finalHeaders,
|
||||||
|
w.body.Bytes(),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Status returns the HTTP status code of the response.
|
||||||
|
func (w *ResponseWriterWrapper) Status() int {
|
||||||
|
if w.statusCode == 0 {
|
||||||
|
return 200 // Default status code
|
||||||
|
}
|
||||||
|
return w.statusCode
|
||||||
|
}
|
||||||
|
|
||||||
|
// Size returns the size of the response body.
|
||||||
|
func (w *ResponseWriterWrapper) Size() int {
|
||||||
|
if w.isStreaming {
|
||||||
|
return -1 // Unknown size for streaming responses
|
||||||
|
}
|
||||||
|
return w.body.Len()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Written returns whether the response has been written.
|
||||||
|
func (w *ResponseWriterWrapper) Written() bool {
|
||||||
|
return w.statusCode != 0
|
||||||
|
}
|
||||||
@@ -8,31 +8,48 @@ import (
|
|||||||
"context"
|
"context"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"net/http"
|
||||||
|
"strings"
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
"github.com/luispater/CLIProxyAPI/internal/api/handlers"
|
"github.com/luispater/CLIProxyAPI/internal/api/handlers"
|
||||||
"github.com/luispater/CLIProxyAPI/internal/api/handlers/claude"
|
"github.com/luispater/CLIProxyAPI/internal/api/handlers/claude"
|
||||||
"github.com/luispater/CLIProxyAPI/internal/api/handlers/gemini"
|
"github.com/luispater/CLIProxyAPI/internal/api/handlers/gemini"
|
||||||
"github.com/luispater/CLIProxyAPI/internal/api/handlers/gemini/cli"
|
"github.com/luispater/CLIProxyAPI/internal/api/handlers/gemini/cli"
|
||||||
"github.com/luispater/CLIProxyAPI/internal/api/handlers/openai"
|
"github.com/luispater/CLIProxyAPI/internal/api/handlers/openai"
|
||||||
|
"github.com/luispater/CLIProxyAPI/internal/api/middleware"
|
||||||
"github.com/luispater/CLIProxyAPI/internal/client"
|
"github.com/luispater/CLIProxyAPI/internal/client"
|
||||||
"github.com/luispater/CLIProxyAPI/internal/config"
|
"github.com/luispater/CLIProxyAPI/internal/config"
|
||||||
|
"github.com/luispater/CLIProxyAPI/internal/logging"
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
"net/http"
|
|
||||||
"strings"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
// Server represents the main API server.
|
// Server represents the main API server.
|
||||||
// It encapsulates the Gin engine, HTTP server, handlers, and configuration.
|
// It encapsulates the Gin engine, HTTP server, handlers, and configuration.
|
||||||
type Server struct {
|
type Server struct {
|
||||||
engine *gin.Engine
|
// engine is the Gin web framework engine instance.
|
||||||
server *http.Server
|
engine *gin.Engine
|
||||||
|
|
||||||
|
// server is the underlying HTTP server.
|
||||||
|
server *http.Server
|
||||||
|
|
||||||
|
// handlers contains the API handlers for processing requests.
|
||||||
handlers *handlers.APIHandlers
|
handlers *handlers.APIHandlers
|
||||||
cfg *config.Config
|
|
||||||
|
// cfg holds the current server configuration.
|
||||||
|
cfg *config.Config
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewServer creates and initializes a new API server instance.
|
// NewServer creates and initializes a new API server instance.
|
||||||
// It sets up the Gin engine, middleware, routes, and handlers.
|
// It sets up the Gin engine, middleware, routes, and handlers.
|
||||||
func NewServer(cfg *config.Config, cliClients []*client.Client) *Server {
|
//
|
||||||
|
// Parameters:
|
||||||
|
// - cfg: The server configuration
|
||||||
|
// - cliClients: A slice of AI service clients
|
||||||
|
//
|
||||||
|
// Returns:
|
||||||
|
// - *Server: A new server instance
|
||||||
|
func NewServer(cfg *config.Config, cliClients []client.Client) *Server {
|
||||||
// Set gin mode
|
// Set gin mode
|
||||||
if !cfg.Debug {
|
if !cfg.Debug {
|
||||||
gin.SetMode(gin.ReleaseMode)
|
gin.SetMode(gin.ReleaseMode)
|
||||||
@@ -44,6 +61,11 @@ func NewServer(cfg *config.Config, cliClients []*client.Client) *Server {
|
|||||||
// Add middleware
|
// Add middleware
|
||||||
engine.Use(gin.Logger())
|
engine.Use(gin.Logger())
|
||||||
engine.Use(gin.Recovery())
|
engine.Use(gin.Recovery())
|
||||||
|
|
||||||
|
// Add request logging middleware (positioned after recovery, before auth)
|
||||||
|
requestLogger := logging.NewFileRequestLogger(cfg.RequestLog, "logs")
|
||||||
|
engine.Use(middleware.RequestLoggingMiddleware(requestLogger))
|
||||||
|
|
||||||
engine.Use(corsMiddleware())
|
engine.Use(corsMiddleware())
|
||||||
|
|
||||||
// Create server instance
|
// Create server instance
|
||||||
@@ -103,11 +125,13 @@ func (s *Server) setupRoutes() {
|
|||||||
})
|
})
|
||||||
})
|
})
|
||||||
s.engine.POST("/v1internal:method", geminiCLIHandlers.CLIHandler)
|
s.engine.POST("/v1internal:method", geminiCLIHandlers.CLIHandler)
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Start begins listening for and serving HTTP requests.
|
// Start begins listening for and serving HTTP requests.
|
||||||
// It's a blocking call and will only return on an unrecoverable error.
|
// It's a blocking call and will only return on an unrecoverable error.
|
||||||
|
//
|
||||||
|
// Returns:
|
||||||
|
// - error: An error if the server fails to start
|
||||||
func (s *Server) Start() error {
|
func (s *Server) Start() error {
|
||||||
log.Debugf("Starting API server on %s", s.server.Addr)
|
log.Debugf("Starting API server on %s", s.server.Addr)
|
||||||
|
|
||||||
@@ -121,6 +145,12 @@ func (s *Server) Start() error {
|
|||||||
|
|
||||||
// Stop gracefully shuts down the API server without interrupting any
|
// Stop gracefully shuts down the API server without interrupting any
|
||||||
// active connections.
|
// active connections.
|
||||||
|
//
|
||||||
|
// Parameters:
|
||||||
|
// - ctx: The context for graceful shutdown
|
||||||
|
//
|
||||||
|
// Returns:
|
||||||
|
// - error: An error if the server fails to stop
|
||||||
func (s *Server) Stop(ctx context.Context) error {
|
func (s *Server) Stop(ctx context.Context) error {
|
||||||
log.Debug("Stopping API server...")
|
log.Debug("Stopping API server...")
|
||||||
|
|
||||||
@@ -135,6 +165,9 @@ func (s *Server) Stop(ctx context.Context) error {
|
|||||||
|
|
||||||
// corsMiddleware returns a Gin middleware handler that adds CORS headers
|
// corsMiddleware returns a Gin middleware handler that adds CORS headers
|
||||||
// to every response, allowing cross-origin requests.
|
// to every response, allowing cross-origin requests.
|
||||||
|
//
|
||||||
|
// Returns:
|
||||||
|
// - gin.HandlerFunc: The CORS middleware handler
|
||||||
func corsMiddleware() gin.HandlerFunc {
|
func corsMiddleware() gin.HandlerFunc {
|
||||||
return func(c *gin.Context) {
|
return func(c *gin.Context) {
|
||||||
c.Header("Access-Control-Allow-Origin", "*")
|
c.Header("Access-Control-Allow-Origin", "*")
|
||||||
@@ -150,8 +183,13 @@ func corsMiddleware() gin.HandlerFunc {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// UpdateClients updates the server's client list and configuration
|
// UpdateClients updates the server's client list and configuration.
|
||||||
func (s *Server) UpdateClients(clients []*client.Client, cfg *config.Config) {
|
// This method is called when the configuration or authentication tokens change.
|
||||||
|
//
|
||||||
|
// Parameters:
|
||||||
|
// - clients: The new slice of AI service clients
|
||||||
|
// - cfg: The new application configuration
|
||||||
|
func (s *Server) UpdateClients(clients []client.Client, cfg *config.Config) {
|
||||||
s.cfg = cfg
|
s.cfg = cfg
|
||||||
s.handlers.UpdateClients(clients, cfg)
|
s.handlers.UpdateClients(clients, cfg)
|
||||||
log.Infof("server clients and configuration updated: %d clients", len(clients))
|
log.Infof("server clients and configuration updated: %d clients", len(clients))
|
||||||
@@ -159,6 +197,12 @@ func (s *Server) UpdateClients(clients []*client.Client, cfg *config.Config) {
|
|||||||
|
|
||||||
// AuthMiddleware returns a Gin middleware handler that authenticates requests
|
// AuthMiddleware returns a Gin middleware handler that authenticates requests
|
||||||
// using API keys. If no API keys are configured, it allows all requests.
|
// using API keys. If no API keys are configured, it allows all requests.
|
||||||
|
//
|
||||||
|
// Parameters:
|
||||||
|
// - cfg: The server configuration containing API keys
|
||||||
|
//
|
||||||
|
// Returns:
|
||||||
|
// - gin.HandlerFunc: The authentication middleware handler
|
||||||
func AuthMiddleware(cfg *config.Config) gin.HandlerFunc {
|
func AuthMiddleware(cfg *config.Config) gin.HandlerFunc {
|
||||||
return func(c *gin.Context) {
|
return func(c *gin.Context) {
|
||||||
if len(cfg.APIKeys) == 0 {
|
if len(cfg.APIKeys) == 0 {
|
||||||
|
|||||||
155
internal/auth/codex/errors.go
Normal file
155
internal/auth/codex/errors.go
Normal file
@@ -0,0 +1,155 @@
|
|||||||
|
package codex
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"net/http"
|
||||||
|
)
|
||||||
|
|
||||||
|
// OAuthError represents an OAuth-specific error
|
||||||
|
type OAuthError struct {
|
||||||
|
Code string `json:"error"`
|
||||||
|
Description string `json:"error_description,omitempty"`
|
||||||
|
URI string `json:"error_uri,omitempty"`
|
||||||
|
StatusCode int `json:"-"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *OAuthError) Error() string {
|
||||||
|
if e.Description != "" {
|
||||||
|
return fmt.Sprintf("OAuth error %s: %s", e.Code, e.Description)
|
||||||
|
}
|
||||||
|
return fmt.Sprintf("OAuth error: %s", e.Code)
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewOAuthError creates a new OAuth error
|
||||||
|
func NewOAuthError(code, description string, statusCode int) *OAuthError {
|
||||||
|
return &OAuthError{
|
||||||
|
Code: code,
|
||||||
|
Description: description,
|
||||||
|
StatusCode: statusCode,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// AuthenticationError represents authentication-related errors
|
||||||
|
type AuthenticationError struct {
|
||||||
|
Type string `json:"type"`
|
||||||
|
Message string `json:"message"`
|
||||||
|
Code int `json:"code"`
|
||||||
|
Cause error `json:"-"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *AuthenticationError) Error() string {
|
||||||
|
if e.Cause != nil {
|
||||||
|
return fmt.Sprintf("%s: %s (caused by: %v)", e.Type, e.Message, e.Cause)
|
||||||
|
}
|
||||||
|
return fmt.Sprintf("%s: %s", e.Type, e.Message)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Common authentication error types
|
||||||
|
var (
|
||||||
|
ErrTokenExpired = &AuthenticationError{
|
||||||
|
Type: "token_expired",
|
||||||
|
Message: "Access token has expired",
|
||||||
|
Code: http.StatusUnauthorized,
|
||||||
|
}
|
||||||
|
|
||||||
|
ErrInvalidState = &AuthenticationError{
|
||||||
|
Type: "invalid_state",
|
||||||
|
Message: "OAuth state parameter is invalid",
|
||||||
|
Code: http.StatusBadRequest,
|
||||||
|
}
|
||||||
|
|
||||||
|
ErrCodeExchangeFailed = &AuthenticationError{
|
||||||
|
Type: "code_exchange_failed",
|
||||||
|
Message: "Failed to exchange authorization code for tokens",
|
||||||
|
Code: http.StatusBadRequest,
|
||||||
|
}
|
||||||
|
|
||||||
|
ErrServerStartFailed = &AuthenticationError{
|
||||||
|
Type: "server_start_failed",
|
||||||
|
Message: "Failed to start OAuth callback server",
|
||||||
|
Code: http.StatusInternalServerError,
|
||||||
|
}
|
||||||
|
|
||||||
|
ErrPortInUse = &AuthenticationError{
|
||||||
|
Type: "port_in_use",
|
||||||
|
Message: "OAuth callback port is already in use",
|
||||||
|
Code: 13, // Special exit code for port-in-use
|
||||||
|
}
|
||||||
|
|
||||||
|
ErrCallbackTimeout = &AuthenticationError{
|
||||||
|
Type: "callback_timeout",
|
||||||
|
Message: "Timeout waiting for OAuth callback",
|
||||||
|
Code: http.StatusRequestTimeout,
|
||||||
|
}
|
||||||
|
|
||||||
|
ErrBrowserOpenFailed = &AuthenticationError{
|
||||||
|
Type: "browser_open_failed",
|
||||||
|
Message: "Failed to open browser for authentication",
|
||||||
|
Code: http.StatusInternalServerError,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
// NewAuthenticationError creates a new authentication error with a cause
|
||||||
|
func NewAuthenticationError(baseErr *AuthenticationError, cause error) *AuthenticationError {
|
||||||
|
return &AuthenticationError{
|
||||||
|
Type: baseErr.Type,
|
||||||
|
Message: baseErr.Message,
|
||||||
|
Code: baseErr.Code,
|
||||||
|
Cause: cause,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// IsAuthenticationError checks if an error is an authentication error
|
||||||
|
func IsAuthenticationError(err error) bool {
|
||||||
|
var authenticationError *AuthenticationError
|
||||||
|
ok := errors.As(err, &authenticationError)
|
||||||
|
return ok
|
||||||
|
}
|
||||||
|
|
||||||
|
// IsOAuthError checks if an error is an OAuth error
|
||||||
|
func IsOAuthError(err error) bool {
|
||||||
|
var oAuthError *OAuthError
|
||||||
|
ok := errors.As(err, &oAuthError)
|
||||||
|
return ok
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetUserFriendlyMessage returns a user-friendly error message
|
||||||
|
func GetUserFriendlyMessage(err error) string {
|
||||||
|
switch {
|
||||||
|
case IsAuthenticationError(err):
|
||||||
|
var authErr *AuthenticationError
|
||||||
|
errors.As(err, &authErr)
|
||||||
|
switch authErr.Type {
|
||||||
|
case "token_expired":
|
||||||
|
return "Your authentication has expired. Please log in again."
|
||||||
|
case "token_invalid":
|
||||||
|
return "Your authentication is invalid. Please log in again."
|
||||||
|
case "authentication_required":
|
||||||
|
return "Please log in to continue."
|
||||||
|
case "port_in_use":
|
||||||
|
return "The required port is already in use. Please close any applications using port 3000 and try again."
|
||||||
|
case "callback_timeout":
|
||||||
|
return "Authentication timed out. Please try again."
|
||||||
|
case "browser_open_failed":
|
||||||
|
return "Could not open your browser automatically. Please copy and paste the URL manually."
|
||||||
|
default:
|
||||||
|
return "Authentication failed. Please try again."
|
||||||
|
}
|
||||||
|
case IsOAuthError(err):
|
||||||
|
var oauthErr *OAuthError
|
||||||
|
errors.As(err, &oauthErr)
|
||||||
|
switch oauthErr.Code {
|
||||||
|
case "access_denied":
|
||||||
|
return "Authentication was cancelled or denied."
|
||||||
|
case "invalid_request":
|
||||||
|
return "Invalid authentication request. Please try again."
|
||||||
|
case "server_error":
|
||||||
|
return "Authentication server error. Please try again later."
|
||||||
|
default:
|
||||||
|
return fmt.Sprintf("Authentication failed: %s", oauthErr.Description)
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
return "An unexpected error occurred. Please try again."
|
||||||
|
}
|
||||||
|
}
|
||||||
210
internal/auth/codex/html_templates.go
Normal file
210
internal/auth/codex/html_templates.go
Normal file
@@ -0,0 +1,210 @@
|
|||||||
|
package codex
|
||||||
|
|
||||||
|
// LoginSuccessHtml is the template for the OAuth success page
|
||||||
|
const LoginSuccessHtml = `<!DOCTYPE html>
|
||||||
|
<html lang="en">
|
||||||
|
<head>
|
||||||
|
<meta charset="UTF-8">
|
||||||
|
<meta name="viewport" content="width=device-width, initial-scale=1.0">
|
||||||
|
<title>Authentication Successful - Codex</title>
|
||||||
|
<link rel="icon" type="image/svg+xml" href="data:image/svg+xml,%3Csvg xmlns='http://www.w3.org/2000/svg' viewBox='0 0 24 24' fill='%2310b981'%3E%3Cpath d='M9 12l2 2 4-4m6 2a9 9 0 11-18 0 9 9 0 0118 0z'/%3E%3C/svg%3E">
|
||||||
|
<style>
|
||||||
|
* {
|
||||||
|
box-sizing: border-box;
|
||||||
|
}
|
||||||
|
body {
|
||||||
|
font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, Oxygen, Ubuntu, Cantarell, sans-serif;
|
||||||
|
display: flex;
|
||||||
|
justify-content: center;
|
||||||
|
align-items: center;
|
||||||
|
min-height: 100vh;
|
||||||
|
margin: 0;
|
||||||
|
background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
|
||||||
|
padding: 1rem;
|
||||||
|
}
|
||||||
|
.container {
|
||||||
|
text-align: center;
|
||||||
|
background: white;
|
||||||
|
padding: 2.5rem;
|
||||||
|
border-radius: 12px;
|
||||||
|
box-shadow: 0 10px 25px rgba(0,0,0,0.1);
|
||||||
|
max-width: 480px;
|
||||||
|
width: 100%;
|
||||||
|
animation: slideIn 0.3s ease-out;
|
||||||
|
}
|
||||||
|
@keyframes slideIn {
|
||||||
|
from {
|
||||||
|
opacity: 0;
|
||||||
|
transform: translateY(-20px);
|
||||||
|
}
|
||||||
|
to {
|
||||||
|
opacity: 1;
|
||||||
|
transform: translateY(0);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
.success-icon {
|
||||||
|
width: 64px;
|
||||||
|
height: 64px;
|
||||||
|
margin: 0 auto 1.5rem;
|
||||||
|
background: #10b981;
|
||||||
|
border-radius: 50%;
|
||||||
|
display: flex;
|
||||||
|
align-items: center;
|
||||||
|
justify-content: center;
|
||||||
|
color: white;
|
||||||
|
font-size: 2rem;
|
||||||
|
font-weight: bold;
|
||||||
|
}
|
||||||
|
h1 {
|
||||||
|
color: #1f2937;
|
||||||
|
margin-bottom: 1rem;
|
||||||
|
font-size: 1.75rem;
|
||||||
|
font-weight: 600;
|
||||||
|
}
|
||||||
|
.subtitle {
|
||||||
|
color: #6b7280;
|
||||||
|
margin-bottom: 1.5rem;
|
||||||
|
font-size: 1rem;
|
||||||
|
line-height: 1.5;
|
||||||
|
}
|
||||||
|
.setup-notice {
|
||||||
|
background: #fef3c7;
|
||||||
|
border: 1px solid #f59e0b;
|
||||||
|
border-radius: 6px;
|
||||||
|
padding: 1rem;
|
||||||
|
margin: 1rem 0;
|
||||||
|
}
|
||||||
|
.setup-notice h3 {
|
||||||
|
color: #92400e;
|
||||||
|
margin: 0 0 0.5rem 0;
|
||||||
|
font-size: 1rem;
|
||||||
|
}
|
||||||
|
.setup-notice p {
|
||||||
|
color: #92400e;
|
||||||
|
margin: 0;
|
||||||
|
font-size: 0.875rem;
|
||||||
|
}
|
||||||
|
.setup-notice a {
|
||||||
|
color: #1d4ed8;
|
||||||
|
text-decoration: none;
|
||||||
|
}
|
||||||
|
.setup-notice a:hover {
|
||||||
|
text-decoration: underline;
|
||||||
|
}
|
||||||
|
.actions {
|
||||||
|
display: flex;
|
||||||
|
gap: 1rem;
|
||||||
|
justify-content: center;
|
||||||
|
flex-wrap: wrap;
|
||||||
|
margin-top: 2rem;
|
||||||
|
}
|
||||||
|
.button {
|
||||||
|
padding: 0.75rem 1.5rem;
|
||||||
|
border-radius: 8px;
|
||||||
|
font-size: 0.875rem;
|
||||||
|
font-weight: 500;
|
||||||
|
text-decoration: none;
|
||||||
|
transition: all 0.2s;
|
||||||
|
cursor: pointer;
|
||||||
|
border: none;
|
||||||
|
display: inline-flex;
|
||||||
|
align-items: center;
|
||||||
|
gap: 0.5rem;
|
||||||
|
}
|
||||||
|
.button-primary {
|
||||||
|
background: #3b82f6;
|
||||||
|
color: white;
|
||||||
|
}
|
||||||
|
.button-primary:hover {
|
||||||
|
background: #2563eb;
|
||||||
|
transform: translateY(-1px);
|
||||||
|
}
|
||||||
|
.button-secondary {
|
||||||
|
background: #f3f4f6;
|
||||||
|
color: #374151;
|
||||||
|
border: 1px solid #d1d5db;
|
||||||
|
}
|
||||||
|
.button-secondary:hover {
|
||||||
|
background: #e5e7eb;
|
||||||
|
}
|
||||||
|
.countdown {
|
||||||
|
color: #9ca3af;
|
||||||
|
font-size: 0.75rem;
|
||||||
|
margin-top: 1rem;
|
||||||
|
}
|
||||||
|
.footer {
|
||||||
|
margin-top: 2rem;
|
||||||
|
padding-top: 1.5rem;
|
||||||
|
border-top: 1px solid #e5e7eb;
|
||||||
|
color: #9ca3af;
|
||||||
|
font-size: 0.75rem;
|
||||||
|
}
|
||||||
|
.footer a {
|
||||||
|
color: #3b82f6;
|
||||||
|
text-decoration: none;
|
||||||
|
}
|
||||||
|
.footer a:hover {
|
||||||
|
text-decoration: underline;
|
||||||
|
}
|
||||||
|
</style>
|
||||||
|
</head>
|
||||||
|
<body>
|
||||||
|
<div class="container">
|
||||||
|
<div class="success-icon">✓</div>
|
||||||
|
<h1>Authentication Successful!</h1>
|
||||||
|
<p class="subtitle">You have successfully authenticated with Codex. You can now close this window and return to your terminal to continue.</p>
|
||||||
|
|
||||||
|
{{SETUP_NOTICE}}
|
||||||
|
|
||||||
|
<div class="actions">
|
||||||
|
<button class="button button-primary" onclick="window.close()">
|
||||||
|
<span>Close Window</span>
|
||||||
|
</button>
|
||||||
|
<a href="{{PLATFORM_URL}}" target="_blank" class="button button-secondary">
|
||||||
|
<span>Open Platform</span>
|
||||||
|
<span>↗</span>
|
||||||
|
</a>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
<div class="countdown">
|
||||||
|
This window will close automatically in <span id="countdown">10</span> seconds
|
||||||
|
</div>
|
||||||
|
|
||||||
|
<div class="footer">
|
||||||
|
<p>Powered by <a href="https://chatgpt.com" target="_blank">ChatGPT</a></p>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
<script>
|
||||||
|
let countdown = 10;
|
||||||
|
const countdownElement = document.getElementById('countdown');
|
||||||
|
|
||||||
|
const timer = setInterval(() => {
|
||||||
|
countdown--;
|
||||||
|
countdownElement.textContent = countdown;
|
||||||
|
|
||||||
|
if (countdown <= 0) {
|
||||||
|
clearInterval(timer);
|
||||||
|
window.close();
|
||||||
|
}
|
||||||
|
}, 1000);
|
||||||
|
|
||||||
|
// Close window when user presses Escape
|
||||||
|
document.addEventListener('keydown', (e) => {
|
||||||
|
if (e.key === 'Escape') {
|
||||||
|
window.close();
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
// Focus the close button for keyboard accessibility
|
||||||
|
document.querySelector('.button-primary').focus();
|
||||||
|
</script>
|
||||||
|
</body>
|
||||||
|
</html>`
|
||||||
|
|
||||||
|
// SetupNoticeHtml is the template for the setup notice section
|
||||||
|
const SetupNoticeHtml = `
|
||||||
|
<div class="setup-notice">
|
||||||
|
<h3>Additional Setup Required</h3>
|
||||||
|
<p>To complete your setup, please visit the <a href="{{PLATFORM_URL}}" target="_blank">Codex</a> to configure your account.</p>
|
||||||
|
</div>`
|
||||||
89
internal/auth/codex/jwt_parser.go
Normal file
89
internal/auth/codex/jwt_parser.go
Normal file
@@ -0,0 +1,89 @@
|
|||||||
|
package codex
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/base64"
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
// JWTClaims represents the claims section of a JWT token
|
||||||
|
type JWTClaims struct {
|
||||||
|
AtHash string `json:"at_hash"`
|
||||||
|
Aud []string `json:"aud"`
|
||||||
|
AuthProvider string `json:"auth_provider"`
|
||||||
|
AuthTime int `json:"auth_time"`
|
||||||
|
Email string `json:"email"`
|
||||||
|
EmailVerified bool `json:"email_verified"`
|
||||||
|
Exp int `json:"exp"`
|
||||||
|
CodexAuthInfo CodexAuthInfo `json:"https://api.openai.com/auth"`
|
||||||
|
Iat int `json:"iat"`
|
||||||
|
Iss string `json:"iss"`
|
||||||
|
Jti string `json:"jti"`
|
||||||
|
Rat int `json:"rat"`
|
||||||
|
Sid string `json:"sid"`
|
||||||
|
Sub string `json:"sub"`
|
||||||
|
}
|
||||||
|
type Organizations struct {
|
||||||
|
ID string `json:"id"`
|
||||||
|
IsDefault bool `json:"is_default"`
|
||||||
|
Role string `json:"role"`
|
||||||
|
Title string `json:"title"`
|
||||||
|
}
|
||||||
|
type CodexAuthInfo struct {
|
||||||
|
ChatgptAccountID string `json:"chatgpt_account_id"`
|
||||||
|
ChatgptPlanType string `json:"chatgpt_plan_type"`
|
||||||
|
ChatgptSubscriptionActiveStart any `json:"chatgpt_subscription_active_start"`
|
||||||
|
ChatgptSubscriptionActiveUntil any `json:"chatgpt_subscription_active_until"`
|
||||||
|
ChatgptSubscriptionLastChecked time.Time `json:"chatgpt_subscription_last_checked"`
|
||||||
|
ChatgptUserID string `json:"chatgpt_user_id"`
|
||||||
|
Groups []any `json:"groups"`
|
||||||
|
Organizations []Organizations `json:"organizations"`
|
||||||
|
UserID string `json:"user_id"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// ParseJWTToken parses a JWT token and extracts the claims without verification
|
||||||
|
// This is used for extracting user information from ID tokens
|
||||||
|
func ParseJWTToken(token string) (*JWTClaims, error) {
|
||||||
|
parts := strings.Split(token, ".")
|
||||||
|
if len(parts) != 3 {
|
||||||
|
return nil, fmt.Errorf("invalid JWT token format: expected 3 parts, got %d", len(parts))
|
||||||
|
}
|
||||||
|
|
||||||
|
// Decode the claims (payload) part
|
||||||
|
claimsData, err := base64URLDecode(parts[1])
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to decode JWT claims: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
var claims JWTClaims
|
||||||
|
if err = json.Unmarshal(claimsData, &claims); err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to unmarshal JWT claims: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return &claims, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// base64URLDecode decodes a base64 URL-encoded string with proper padding
|
||||||
|
func base64URLDecode(data string) ([]byte, error) {
|
||||||
|
// Add padding if necessary
|
||||||
|
switch len(data) % 4 {
|
||||||
|
case 2:
|
||||||
|
data += "=="
|
||||||
|
case 3:
|
||||||
|
data += "="
|
||||||
|
}
|
||||||
|
|
||||||
|
return base64.URLEncoding.DecodeString(data)
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetUserEmail extracts the user email from JWT claims
|
||||||
|
func (c *JWTClaims) GetUserEmail() string {
|
||||||
|
return c.Email
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetAccountID extracts the user ID from JWT claims (subject)
|
||||||
|
func (c *JWTClaims) GetAccountID() string {
|
||||||
|
return c.CodexAuthInfo.ChatgptAccountID
|
||||||
|
}
|
||||||
244
internal/auth/codex/oauth_server.go
Normal file
244
internal/auth/codex/oauth_server.go
Normal file
@@ -0,0 +1,244 @@
|
|||||||
|
package codex
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"net"
|
||||||
|
"net/http"
|
||||||
|
"strings"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
)
|
||||||
|
|
||||||
|
// OAuthServer handles the local HTTP server for OAuth callbacks
|
||||||
|
type OAuthServer struct {
|
||||||
|
server *http.Server
|
||||||
|
port int
|
||||||
|
resultChan chan *OAuthResult
|
||||||
|
errorChan chan error
|
||||||
|
mu sync.Mutex
|
||||||
|
running bool
|
||||||
|
}
|
||||||
|
|
||||||
|
// OAuthResult contains the result of the OAuth callback
|
||||||
|
type OAuthResult struct {
|
||||||
|
Code string
|
||||||
|
State string
|
||||||
|
Error string
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewOAuthServer creates a new OAuth callback server
|
||||||
|
func NewOAuthServer(port int) *OAuthServer {
|
||||||
|
return &OAuthServer{
|
||||||
|
port: port,
|
||||||
|
resultChan: make(chan *OAuthResult, 1),
|
||||||
|
errorChan: make(chan error, 1),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Start starts the OAuth callback server
|
||||||
|
func (s *OAuthServer) Start(ctx context.Context) error {
|
||||||
|
s.mu.Lock()
|
||||||
|
defer s.mu.Unlock()
|
||||||
|
|
||||||
|
if s.running {
|
||||||
|
return fmt.Errorf("server is already running")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check if port is available
|
||||||
|
if !s.isPortAvailable() {
|
||||||
|
return fmt.Errorf("port %d is already in use", s.port)
|
||||||
|
}
|
||||||
|
|
||||||
|
mux := http.NewServeMux()
|
||||||
|
mux.HandleFunc("/auth/callback", s.handleCallback)
|
||||||
|
mux.HandleFunc("/success", s.handleSuccess)
|
||||||
|
|
||||||
|
s.server = &http.Server{
|
||||||
|
Addr: fmt.Sprintf(":%d", s.port),
|
||||||
|
Handler: mux,
|
||||||
|
ReadTimeout: 10 * time.Second,
|
||||||
|
WriteTimeout: 10 * time.Second,
|
||||||
|
}
|
||||||
|
|
||||||
|
s.running = true
|
||||||
|
|
||||||
|
// Start server in goroutine
|
||||||
|
go func() {
|
||||||
|
if err := s.server.ListenAndServe(); err != nil && !errors.Is(err, http.ErrServerClosed) {
|
||||||
|
s.errorChan <- fmt.Errorf("server failed to start: %w", err)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
// Give server a moment to start
|
||||||
|
time.Sleep(100 * time.Millisecond)
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Stop gracefully stops the OAuth callback server
|
||||||
|
func (s *OAuthServer) Stop(ctx context.Context) error {
|
||||||
|
s.mu.Lock()
|
||||||
|
defer s.mu.Unlock()
|
||||||
|
|
||||||
|
if !s.running || s.server == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Debug("Stopping OAuth callback server")
|
||||||
|
|
||||||
|
// Create a context with timeout for shutdown
|
||||||
|
shutdownCtx, cancel := context.WithTimeout(ctx, 5*time.Second)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
err := s.server.Shutdown(shutdownCtx)
|
||||||
|
s.running = false
|
||||||
|
s.server = nil
|
||||||
|
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// WaitForCallback waits for the OAuth callback with a timeout
|
||||||
|
func (s *OAuthServer) WaitForCallback(timeout time.Duration) (*OAuthResult, error) {
|
||||||
|
select {
|
||||||
|
case result := <-s.resultChan:
|
||||||
|
return result, nil
|
||||||
|
case err := <-s.errorChan:
|
||||||
|
return nil, err
|
||||||
|
case <-time.After(timeout):
|
||||||
|
return nil, fmt.Errorf("timeout waiting for OAuth callback")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// handleCallback handles the OAuth callback endpoint
|
||||||
|
func (s *OAuthServer) handleCallback(w http.ResponseWriter, r *http.Request) {
|
||||||
|
log.Debug("Received OAuth callback")
|
||||||
|
|
||||||
|
// Validate request method
|
||||||
|
if r.Method != http.MethodGet {
|
||||||
|
http.Error(w, "Method not allowed", http.StatusMethodNotAllowed)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Extract parameters
|
||||||
|
query := r.URL.Query()
|
||||||
|
code := query.Get("code")
|
||||||
|
state := query.Get("state")
|
||||||
|
errorParam := query.Get("error")
|
||||||
|
|
||||||
|
// Validate required parameters
|
||||||
|
if errorParam != "" {
|
||||||
|
log.Errorf("OAuth error received: %s", errorParam)
|
||||||
|
result := &OAuthResult{
|
||||||
|
Error: errorParam,
|
||||||
|
}
|
||||||
|
s.sendResult(result)
|
||||||
|
http.Error(w, fmt.Sprintf("OAuth error: %s", errorParam), http.StatusBadRequest)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if code == "" {
|
||||||
|
log.Error("No authorization code received")
|
||||||
|
result := &OAuthResult{
|
||||||
|
Error: "no_code",
|
||||||
|
}
|
||||||
|
s.sendResult(result)
|
||||||
|
http.Error(w, "No authorization code received", http.StatusBadRequest)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if state == "" {
|
||||||
|
log.Error("No state parameter received")
|
||||||
|
result := &OAuthResult{
|
||||||
|
Error: "no_state",
|
||||||
|
}
|
||||||
|
s.sendResult(result)
|
||||||
|
http.Error(w, "No state parameter received", http.StatusBadRequest)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Send successful result
|
||||||
|
result := &OAuthResult{
|
||||||
|
Code: code,
|
||||||
|
State: state,
|
||||||
|
}
|
||||||
|
s.sendResult(result)
|
||||||
|
|
||||||
|
// Redirect to success page
|
||||||
|
http.Redirect(w, r, "/success", http.StatusFound)
|
||||||
|
}
|
||||||
|
|
||||||
|
// handleSuccess handles the success page endpoint
|
||||||
|
func (s *OAuthServer) handleSuccess(w http.ResponseWriter, r *http.Request) {
|
||||||
|
log.Debug("Serving success page")
|
||||||
|
|
||||||
|
w.Header().Set("Content-Type", "text/html; charset=utf-8")
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
|
||||||
|
// Parse query parameters for customization
|
||||||
|
query := r.URL.Query()
|
||||||
|
setupRequired := query.Get("setup_required") == "true"
|
||||||
|
platformURL := query.Get("platform_url")
|
||||||
|
if platformURL == "" {
|
||||||
|
platformURL = "https://platform.openai.com"
|
||||||
|
}
|
||||||
|
|
||||||
|
// Generate success page HTML with dynamic content
|
||||||
|
successHTML := s.generateSuccessHTML(setupRequired, platformURL)
|
||||||
|
|
||||||
|
_, err := w.Write([]byte(successHTML))
|
||||||
|
if err != nil {
|
||||||
|
log.Errorf("Failed to write success page: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// generateSuccessHTML creates the HTML content for the success page
|
||||||
|
func (s *OAuthServer) generateSuccessHTML(setupRequired bool, platformURL string) string {
|
||||||
|
html := LoginSuccessHtml
|
||||||
|
|
||||||
|
// Replace platform URL placeholder
|
||||||
|
html = strings.Replace(html, "{{PLATFORM_URL}}", platformURL, -1)
|
||||||
|
|
||||||
|
// Add setup notice if required
|
||||||
|
if setupRequired {
|
||||||
|
setupNotice := strings.Replace(SetupNoticeHtml, "{{PLATFORM_URL}}", platformURL, -1)
|
||||||
|
html = strings.Replace(html, "{{SETUP_NOTICE}}", setupNotice, 1)
|
||||||
|
} else {
|
||||||
|
html = strings.Replace(html, "{{SETUP_NOTICE}}", "", 1)
|
||||||
|
}
|
||||||
|
|
||||||
|
return html
|
||||||
|
}
|
||||||
|
|
||||||
|
// sendResult sends the OAuth result to the waiting channel
|
||||||
|
func (s *OAuthServer) sendResult(result *OAuthResult) {
|
||||||
|
select {
|
||||||
|
case s.resultChan <- result:
|
||||||
|
log.Debug("OAuth result sent to channel")
|
||||||
|
default:
|
||||||
|
log.Warn("OAuth result channel is full, result dropped")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// isPortAvailable checks if the specified port is available
|
||||||
|
func (s *OAuthServer) isPortAvailable() bool {
|
||||||
|
addr := fmt.Sprintf(":%d", s.port)
|
||||||
|
listener, err := net.Listen("tcp", addr)
|
||||||
|
if err != nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
defer func() {
|
||||||
|
_ = listener.Close()
|
||||||
|
}()
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
// IsRunning returns whether the server is currently running
|
||||||
|
func (s *OAuthServer) IsRunning() bool {
|
||||||
|
s.mu.Lock()
|
||||||
|
defer s.mu.Unlock()
|
||||||
|
return s.running
|
||||||
|
}
|
||||||
36
internal/auth/codex/openai.go
Normal file
36
internal/auth/codex/openai.go
Normal file
@@ -0,0 +1,36 @@
|
|||||||
|
package codex
|
||||||
|
|
||||||
|
// PKCECodes holds PKCE verification codes for OAuth2 PKCE flow
|
||||||
|
type PKCECodes struct {
|
||||||
|
// CodeVerifier is the cryptographically random string used to correlate
|
||||||
|
// the authorization request to the token request
|
||||||
|
CodeVerifier string `json:"code_verifier"`
|
||||||
|
// CodeChallenge is the SHA256 hash of the code verifier, base64url-encoded
|
||||||
|
CodeChallenge string `json:"code_challenge"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// CodexTokenData holds OAuth token information from OpenAI
|
||||||
|
type CodexTokenData struct {
|
||||||
|
// IDToken is the JWT ID token containing user claims
|
||||||
|
IDToken string `json:"id_token"`
|
||||||
|
// AccessToken is the OAuth2 access token for API access
|
||||||
|
AccessToken string `json:"access_token"`
|
||||||
|
// RefreshToken is used to obtain new access tokens
|
||||||
|
RefreshToken string `json:"refresh_token"`
|
||||||
|
// AccountID is the OpenAI account identifier
|
||||||
|
AccountID string `json:"account_id"`
|
||||||
|
// Email is the OpenAI account email
|
||||||
|
Email string `json:"email"`
|
||||||
|
// Expire is the timestamp of the token expire
|
||||||
|
Expire string `json:"expired"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// CodexAuthBundle aggregates authentication data after OAuth flow completion
|
||||||
|
type CodexAuthBundle struct {
|
||||||
|
// APIKey is the OpenAI API key obtained from token exchange
|
||||||
|
APIKey string `json:"api_key"`
|
||||||
|
// TokenData contains the OAuth tokens from the authentication flow
|
||||||
|
TokenData CodexTokenData `json:"token_data"`
|
||||||
|
// LastRefresh is the timestamp of the last token refresh
|
||||||
|
LastRefresh string `json:"last_refresh"`
|
||||||
|
}
|
||||||
269
internal/auth/codex/openai_auth.go
Normal file
269
internal/auth/codex/openai_auth.go
Normal file
@@ -0,0 +1,269 @@
|
|||||||
|
package codex
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
"net/url"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/luispater/CLIProxyAPI/internal/config"
|
||||||
|
"github.com/luispater/CLIProxyAPI/internal/util"
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
openaiAuthURL = "https://auth.openai.com/oauth/authorize"
|
||||||
|
openaiTokenURL = "https://auth.openai.com/oauth/token"
|
||||||
|
openaiClientID = "app_EMoamEEZ73f0CkXaXp7hrann"
|
||||||
|
redirectURI = "http://localhost:1455/auth/callback"
|
||||||
|
)
|
||||||
|
|
||||||
|
// CodexAuth handles OpenAI OAuth2 authentication flow
|
||||||
|
type CodexAuth struct {
|
||||||
|
httpClient *http.Client
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewCodexAuth creates a new OpenAI authentication service
|
||||||
|
func NewCodexAuth(cfg *config.Config) *CodexAuth {
|
||||||
|
return &CodexAuth{
|
||||||
|
httpClient: util.SetProxy(cfg, &http.Client{}),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// GenerateAuthURL creates the OAuth authorization URL with PKCE
|
||||||
|
func (o *CodexAuth) GenerateAuthURL(state string, pkceCodes *PKCECodes) (string, error) {
|
||||||
|
if pkceCodes == nil {
|
||||||
|
return "", fmt.Errorf("PKCE codes are required")
|
||||||
|
}
|
||||||
|
|
||||||
|
params := url.Values{
|
||||||
|
"client_id": {openaiClientID},
|
||||||
|
"response_type": {"code"},
|
||||||
|
"redirect_uri": {redirectURI},
|
||||||
|
"scope": {"openid email profile offline_access"},
|
||||||
|
"state": {state},
|
||||||
|
"code_challenge": {pkceCodes.CodeChallenge},
|
||||||
|
"code_challenge_method": {"S256"},
|
||||||
|
"prompt": {"login"},
|
||||||
|
"id_token_add_organizations": {"true"},
|
||||||
|
"codex_cli_simplified_flow": {"true"},
|
||||||
|
}
|
||||||
|
|
||||||
|
authURL := fmt.Sprintf("%s?%s", openaiAuthURL, params.Encode())
|
||||||
|
return authURL, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// ExchangeCodeForTokens exchanges authorization code for access tokens
|
||||||
|
func (o *CodexAuth) ExchangeCodeForTokens(ctx context.Context, code string, pkceCodes *PKCECodes) (*CodexAuthBundle, error) {
|
||||||
|
if pkceCodes == nil {
|
||||||
|
return nil, fmt.Errorf("PKCE codes are required for token exchange")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Prepare token exchange request
|
||||||
|
data := url.Values{
|
||||||
|
"grant_type": {"authorization_code"},
|
||||||
|
"client_id": {openaiClientID},
|
||||||
|
"code": {code},
|
||||||
|
"redirect_uri": {redirectURI},
|
||||||
|
"code_verifier": {pkceCodes.CodeVerifier},
|
||||||
|
}
|
||||||
|
|
||||||
|
req, err := http.NewRequestWithContext(ctx, "POST", openaiTokenURL, strings.NewReader(data.Encode()))
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to create token request: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
||||||
|
req.Header.Set("Accept", "application/json")
|
||||||
|
|
||||||
|
resp, err := o.httpClient.Do(req)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("token exchange request failed: %w", err)
|
||||||
|
}
|
||||||
|
defer func() {
|
||||||
|
_ = resp.Body.Close()
|
||||||
|
}()
|
||||||
|
|
||||||
|
body, err := io.ReadAll(resp.Body)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to read token response: %w", err)
|
||||||
|
}
|
||||||
|
// log.Debugf("Token response: %s", string(body))
|
||||||
|
|
||||||
|
if resp.StatusCode != http.StatusOK {
|
||||||
|
return nil, fmt.Errorf("token exchange failed with status %d: %s", resp.StatusCode, string(body))
|
||||||
|
}
|
||||||
|
|
||||||
|
// Parse token response
|
||||||
|
var tokenResp struct {
|
||||||
|
AccessToken string `json:"access_token"`
|
||||||
|
RefreshToken string `json:"refresh_token"`
|
||||||
|
IDToken string `json:"id_token"`
|
||||||
|
TokenType string `json:"token_type"`
|
||||||
|
ExpiresIn int `json:"expires_in"`
|
||||||
|
}
|
||||||
|
|
||||||
|
if err = json.Unmarshal(body, &tokenResp); err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to parse token response: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Extract account ID from ID token
|
||||||
|
claims, err := ParseJWTToken(tokenResp.IDToken)
|
||||||
|
if err != nil {
|
||||||
|
log.Warnf("Failed to parse ID token: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
accountID := ""
|
||||||
|
email := ""
|
||||||
|
if claims != nil {
|
||||||
|
accountID = claims.GetAccountID()
|
||||||
|
email = claims.GetUserEmail()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create token data
|
||||||
|
tokenData := CodexTokenData{
|
||||||
|
IDToken: tokenResp.IDToken,
|
||||||
|
AccessToken: tokenResp.AccessToken,
|
||||||
|
RefreshToken: tokenResp.RefreshToken,
|
||||||
|
AccountID: accountID,
|
||||||
|
Email: email,
|
||||||
|
Expire: time.Now().Add(time.Duration(tokenResp.ExpiresIn) * time.Second).Format(time.RFC3339),
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create auth bundle
|
||||||
|
bundle := &CodexAuthBundle{
|
||||||
|
TokenData: tokenData,
|
||||||
|
LastRefresh: time.Now().Format(time.RFC3339),
|
||||||
|
}
|
||||||
|
|
||||||
|
return bundle, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// RefreshTokens refreshes the access token using the refresh token
|
||||||
|
func (o *CodexAuth) RefreshTokens(ctx context.Context, refreshToken string) (*CodexTokenData, error) {
|
||||||
|
if refreshToken == "" {
|
||||||
|
return nil, fmt.Errorf("refresh token is required")
|
||||||
|
}
|
||||||
|
|
||||||
|
data := url.Values{
|
||||||
|
"client_id": {openaiClientID},
|
||||||
|
"grant_type": {"refresh_token"},
|
||||||
|
"refresh_token": {refreshToken},
|
||||||
|
"scope": {"openid profile email"},
|
||||||
|
}
|
||||||
|
|
||||||
|
req, err := http.NewRequestWithContext(ctx, "POST", openaiTokenURL, strings.NewReader(data.Encode()))
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to create refresh request: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
||||||
|
req.Header.Set("Accept", "application/json")
|
||||||
|
|
||||||
|
resp, err := o.httpClient.Do(req)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("token refresh request failed: %w", err)
|
||||||
|
}
|
||||||
|
defer func() {
|
||||||
|
_ = resp.Body.Close()
|
||||||
|
}()
|
||||||
|
|
||||||
|
body, err := io.ReadAll(resp.Body)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to read refresh response: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if resp.StatusCode != http.StatusOK {
|
||||||
|
return nil, fmt.Errorf("token refresh failed with status %d: %s", resp.StatusCode, string(body))
|
||||||
|
}
|
||||||
|
|
||||||
|
var tokenResp struct {
|
||||||
|
AccessToken string `json:"access_token"`
|
||||||
|
RefreshToken string `json:"refresh_token"`
|
||||||
|
IDToken string `json:"id_token"`
|
||||||
|
TokenType string `json:"token_type"`
|
||||||
|
ExpiresIn int `json:"expires_in"`
|
||||||
|
}
|
||||||
|
|
||||||
|
if err = json.Unmarshal(body, &tokenResp); err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to parse refresh response: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Extract account ID from ID token
|
||||||
|
claims, err := ParseJWTToken(tokenResp.IDToken)
|
||||||
|
if err != nil {
|
||||||
|
log.Warnf("Failed to parse refreshed ID token: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
accountID := ""
|
||||||
|
email := ""
|
||||||
|
if claims != nil {
|
||||||
|
accountID = claims.GetAccountID()
|
||||||
|
email = claims.Email
|
||||||
|
}
|
||||||
|
|
||||||
|
return &CodexTokenData{
|
||||||
|
IDToken: tokenResp.IDToken,
|
||||||
|
AccessToken: tokenResp.AccessToken,
|
||||||
|
RefreshToken: tokenResp.RefreshToken,
|
||||||
|
AccountID: accountID,
|
||||||
|
Email: email,
|
||||||
|
Expire: time.Now().Add(time.Duration(tokenResp.ExpiresIn) * time.Second).Format(time.RFC3339),
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// CreateTokenStorage creates a new CodexTokenStorage from auth bundle and user info
|
||||||
|
func (o *CodexAuth) CreateTokenStorage(bundle *CodexAuthBundle) *CodexTokenStorage {
|
||||||
|
storage := &CodexTokenStorage{
|
||||||
|
IDToken: bundle.TokenData.IDToken,
|
||||||
|
AccessToken: bundle.TokenData.AccessToken,
|
||||||
|
RefreshToken: bundle.TokenData.RefreshToken,
|
||||||
|
AccountID: bundle.TokenData.AccountID,
|
||||||
|
LastRefresh: bundle.LastRefresh,
|
||||||
|
Email: bundle.TokenData.Email,
|
||||||
|
Expire: bundle.TokenData.Expire,
|
||||||
|
}
|
||||||
|
|
||||||
|
return storage
|
||||||
|
}
|
||||||
|
|
||||||
|
// RefreshTokensWithRetry refreshes tokens with automatic retry logic
|
||||||
|
func (o *CodexAuth) RefreshTokensWithRetry(ctx context.Context, refreshToken string, maxRetries int) (*CodexTokenData, error) {
|
||||||
|
var lastErr error
|
||||||
|
|
||||||
|
for attempt := 0; attempt < maxRetries; attempt++ {
|
||||||
|
if attempt > 0 {
|
||||||
|
// Wait before retry
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
return nil, ctx.Err()
|
||||||
|
case <-time.After(time.Duration(attempt) * time.Second):
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
tokenData, err := o.RefreshTokens(ctx, refreshToken)
|
||||||
|
if err == nil {
|
||||||
|
return tokenData, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
lastErr = err
|
||||||
|
log.Warnf("Token refresh attempt %d failed: %v", attempt+1, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil, fmt.Errorf("token refresh failed after %d attempts: %w", maxRetries, lastErr)
|
||||||
|
}
|
||||||
|
|
||||||
|
// UpdateTokenStorage updates an existing token storage with new token data
|
||||||
|
func (o *CodexAuth) UpdateTokenStorage(storage *CodexTokenStorage, tokenData *CodexTokenData) {
|
||||||
|
storage.IDToken = tokenData.IDToken
|
||||||
|
storage.AccessToken = tokenData.AccessToken
|
||||||
|
storage.RefreshToken = tokenData.RefreshToken
|
||||||
|
storage.AccountID = tokenData.AccountID
|
||||||
|
storage.LastRefresh = time.Now().Format(time.RFC3339)
|
||||||
|
storage.Email = tokenData.Email
|
||||||
|
storage.Expire = tokenData.Expire
|
||||||
|
}
|
||||||
47
internal/auth/codex/pkce.go
Normal file
47
internal/auth/codex/pkce.go
Normal file
@@ -0,0 +1,47 @@
|
|||||||
|
package codex
|
||||||
|
|
||||||
|
import (
|
||||||
|
"crypto/rand"
|
||||||
|
"crypto/sha256"
|
||||||
|
"encoding/base64"
|
||||||
|
"fmt"
|
||||||
|
)
|
||||||
|
|
||||||
|
// GeneratePKCECodes generates a PKCE code verifier and challenge pair
|
||||||
|
// following RFC 7636 specifications for OAuth 2.0 PKCE extension
|
||||||
|
func GeneratePKCECodes() (*PKCECodes, error) {
|
||||||
|
// Generate code verifier: 43-128 characters, URL-safe
|
||||||
|
codeVerifier, err := generateCodeVerifier()
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to generate code verifier: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Generate code challenge using S256 method
|
||||||
|
codeChallenge := generateCodeChallenge(codeVerifier)
|
||||||
|
|
||||||
|
return &PKCECodes{
|
||||||
|
CodeVerifier: codeVerifier,
|
||||||
|
CodeChallenge: codeChallenge,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// generateCodeVerifier creates a cryptographically random string
|
||||||
|
// of 128 characters using URL-safe base64 encoding
|
||||||
|
func generateCodeVerifier() (string, error) {
|
||||||
|
// Generate 96 random bytes (will result in 128 base64 characters)
|
||||||
|
bytes := make([]byte, 96)
|
||||||
|
_, err := rand.Read(bytes)
|
||||||
|
if err != nil {
|
||||||
|
return "", fmt.Errorf("failed to generate random bytes: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Encode to URL-safe base64 without padding
|
||||||
|
return base64.URLEncoding.WithPadding(base64.NoPadding).EncodeToString(bytes), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// generateCodeChallenge creates a SHA256 hash of the code verifier
|
||||||
|
// and encodes it using URL-safe base64 encoding without padding
|
||||||
|
func generateCodeChallenge(codeVerifier string) string {
|
||||||
|
hash := sha256.Sum256([]byte(codeVerifier))
|
||||||
|
return base64.URLEncoding.WithPadding(base64.NoPadding).EncodeToString(hash[:])
|
||||||
|
}
|
||||||
51
internal/auth/codex/token.go
Normal file
51
internal/auth/codex/token.go
Normal file
@@ -0,0 +1,51 @@
|
|||||||
|
package codex
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"os"
|
||||||
|
"path"
|
||||||
|
)
|
||||||
|
|
||||||
|
// CodexTokenStorage extends the existing GeminiTokenStorage for OpenAI-specific data
|
||||||
|
// It maintains compatibility with the existing auth system while adding OpenAI-specific fields
|
||||||
|
type CodexTokenStorage struct {
|
||||||
|
// IDToken is the JWT ID token containing user claims
|
||||||
|
IDToken string `json:"id_token"`
|
||||||
|
// AccessToken is the OAuth2 access token for API access
|
||||||
|
AccessToken string `json:"access_token"`
|
||||||
|
// RefreshToken is used to obtain new access tokens
|
||||||
|
RefreshToken string `json:"refresh_token"`
|
||||||
|
// AccountID is the OpenAI account identifier
|
||||||
|
AccountID string `json:"account_id"`
|
||||||
|
// LastRefresh is the timestamp of the last token refresh
|
||||||
|
LastRefresh string `json:"last_refresh"`
|
||||||
|
// Email is the OpenAI account email
|
||||||
|
Email string `json:"email"`
|
||||||
|
// Type indicates the type (gemini, chatgpt, claude) of token storage.
|
||||||
|
Type string `json:"type"`
|
||||||
|
// Expire is the timestamp of the token expire
|
||||||
|
Expire string `json:"expired"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// SaveTokenToFile serializes the token storage to a JSON file.
|
||||||
|
func (ts *CodexTokenStorage) SaveTokenToFile(authFilePath string) error {
|
||||||
|
ts.Type = "codex"
|
||||||
|
if err := os.MkdirAll(path.Dir(authFilePath), 0700); err != nil {
|
||||||
|
return fmt.Errorf("failed to create directory: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
f, err := os.Create(authFilePath)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to create token file: %w", err)
|
||||||
|
}
|
||||||
|
defer func() {
|
||||||
|
_ = f.Close()
|
||||||
|
}()
|
||||||
|
|
||||||
|
if err = json.NewEncoder(f).Encode(ts); err != nil {
|
||||||
|
return fmt.Errorf("failed to write token to file: %w", err)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
|
||||||
|
}
|
||||||
@@ -1,7 +1,7 @@
|
|||||||
// Package auth provides OAuth2 authentication functionality for Google Cloud APIs.
|
// Package auth provides OAuth2 authentication functionality for Google Cloud APIs.
|
||||||
// It handles the complete OAuth2 flow including token storage, web-based authentication,
|
// It handles the complete OAuth2 flow including token storage, web-based authentication,
|
||||||
// proxy support, and automatic token refresh. The package supports both SOCKS5 and HTTP/HTTPS proxies.
|
// proxy support, and automatic token refresh. The package supports both SOCKS5 and HTTP/HTTPS proxies.
|
||||||
package auth
|
package gemini
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
@@ -14,9 +14,10 @@ import (
|
|||||||
"net/url"
|
"net/url"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/luispater/CLIProxyAPI/internal/auth/codex"
|
||||||
|
"github.com/luispater/CLIProxyAPI/internal/browser"
|
||||||
"github.com/luispater/CLIProxyAPI/internal/config"
|
"github.com/luispater/CLIProxyAPI/internal/config"
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
"github.com/skratchdot/open-golang/open"
|
|
||||||
"github.com/tidwall/gjson"
|
"github.com/tidwall/gjson"
|
||||||
"golang.org/x/net/proxy"
|
"golang.org/x/net/proxy"
|
||||||
|
|
||||||
@@ -25,22 +26,29 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
oauthClientID = "681255809395-oo8ft2oprdrnp9e3aqf6av3hmdib135j.apps.googleusercontent.com"
|
geminiOauthClientID = "681255809395-oo8ft2oprdrnp9e3aqf6av3hmdib135j.apps.googleusercontent.com"
|
||||||
oauthClientSecret = "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"
|
geminiOauthClientSecret = "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"
|
||||||
)
|
)
|
||||||
|
|
||||||
var (
|
var (
|
||||||
oauthScopes = []string{
|
geminiOauthScopes = []string{
|
||||||
"https://www.googleapis.com/auth/cloud-platform",
|
"https://www.googleapis.com/auth/cloud-platform",
|
||||||
"https://www.googleapis.com/auth/userinfo.email",
|
"https://www.googleapis.com/auth/userinfo.email",
|
||||||
"https://www.googleapis.com/auth/userinfo.profile",
|
"https://www.googleapis.com/auth/userinfo.profile",
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
type GeminiAuth struct {
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewGeminiAuth() *GeminiAuth {
|
||||||
|
return &GeminiAuth{}
|
||||||
|
}
|
||||||
|
|
||||||
// GetAuthenticatedClient configures and returns an HTTP client ready for making authenticated API calls.
|
// GetAuthenticatedClient configures and returns an HTTP client ready for making authenticated API calls.
|
||||||
// It manages the entire OAuth2 flow, including handling proxies, loading existing tokens,
|
// It manages the entire OAuth2 flow, including handling proxies, loading existing tokens,
|
||||||
// initiating a new web-based OAuth flow if necessary, and refreshing tokens.
|
// initiating a new web-based OAuth flow if necessary, and refreshing tokens.
|
||||||
func GetAuthenticatedClient(ctx context.Context, ts *TokenStorage, cfg *config.Config) (*http.Client, error) {
|
func (g *GeminiAuth) GetAuthenticatedClient(ctx context.Context, ts *GeminiTokenStorage, cfg *config.Config, noBrowser ...bool) (*http.Client, error) {
|
||||||
// Configure proxy settings for the HTTP client if a proxy URL is provided.
|
// Configure proxy settings for the HTTP client if a proxy URL is provided.
|
||||||
proxyURL, err := url.Parse(cfg.ProxyURL)
|
proxyURL, err := url.Parse(cfg.ProxyURL)
|
||||||
if err == nil {
|
if err == nil {
|
||||||
@@ -72,10 +80,10 @@ func GetAuthenticatedClient(ctx context.Context, ts *TokenStorage, cfg *config.C
|
|||||||
|
|
||||||
// Configure the OAuth2 client.
|
// Configure the OAuth2 client.
|
||||||
conf := &oauth2.Config{
|
conf := &oauth2.Config{
|
||||||
ClientID: oauthClientID,
|
ClientID: geminiOauthClientID,
|
||||||
ClientSecret: oauthClientSecret,
|
ClientSecret: geminiOauthClientSecret,
|
||||||
RedirectURL: "http://localhost:8085/oauth2callback", // This will be used by the local server.
|
RedirectURL: "http://localhost:8085/oauth2callback", // This will be used by the local server.
|
||||||
Scopes: oauthScopes,
|
Scopes: geminiOauthScopes,
|
||||||
Endpoint: google.Endpoint,
|
Endpoint: google.Endpoint,
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -84,12 +92,12 @@ func GetAuthenticatedClient(ctx context.Context, ts *TokenStorage, cfg *config.C
|
|||||||
// If no token is found in storage, initiate the web-based OAuth flow.
|
// If no token is found in storage, initiate the web-based OAuth flow.
|
||||||
if ts.Token == nil {
|
if ts.Token == nil {
|
||||||
log.Info("Could not load token from file, starting OAuth flow.")
|
log.Info("Could not load token from file, starting OAuth flow.")
|
||||||
token, err = getTokenFromWeb(ctx, conf)
|
token, err = g.getTokenFromWeb(ctx, conf, noBrowser...)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("failed to get token from web: %w", err)
|
return nil, fmt.Errorf("failed to get token from web: %w", err)
|
||||||
}
|
}
|
||||||
// After getting a new token, create a new token storage object with user info.
|
// After getting a new token, create a new token storage object with user info.
|
||||||
newTs, errCreateTokenStorage := createTokenStorage(ctx, conf, token, ts.ProjectID)
|
newTs, errCreateTokenStorage := g.createTokenStorage(ctx, conf, token, ts.ProjectID)
|
||||||
if errCreateTokenStorage != nil {
|
if errCreateTokenStorage != nil {
|
||||||
log.Errorf("Warning: failed to create token storage: %v", errCreateTokenStorage)
|
log.Errorf("Warning: failed to create token storage: %v", errCreateTokenStorage)
|
||||||
return nil, errCreateTokenStorage
|
return nil, errCreateTokenStorage
|
||||||
@@ -107,9 +115,9 @@ func GetAuthenticatedClient(ctx context.Context, ts *TokenStorage, cfg *config.C
|
|||||||
return conf.Client(ctx, token), nil
|
return conf.Client(ctx, token), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// createTokenStorage creates a new TokenStorage object. It fetches the user's email
|
// createTokenStorage creates a new GeminiTokenStorage object. It fetches the user's email
|
||||||
// using the provided token and populates the storage structure.
|
// using the provided token and populates the storage structure.
|
||||||
func createTokenStorage(ctx context.Context, config *oauth2.Config, token *oauth2.Token, projectID string) (*TokenStorage, error) {
|
func (g *GeminiAuth) createTokenStorage(ctx context.Context, config *oauth2.Config, token *oauth2.Token, projectID string) (*GeminiTokenStorage, error) {
|
||||||
httpClient := config.Client(ctx, token)
|
httpClient := config.Client(ctx, token)
|
||||||
req, err := http.NewRequestWithContext(ctx, "GET", "https://www.googleapis.com/oauth2/v1/userinfo?alt=json", nil)
|
req, err := http.NewRequestWithContext(ctx, "GET", "https://www.googleapis.com/oauth2/v1/userinfo?alt=json", nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -148,12 +156,12 @@ func createTokenStorage(ctx context.Context, config *oauth2.Config, token *oauth
|
|||||||
}
|
}
|
||||||
|
|
||||||
ifToken["token_uri"] = "https://oauth2.googleapis.com/token"
|
ifToken["token_uri"] = "https://oauth2.googleapis.com/token"
|
||||||
ifToken["client_id"] = oauthClientID
|
ifToken["client_id"] = geminiOauthClientID
|
||||||
ifToken["client_secret"] = oauthClientSecret
|
ifToken["client_secret"] = geminiOauthClientSecret
|
||||||
ifToken["scopes"] = oauthScopes
|
ifToken["scopes"] = geminiOauthScopes
|
||||||
ifToken["universe_domain"] = "googleapis.com"
|
ifToken["universe_domain"] = "googleapis.com"
|
||||||
|
|
||||||
ts := TokenStorage{
|
ts := GeminiTokenStorage{
|
||||||
Token: ifToken,
|
Token: ifToken,
|
||||||
ProjectID: projectID,
|
ProjectID: projectID,
|
||||||
Email: emailResult.String(),
|
Email: emailResult.String(),
|
||||||
@@ -166,7 +174,7 @@ func createTokenStorage(ctx context.Context, config *oauth2.Config, token *oauth
|
|||||||
// It starts a local HTTP server to listen for the callback from Google's auth server,
|
// It starts a local HTTP server to listen for the callback from Google's auth server,
|
||||||
// opens the user's browser to the authorization URL, and exchanges the received
|
// opens the user's browser to the authorization URL, and exchanges the received
|
||||||
// authorization code for an access token.
|
// authorization code for an access token.
|
||||||
func getTokenFromWeb(ctx context.Context, config *oauth2.Config) (*oauth2.Token, error) {
|
func (g *GeminiAuth) getTokenFromWeb(ctx context.Context, config *oauth2.Config, noBrowser ...bool) (*oauth2.Token, error) {
|
||||||
// Use a channel to pass the authorization code from the HTTP handler to the main function.
|
// Use a channel to pass the authorization code from the HTTP handler to the main function.
|
||||||
codeChan := make(chan string)
|
codeChan := make(chan string)
|
||||||
errChan := make(chan error)
|
errChan := make(chan error)
|
||||||
@@ -201,27 +209,46 @@ func getTokenFromWeb(ctx context.Context, config *oauth2.Config) (*oauth2.Token,
|
|||||||
|
|
||||||
// Open the authorization URL in the user's browser.
|
// Open the authorization URL in the user's browser.
|
||||||
authURL := config.AuthCodeURL("state-token", oauth2.AccessTypeOffline, oauth2.SetAuthURLParam("prompt", "consent"))
|
authURL := config.AuthCodeURL("state-token", oauth2.AccessTypeOffline, oauth2.SetAuthURLParam("prompt", "consent"))
|
||||||
log.Debugf("CLI login required.\nAttempting to open authentication page in your browser.\nIf it does not open, please navigate to this URL:\n\n%s\n", authURL)
|
|
||||||
|
|
||||||
var err error
|
if len(noBrowser) == 1 && !noBrowser[0] {
|
||||||
err = open.Run(authURL)
|
log.Info("Opening browser for authentication...")
|
||||||
if err != nil {
|
|
||||||
log.Errorf("Failed to open browser: %v. Please open the URL manually.", err)
|
// Check if browser is available
|
||||||
|
if !browser.IsAvailable() {
|
||||||
|
log.Warn("No browser available on this system")
|
||||||
|
log.Infof("Please manually open this URL in your browser:\n\n%s\n", authURL)
|
||||||
|
} else {
|
||||||
|
if err := browser.OpenURL(authURL); err != nil {
|
||||||
|
authErr := codex.NewAuthenticationError(codex.ErrBrowserOpenFailed, err)
|
||||||
|
log.Warn(codex.GetUserFriendlyMessage(authErr))
|
||||||
|
log.Infof("Please manually open this URL in your browser:\n\n%s\n", authURL)
|
||||||
|
|
||||||
|
// Log platform info for debugging
|
||||||
|
platformInfo := browser.GetPlatformInfo()
|
||||||
|
log.Debugf("Browser platform info: %+v", platformInfo)
|
||||||
|
} else {
|
||||||
|
log.Debug("Browser opened successfully")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
log.Infof("Please open this URL in your browser:\n\n%s\n", authURL)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
log.Info("Waiting for authentication callback...")
|
||||||
|
|
||||||
// Wait for the authorization code or an error.
|
// Wait for the authorization code or an error.
|
||||||
var authCode string
|
var authCode string
|
||||||
select {
|
select {
|
||||||
case code := <-codeChan:
|
case code := <-codeChan:
|
||||||
authCode = code
|
authCode = code
|
||||||
case err = <-errChan:
|
case err := <-errChan:
|
||||||
return nil, err
|
return nil, err
|
||||||
case <-time.After(5 * time.Minute): // Timeout
|
case <-time.After(5 * time.Minute): // Timeout
|
||||||
return nil, fmt.Errorf("oauth flow timed out")
|
return nil, fmt.Errorf("oauth flow timed out")
|
||||||
}
|
}
|
||||||
|
|
||||||
// Shutdown the server.
|
// Shutdown the server.
|
||||||
if err = server.Shutdown(ctx); err != nil {
|
if err := server.Shutdown(ctx); err != nil {
|
||||||
log.Errorf("Failed to shut down server: %v", err)
|
log.Errorf("Failed to shut down server: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
64
internal/auth/gemini/gemini_token.go
Normal file
64
internal/auth/gemini/gemini_token.go
Normal file
@@ -0,0 +1,64 @@
|
|||||||
|
// Package gemini provides authentication and token management functionality
|
||||||
|
// for Google's Gemini AI services. It handles OAuth2 token storage, serialization,
|
||||||
|
// and retrieval for maintaining authenticated sessions with the Gemini API.
|
||||||
|
package gemini
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"os"
|
||||||
|
"path"
|
||||||
|
)
|
||||||
|
|
||||||
|
// GeminiTokenStorage defines the structure for storing OAuth2 token information,
|
||||||
|
// along with associated user and project details. This data is typically
|
||||||
|
// serialized to a JSON file for persistence.
|
||||||
|
type GeminiTokenStorage struct {
|
||||||
|
// Token holds the raw OAuth2 token data, including access and refresh tokens.
|
||||||
|
Token any `json:"token"`
|
||||||
|
|
||||||
|
// ProjectID is the Google Cloud Project ID associated with this token.
|
||||||
|
ProjectID string `json:"project_id"`
|
||||||
|
|
||||||
|
// Email is the email address of the authenticated user.
|
||||||
|
Email string `json:"email"`
|
||||||
|
|
||||||
|
// Auto indicates if the project ID was automatically selected.
|
||||||
|
Auto bool `json:"auto"`
|
||||||
|
|
||||||
|
// Checked indicates if the associated Cloud AI API has been verified as enabled.
|
||||||
|
Checked bool `json:"checked"`
|
||||||
|
|
||||||
|
// Type indicates the type (gemini, chatgpt, claude) of token storage.
|
||||||
|
Type string `json:"type"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// SaveTokenToFile serializes the token storage to a JSON file.
|
||||||
|
// This method creates the necessary directory structure and writes the token
|
||||||
|
// data in JSON format to the specified file path. It ensures the file is
|
||||||
|
// properly closed after writing.
|
||||||
|
//
|
||||||
|
// Parameters:
|
||||||
|
// - authFilePath: The full path where the token file should be saved
|
||||||
|
//
|
||||||
|
// Returns:
|
||||||
|
// - error: An error if the operation fails, nil otherwise
|
||||||
|
func (ts *GeminiTokenStorage) SaveTokenToFile(authFilePath string) error {
|
||||||
|
ts.Type = "gemini"
|
||||||
|
if err := os.MkdirAll(path.Dir(authFilePath), 0700); err != nil {
|
||||||
|
return fmt.Errorf("failed to create directory: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
f, err := os.Create(authFilePath)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to create token file: %w", err)
|
||||||
|
}
|
||||||
|
defer func() {
|
||||||
|
_ = f.Close()
|
||||||
|
}()
|
||||||
|
|
||||||
|
if err = json.NewEncoder(f).Encode(ts); err != nil {
|
||||||
|
return fmt.Errorf("failed to write token to file: %w", err)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
@@ -1,17 +1,5 @@
|
|||||||
package auth
|
package auth
|
||||||
|
|
||||||
// TokenStorage defines the structure for storing OAuth2 token information,
|
type TokenStorage interface {
|
||||||
// along with associated user and project details. This data is typically
|
SaveTokenToFile(authFilePath string) error
|
||||||
// serialized to a JSON file for persistence.
|
|
||||||
type TokenStorage struct {
|
|
||||||
// Token holds the raw OAuth2 token data, including access and refresh tokens.
|
|
||||||
Token any `json:"token"`
|
|
||||||
// ProjectID is the Google Cloud Project ID associated with this token.
|
|
||||||
ProjectID string `json:"project_id"`
|
|
||||||
// Email is the email address of the authenticated user.
|
|
||||||
Email string `json:"email"`
|
|
||||||
// Auto indicates if the project ID was automatically selected.
|
|
||||||
Auto bool `json:"auto"`
|
|
||||||
// Checked indicates if the associated Cloud AI API has been verified as enabled.
|
|
||||||
Checked bool `json:"checked"`
|
|
||||||
}
|
}
|
||||||
|
|||||||
121
internal/browser/browser.go
Normal file
121
internal/browser/browser.go
Normal file
@@ -0,0 +1,121 @@
|
|||||||
|
package browser
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"os/exec"
|
||||||
|
"runtime"
|
||||||
|
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
"github.com/skratchdot/open-golang/open"
|
||||||
|
)
|
||||||
|
|
||||||
|
// OpenURL opens a URL in the default browser
|
||||||
|
func OpenURL(url string) error {
|
||||||
|
log.Debugf("Attempting to open URL in browser: %s", url)
|
||||||
|
|
||||||
|
// Try using the open-golang library first
|
||||||
|
err := open.Run(url)
|
||||||
|
if err == nil {
|
||||||
|
log.Debug("Successfully opened URL using open-golang library")
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Debugf("open-golang failed: %v, trying platform-specific commands", err)
|
||||||
|
|
||||||
|
// Fallback to platform-specific commands
|
||||||
|
return openURLPlatformSpecific(url)
|
||||||
|
}
|
||||||
|
|
||||||
|
// openURLPlatformSpecific opens URL using platform-specific commands
|
||||||
|
func openURLPlatformSpecific(url string) error {
|
||||||
|
var cmd *exec.Cmd
|
||||||
|
|
||||||
|
switch runtime.GOOS {
|
||||||
|
case "darwin": // macOS
|
||||||
|
cmd = exec.Command("open", url)
|
||||||
|
case "windows":
|
||||||
|
cmd = exec.Command("rundll32", "url.dll,FileProtocolHandler", url)
|
||||||
|
case "linux":
|
||||||
|
// Try common Linux browsers in order of preference
|
||||||
|
browsers := []string{"xdg-open", "x-www-browser", "www-browser", "firefox", "chromium", "google-chrome"}
|
||||||
|
for _, browser := range browsers {
|
||||||
|
if _, err := exec.LookPath(browser); err == nil {
|
||||||
|
cmd = exec.Command(browser, url)
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if cmd == nil {
|
||||||
|
return fmt.Errorf("no suitable browser found on Linux system")
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
return fmt.Errorf("unsupported operating system: %s", runtime.GOOS)
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Debugf("Running command: %s %v", cmd.Path, cmd.Args[1:])
|
||||||
|
err := cmd.Start()
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to start browser command: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Debug("Successfully opened URL using platform-specific command")
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// IsAvailable checks if browser opening functionality is available
|
||||||
|
func IsAvailable() bool {
|
||||||
|
// First check if open-golang can work
|
||||||
|
testErr := open.Run("about:blank")
|
||||||
|
if testErr == nil {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check platform-specific commands
|
||||||
|
switch runtime.GOOS {
|
||||||
|
case "darwin":
|
||||||
|
_, err := exec.LookPath("open")
|
||||||
|
return err == nil
|
||||||
|
case "windows":
|
||||||
|
_, err := exec.LookPath("rundll32")
|
||||||
|
return err == nil
|
||||||
|
case "linux":
|
||||||
|
browsers := []string{"xdg-open", "x-www-browser", "www-browser", "firefox", "chromium", "google-chrome"}
|
||||||
|
for _, browser := range browsers {
|
||||||
|
if _, err := exec.LookPath(browser); err == nil {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
default:
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetPlatformInfo returns information about the current platform's browser support
|
||||||
|
func GetPlatformInfo() map[string]interface{} {
|
||||||
|
info := map[string]interface{}{
|
||||||
|
"os": runtime.GOOS,
|
||||||
|
"arch": runtime.GOARCH,
|
||||||
|
"available": IsAvailable(),
|
||||||
|
}
|
||||||
|
|
||||||
|
switch runtime.GOOS {
|
||||||
|
case "darwin":
|
||||||
|
info["default_command"] = "open"
|
||||||
|
case "windows":
|
||||||
|
info["default_command"] = "rundll32"
|
||||||
|
case "linux":
|
||||||
|
browsers := []string{"xdg-open", "x-www-browser", "www-browser", "firefox", "chromium", "google-chrome"}
|
||||||
|
availableBrowsers := []string{}
|
||||||
|
for _, browser := range browsers {
|
||||||
|
if _, err := exec.LookPath(browser); err == nil {
|
||||||
|
availableBrowsers = append(availableBrowsers, browser)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
info["available_browsers"] = availableBrowsers
|
||||||
|
if len(availableBrowsers) > 0 {
|
||||||
|
info["default_command"] = availableBrowsers[0]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return info
|
||||||
|
}
|
||||||
File diff suppressed because it is too large
Load Diff
159
internal/client/client_models.go
Normal file
159
internal/client/client_models.go
Normal file
@@ -0,0 +1,159 @@
|
|||||||
|
// Package client defines the data structures used across all AI API clients.
|
||||||
|
// These structures represent the common data models for requests, responses,
|
||||||
|
// and configuration parameters used when communicating with various AI services.
|
||||||
|
package client
|
||||||
|
|
||||||
|
import "time"
|
||||||
|
|
||||||
|
// ErrorMessage encapsulates an error with an associated HTTP status code.
|
||||||
|
// This structure is used to provide detailed error information including
|
||||||
|
// both the HTTP status and the underlying error.
|
||||||
|
type ErrorMessage struct {
|
||||||
|
// StatusCode is the HTTP status code returned by the API.
|
||||||
|
StatusCode int
|
||||||
|
|
||||||
|
// Error is the underlying error that occurred.
|
||||||
|
Error error
|
||||||
|
}
|
||||||
|
|
||||||
|
// GCPProject represents the response structure for a Google Cloud project list request.
|
||||||
|
// This structure is used when fetching available projects for a Google Cloud account.
|
||||||
|
type GCPProject struct {
|
||||||
|
// Projects is a list of Google Cloud projects accessible by the user.
|
||||||
|
Projects []GCPProjectProjects `json:"projects"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// GCPProjectLabels defines the labels associated with a GCP project.
|
||||||
|
// These labels can contain metadata about the project's purpose or configuration.
|
||||||
|
type GCPProjectLabels struct {
|
||||||
|
// GenerativeLanguage indicates if the project has generative language APIs enabled.
|
||||||
|
GenerativeLanguage string `json:"generative-language"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// GCPProjectProjects contains details about a single Google Cloud project.
|
||||||
|
// This includes identifying information, metadata, and configuration details.
|
||||||
|
type GCPProjectProjects struct {
|
||||||
|
// ProjectNumber is the unique numeric identifier for the project.
|
||||||
|
ProjectNumber string `json:"projectNumber"`
|
||||||
|
|
||||||
|
// ProjectID is the unique string identifier for the project.
|
||||||
|
ProjectID string `json:"projectId"`
|
||||||
|
|
||||||
|
// LifecycleState indicates the current state of the project (e.g., "ACTIVE").
|
||||||
|
LifecycleState string `json:"lifecycleState"`
|
||||||
|
|
||||||
|
// Name is the human-readable name of the project.
|
||||||
|
Name string `json:"name"`
|
||||||
|
|
||||||
|
// Labels contains metadata labels associated with the project.
|
||||||
|
Labels GCPProjectLabels `json:"labels"`
|
||||||
|
|
||||||
|
// CreateTime is the timestamp when the project was created.
|
||||||
|
CreateTime time.Time `json:"createTime"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// Content represents a single message in a conversation, with a role and parts.
|
||||||
|
// This structure models a message exchange between a user and an AI model.
|
||||||
|
type Content struct {
|
||||||
|
// Role indicates who sent the message ("user", "model", or "tool").
|
||||||
|
Role string `json:"role"`
|
||||||
|
|
||||||
|
// Parts is a collection of content parts that make up the message.
|
||||||
|
Parts []Part `json:"parts"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// Part represents a distinct piece of content within a message.
|
||||||
|
// A part can be text, inline data (like an image), a function call, or a function response.
|
||||||
|
type Part struct {
|
||||||
|
// Text contains plain text content.
|
||||||
|
Text string `json:"text,omitempty"`
|
||||||
|
|
||||||
|
// InlineData contains base64-encoded data with its MIME type (e.g., images).
|
||||||
|
InlineData *InlineData `json:"inlineData,omitempty"`
|
||||||
|
|
||||||
|
// FunctionCall represents a tool call requested by the model.
|
||||||
|
FunctionCall *FunctionCall `json:"functionCall,omitempty"`
|
||||||
|
|
||||||
|
// FunctionResponse represents the result of a tool execution.
|
||||||
|
FunctionResponse *FunctionResponse `json:"functionResponse,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// InlineData represents base64-encoded data with its MIME type.
|
||||||
|
// This is typically used for embedding images or other binary data in requests.
|
||||||
|
type InlineData struct {
|
||||||
|
// MimeType specifies the media type of the embedded data (e.g., "image/png").
|
||||||
|
MimeType string `json:"mime_type,omitempty"`
|
||||||
|
|
||||||
|
// Data contains the base64-encoded binary data.
|
||||||
|
Data string `json:"data,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// FunctionCall represents a tool call requested by the model.
|
||||||
|
// It includes the function name and its arguments that the model wants to execute.
|
||||||
|
type FunctionCall struct {
|
||||||
|
// Name is the identifier of the function to be called.
|
||||||
|
Name string `json:"name"`
|
||||||
|
|
||||||
|
// Args contains the arguments to pass to the function.
|
||||||
|
Args map[string]interface{} `json:"args"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// FunctionResponse represents the result of a tool execution.
|
||||||
|
// This is sent back to the model after a tool call has been processed.
|
||||||
|
type FunctionResponse struct {
|
||||||
|
// Name is the identifier of the function that was called.
|
||||||
|
Name string `json:"name"`
|
||||||
|
|
||||||
|
// Response contains the result data from the function execution.
|
||||||
|
Response map[string]interface{} `json:"response"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// GenerateContentRequest is the top-level request structure for the streamGenerateContent endpoint.
|
||||||
|
// This structure defines all the parameters needed for generating content from an AI model.
|
||||||
|
type GenerateContentRequest struct {
|
||||||
|
// SystemInstruction provides system-level instructions that guide the model's behavior.
|
||||||
|
SystemInstruction *Content `json:"systemInstruction,omitempty"`
|
||||||
|
|
||||||
|
// Contents is the conversation history between the user and the model.
|
||||||
|
Contents []Content `json:"contents"`
|
||||||
|
|
||||||
|
// Tools defines the available tools/functions that the model can call.
|
||||||
|
Tools []ToolDeclaration `json:"tools,omitempty"`
|
||||||
|
|
||||||
|
// GenerationConfig contains parameters that control the model's generation behavior.
|
||||||
|
GenerationConfig `json:"generationConfig"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// GenerationConfig defines parameters that control the model's generation behavior.
|
||||||
|
// These parameters affect the creativity, randomness, and reasoning of the model's responses.
|
||||||
|
type GenerationConfig struct {
|
||||||
|
// ThinkingConfig specifies configuration for the model's "thinking" process.
|
||||||
|
ThinkingConfig GenerationConfigThinkingConfig `json:"thinkingConfig,omitempty"`
|
||||||
|
|
||||||
|
// Temperature controls the randomness of the model's responses.
|
||||||
|
// Values closer to 0 make responses more deterministic, while values closer to 1 increase randomness.
|
||||||
|
Temperature float64 `json:"temperature,omitempty"`
|
||||||
|
|
||||||
|
// TopP controls nucleus sampling, which affects the diversity of responses.
|
||||||
|
// It limits the model to consider only the top P% of probability mass.
|
||||||
|
TopP float64 `json:"topP,omitempty"`
|
||||||
|
|
||||||
|
// TopK limits the model to consider only the top K most likely tokens.
|
||||||
|
// This can help control the quality and diversity of generated text.
|
||||||
|
TopK float64 `json:"topK,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// GenerationConfigThinkingConfig specifies configuration for the model's "thinking" process.
|
||||||
|
// This controls whether the model should output its reasoning process along with the final answer.
|
||||||
|
type GenerationConfigThinkingConfig struct {
|
||||||
|
// IncludeThoughts determines whether the model should output its reasoning process.
|
||||||
|
// When enabled, the model will include its step-by-step thinking in the response.
|
||||||
|
IncludeThoughts bool `json:"include_thoughts,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// ToolDeclaration defines the structure for declaring tools (like functions)
|
||||||
|
// that the model can call during content generation.
|
||||||
|
type ToolDeclaration struct {
|
||||||
|
// FunctionDeclarations is a list of available functions that the model can call.
|
||||||
|
FunctionDeclarations []interface{} `json:"functionDeclarations"`
|
||||||
|
}
|
||||||
258
internal/client/codex_client.go
Normal file
258
internal/client/codex_client.go
Normal file
@@ -0,0 +1,258 @@
|
|||||||
|
package client
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bufio"
|
||||||
|
"bytes"
|
||||||
|
"context"
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
"path/filepath"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/google/uuid"
|
||||||
|
"github.com/luispater/CLIProxyAPI/internal/auth"
|
||||||
|
"github.com/luispater/CLIProxyAPI/internal/auth/codex"
|
||||||
|
"github.com/luispater/CLIProxyAPI/internal/config"
|
||||||
|
"github.com/luispater/CLIProxyAPI/internal/util"
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
"github.com/tidwall/gjson"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
chatGPTEndpoint = "https://chatgpt.com/backend-api"
|
||||||
|
)
|
||||||
|
|
||||||
|
// CodexClient implements the Client interface for OpenAI API
|
||||||
|
type CodexClient struct {
|
||||||
|
ClientBase
|
||||||
|
codexAuth *codex.CodexAuth
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewCodexClient creates a new OpenAI client instance
|
||||||
|
func NewCodexClient(cfg *config.Config, ts *codex.CodexTokenStorage) (*CodexClient, error) {
|
||||||
|
httpClient := util.SetProxy(cfg, &http.Client{})
|
||||||
|
client := &CodexClient{
|
||||||
|
ClientBase: ClientBase{
|
||||||
|
RequestMutex: &sync.Mutex{},
|
||||||
|
httpClient: httpClient,
|
||||||
|
cfg: cfg,
|
||||||
|
modelQuotaExceeded: make(map[string]*time.Time),
|
||||||
|
tokenStorage: ts,
|
||||||
|
},
|
||||||
|
codexAuth: codex.NewCodexAuth(cfg),
|
||||||
|
}
|
||||||
|
|
||||||
|
return client, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetUserAgent returns the user agent string for OpenAI API requests
|
||||||
|
func (c *CodexClient) GetUserAgent() string {
|
||||||
|
return "codex-cli"
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *CodexClient) TokenStorage() auth.TokenStorage {
|
||||||
|
return c.tokenStorage
|
||||||
|
}
|
||||||
|
|
||||||
|
// SendMessage sends a message to OpenAI API (non-streaming)
|
||||||
|
func (c *CodexClient) SendMessage(_ context.Context, _ []byte, _ string, _ *Content, _ []Content, _ []ToolDeclaration) ([]byte, *ErrorMessage) {
|
||||||
|
// For now, return an error as OpenAI integration is not fully implemented
|
||||||
|
return nil, &ErrorMessage{
|
||||||
|
StatusCode: http.StatusNotImplemented,
|
||||||
|
Error: fmt.Errorf("codex message sending not yet implemented"),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// SendMessageStream sends a streaming message to OpenAI API
|
||||||
|
func (c *CodexClient) SendMessageStream(_ context.Context, _ []byte, _ string, _ *Content, _ []Content, _ []ToolDeclaration, _ ...bool) (<-chan []byte, <-chan *ErrorMessage) {
|
||||||
|
errChan := make(chan *ErrorMessage, 1)
|
||||||
|
errChan <- &ErrorMessage{
|
||||||
|
StatusCode: http.StatusNotImplemented,
|
||||||
|
Error: fmt.Errorf("codex streaming not yet implemented"),
|
||||||
|
}
|
||||||
|
close(errChan)
|
||||||
|
|
||||||
|
return nil, errChan
|
||||||
|
}
|
||||||
|
|
||||||
|
// SendRawMessage sends a raw message to OpenAI API
|
||||||
|
func (c *CodexClient) SendRawMessage(ctx context.Context, rawJSON []byte, alt string) ([]byte, *ErrorMessage) {
|
||||||
|
modelResult := gjson.GetBytes(rawJSON, "model")
|
||||||
|
model := modelResult.String()
|
||||||
|
modelName := model
|
||||||
|
|
||||||
|
respBody, err := c.APIRequest(ctx, "/codex/responses", rawJSON, alt, false)
|
||||||
|
if err != nil {
|
||||||
|
if err.StatusCode == 429 {
|
||||||
|
now := time.Now()
|
||||||
|
c.modelQuotaExceeded[modelName] = &now
|
||||||
|
}
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
delete(c.modelQuotaExceeded, modelName)
|
||||||
|
bodyBytes, errReadAll := io.ReadAll(respBody)
|
||||||
|
if errReadAll != nil {
|
||||||
|
return nil, &ErrorMessage{StatusCode: 500, Error: errReadAll}
|
||||||
|
}
|
||||||
|
return bodyBytes, nil
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
// SendRawMessageStream sends a raw streaming message to OpenAI API
|
||||||
|
func (c *CodexClient) SendRawMessageStream(ctx context.Context, rawJSON []byte, alt string) (<-chan []byte, <-chan *ErrorMessage) {
|
||||||
|
errChan := make(chan *ErrorMessage)
|
||||||
|
dataChan := make(chan []byte)
|
||||||
|
go func() {
|
||||||
|
defer close(errChan)
|
||||||
|
defer close(dataChan)
|
||||||
|
|
||||||
|
modelResult := gjson.GetBytes(rawJSON, "model")
|
||||||
|
model := modelResult.String()
|
||||||
|
modelName := model
|
||||||
|
var stream io.ReadCloser
|
||||||
|
for {
|
||||||
|
var err *ErrorMessage
|
||||||
|
stream, err = c.APIRequest(ctx, "/codex/responses", rawJSON, alt, true)
|
||||||
|
if err != nil {
|
||||||
|
if err.StatusCode == 429 {
|
||||||
|
now := time.Now()
|
||||||
|
c.modelQuotaExceeded[modelName] = &now
|
||||||
|
}
|
||||||
|
errChan <- err
|
||||||
|
return
|
||||||
|
}
|
||||||
|
delete(c.modelQuotaExceeded, modelName)
|
||||||
|
break
|
||||||
|
}
|
||||||
|
|
||||||
|
scanner := bufio.NewScanner(stream)
|
||||||
|
buffer := make([]byte, 10240*1024)
|
||||||
|
scanner.Buffer(buffer, 10240*1024)
|
||||||
|
for scanner.Scan() {
|
||||||
|
line := scanner.Bytes()
|
||||||
|
dataChan <- line
|
||||||
|
}
|
||||||
|
|
||||||
|
if errScanner := scanner.Err(); errScanner != nil {
|
||||||
|
errChan <- &ErrorMessage{500, errScanner}
|
||||||
|
_ = stream.Close()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
_ = stream.Close()
|
||||||
|
}()
|
||||||
|
|
||||||
|
return dataChan, errChan
|
||||||
|
}
|
||||||
|
|
||||||
|
// SendRawTokenCount sends a token count request to OpenAI API
|
||||||
|
func (c *CodexClient) SendRawTokenCount(_ context.Context, _ []byte, _ string) ([]byte, *ErrorMessage) {
|
||||||
|
return nil, &ErrorMessage{
|
||||||
|
StatusCode: http.StatusNotImplemented,
|
||||||
|
Error: fmt.Errorf("codex token counting not yet implemented"),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// SaveTokenToFile persists the token storage to disk
|
||||||
|
func (c *CodexClient) SaveTokenToFile() error {
|
||||||
|
fileName := filepath.Join(c.cfg.AuthDir, fmt.Sprintf("codex-%s.json", c.tokenStorage.(*codex.CodexTokenStorage).Email))
|
||||||
|
return c.tokenStorage.SaveTokenToFile(fileName)
|
||||||
|
}
|
||||||
|
|
||||||
|
// RefreshTokens refreshes the access tokens if needed
|
||||||
|
func (c *CodexClient) RefreshTokens(ctx context.Context) error {
|
||||||
|
if c.tokenStorage == nil || c.tokenStorage.(*codex.CodexTokenStorage).RefreshToken == "" {
|
||||||
|
return fmt.Errorf("no refresh token available")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Refresh tokens using the auth service
|
||||||
|
newTokenData, err := c.codexAuth.RefreshTokensWithRetry(ctx, c.tokenStorage.(*codex.CodexTokenStorage).RefreshToken, 3)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to refresh tokens: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Update token storage
|
||||||
|
c.codexAuth.UpdateTokenStorage(c.tokenStorage.(*codex.CodexTokenStorage), newTokenData)
|
||||||
|
|
||||||
|
// Save updated tokens
|
||||||
|
if err = c.SaveTokenToFile(); err != nil {
|
||||||
|
log.Warnf("Failed to save refreshed tokens: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Debug("codex tokens refreshed successfully")
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// APIRequest handles making requests to the CLI API endpoints.
|
||||||
|
func (c *CodexClient) APIRequest(ctx context.Context, endpoint string, body interface{}, _ string, _ bool) (io.ReadCloser, *ErrorMessage) {
|
||||||
|
var jsonBody []byte
|
||||||
|
var err error
|
||||||
|
if byteBody, ok := body.([]byte); ok {
|
||||||
|
jsonBody = byteBody
|
||||||
|
} else {
|
||||||
|
jsonBody, err = json.Marshal(body)
|
||||||
|
if err != nil {
|
||||||
|
return nil, &ErrorMessage{500, fmt.Errorf("failed to marshal request body: %w", err)}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
url := fmt.Sprintf("%s/%s", chatGPTEndpoint, endpoint)
|
||||||
|
|
||||||
|
// log.Debug(string(jsonBody))
|
||||||
|
// log.Debug(url)
|
||||||
|
reqBody := bytes.NewBuffer(jsonBody)
|
||||||
|
|
||||||
|
req, err := http.NewRequestWithContext(ctx, "POST", url, reqBody)
|
||||||
|
if err != nil {
|
||||||
|
return nil, &ErrorMessage{500, fmt.Errorf("failed to create request: %v", err)}
|
||||||
|
}
|
||||||
|
|
||||||
|
sessionID := uuid.New().String()
|
||||||
|
// Set headers
|
||||||
|
req.Header.Set("Version", "0.21.0")
|
||||||
|
req.Header.Set("Content-Type", "application/json")
|
||||||
|
req.Header.Set("Openai-Beta", "responses=experimental")
|
||||||
|
req.Header.Set("Session_id", sessionID)
|
||||||
|
req.Header.Set("Accept", "text/event-stream")
|
||||||
|
req.Header.Set("Chatgpt-Account-Id", c.tokenStorage.(*codex.CodexTokenStorage).AccountID)
|
||||||
|
req.Header.Set("Originator", "codex_cli_rs")
|
||||||
|
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", c.tokenStorage.(*codex.CodexTokenStorage).AccessToken))
|
||||||
|
|
||||||
|
resp, err := c.httpClient.Do(req)
|
||||||
|
if err != nil {
|
||||||
|
return nil, &ErrorMessage{500, fmt.Errorf("failed to execute request: %v", err)}
|
||||||
|
}
|
||||||
|
|
||||||
|
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
|
||||||
|
defer func() {
|
||||||
|
if err = resp.Body.Close(); err != nil {
|
||||||
|
log.Printf("warn: failed to close response body: %v", err)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
bodyBytes, _ := io.ReadAll(resp.Body)
|
||||||
|
// log.Debug(string(jsonBody))
|
||||||
|
return nil, &ErrorMessage{resp.StatusCode, fmt.Errorf(string(bodyBytes))}
|
||||||
|
}
|
||||||
|
|
||||||
|
return resp.Body, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *CodexClient) GetEmail() string {
|
||||||
|
return c.tokenStorage.(*codex.CodexTokenStorage).Email
|
||||||
|
}
|
||||||
|
|
||||||
|
// IsModelQuotaExceeded returns true if the specified model has exceeded its quota
|
||||||
|
// and no fallback options are available.
|
||||||
|
func (c *CodexClient) IsModelQuotaExceeded(model string) bool {
|
||||||
|
if lastExceededTime, hasKey := c.modelQuotaExceeded[model]; hasKey {
|
||||||
|
duration := time.Now().Sub(*lastExceededTime)
|
||||||
|
if duration > 30*time.Minute {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
942
internal/client/gemini_client.go
Normal file
942
internal/client/gemini_client.go
Normal file
@@ -0,0 +1,942 @@
|
|||||||
|
// Package client provides HTTP client functionality for interacting with Google Cloud AI APIs.
|
||||||
|
// It handles OAuth2 authentication, token management, request/response processing,
|
||||||
|
// streaming communication, quota management, and automatic model fallback.
|
||||||
|
// The package supports both direct API key authentication and OAuth2 flows.
|
||||||
|
package client
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bufio"
|
||||||
|
"bytes"
|
||||||
|
"context"
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
"strings"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
geminiAuth "github.com/luispater/CLIProxyAPI/internal/auth/gemini"
|
||||||
|
"github.com/luispater/CLIProxyAPI/internal/config"
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
"github.com/tidwall/gjson"
|
||||||
|
"github.com/tidwall/sjson"
|
||||||
|
"golang.org/x/oauth2"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
codeAssistEndpoint = "https://cloudcode-pa.googleapis.com"
|
||||||
|
apiVersion = "v1internal"
|
||||||
|
|
||||||
|
glEndPoint = "https://generativelanguage.googleapis.com"
|
||||||
|
glAPIVersion = "v1beta"
|
||||||
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
previewModels = map[string][]string{
|
||||||
|
"gemini-2.5-pro": {"gemini-2.5-pro-preview-05-06", "gemini-2.5-pro-preview-06-05"},
|
||||||
|
"gemini-2.5-flash": {"gemini-2.5-flash-preview-04-17", "gemini-2.5-flash-preview-05-20"},
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
// GeminiClient is the main client for interacting with the CLI API.
|
||||||
|
type GeminiClient struct {
|
||||||
|
ClientBase
|
||||||
|
glAPIKey string
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewGeminiClient creates a new CLI API client.
|
||||||
|
func NewGeminiClient(httpClient *http.Client, ts *geminiAuth.GeminiTokenStorage, cfg *config.Config, glAPIKey ...string) *GeminiClient {
|
||||||
|
var glKey string
|
||||||
|
if len(glAPIKey) > 0 {
|
||||||
|
glKey = glAPIKey[0]
|
||||||
|
}
|
||||||
|
return &GeminiClient{
|
||||||
|
ClientBase: ClientBase{
|
||||||
|
RequestMutex: &sync.Mutex{},
|
||||||
|
httpClient: httpClient,
|
||||||
|
cfg: cfg,
|
||||||
|
tokenStorage: ts,
|
||||||
|
modelQuotaExceeded: make(map[string]*time.Time),
|
||||||
|
},
|
||||||
|
glAPIKey: glKey,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetProjectID updates the project ID for the client's token storage.
|
||||||
|
func (c *GeminiClient) SetProjectID(projectID string) {
|
||||||
|
c.tokenStorage.(*geminiAuth.GeminiTokenStorage).ProjectID = projectID
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetIsAuto configures whether the client should operate in automatic mode.
|
||||||
|
func (c *GeminiClient) SetIsAuto(auto bool) {
|
||||||
|
c.tokenStorage.(*geminiAuth.GeminiTokenStorage).Auto = auto
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetIsChecked sets the checked status for the client's token storage.
|
||||||
|
func (c *GeminiClient) SetIsChecked(checked bool) {
|
||||||
|
c.tokenStorage.(*geminiAuth.GeminiTokenStorage).Checked = checked
|
||||||
|
}
|
||||||
|
|
||||||
|
// IsChecked returns whether the client's token storage has been checked.
|
||||||
|
func (c *GeminiClient) IsChecked() bool {
|
||||||
|
return c.tokenStorage.(*geminiAuth.GeminiTokenStorage).Checked
|
||||||
|
}
|
||||||
|
|
||||||
|
// IsAuto returns whether the client is operating in automatic mode.
|
||||||
|
func (c *GeminiClient) IsAuto() bool {
|
||||||
|
return c.tokenStorage.(*geminiAuth.GeminiTokenStorage).Auto
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetEmail returns the email address associated with the client's token storage.
|
||||||
|
func (c *GeminiClient) GetEmail() string {
|
||||||
|
return c.tokenStorage.(*geminiAuth.GeminiTokenStorage).Email
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetProjectID returns the Google Cloud project ID from the client's token storage.
|
||||||
|
func (c *GeminiClient) GetProjectID() string {
|
||||||
|
if c.glAPIKey == "" && c.tokenStorage != nil {
|
||||||
|
if ts, ok := c.tokenStorage.(*geminiAuth.GeminiTokenStorage); ok {
|
||||||
|
return ts.ProjectID
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetGenerativeLanguageAPIKey returns the generative language API key if configured.
|
||||||
|
func (c *GeminiClient) GetGenerativeLanguageAPIKey() string {
|
||||||
|
return c.glAPIKey
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetupUser performs the initial user onboarding and setup.
|
||||||
|
func (c *GeminiClient) SetupUser(ctx context.Context, email, projectID string) error {
|
||||||
|
c.tokenStorage.(*geminiAuth.GeminiTokenStorage).Email = email
|
||||||
|
log.Info("Performing user onboarding...")
|
||||||
|
|
||||||
|
// 1. LoadCodeAssist
|
||||||
|
loadAssistReqBody := map[string]interface{}{
|
||||||
|
"metadata": c.getClientMetadata(),
|
||||||
|
}
|
||||||
|
if projectID != "" {
|
||||||
|
loadAssistReqBody["cloudaicompanionProject"] = projectID
|
||||||
|
}
|
||||||
|
|
||||||
|
var loadAssistResp map[string]interface{}
|
||||||
|
err := c.makeAPIRequest(ctx, "loadCodeAssist", "POST", loadAssistReqBody, &loadAssistResp)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to load code assist: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// a, _ := json.Marshal(&loadAssistResp)
|
||||||
|
// log.Debug(string(a))
|
||||||
|
//
|
||||||
|
// a, _ = json.Marshal(loadAssistReqBody)
|
||||||
|
// log.Debug(string(a))
|
||||||
|
|
||||||
|
// 2. OnboardUser
|
||||||
|
var onboardTierID = "legacy-tier"
|
||||||
|
if tiers, ok := loadAssistResp["allowedTiers"].([]interface{}); ok {
|
||||||
|
for _, t := range tiers {
|
||||||
|
if tier, tierOk := t.(map[string]interface{}); tierOk {
|
||||||
|
if isDefault, isDefaultOk := tier["isDefault"].(bool); isDefaultOk && isDefault {
|
||||||
|
if id, idOk := tier["id"].(string); idOk {
|
||||||
|
onboardTierID = id
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
onboardProjectID := projectID
|
||||||
|
if p, ok := loadAssistResp["cloudaicompanionProject"].(string); ok && p != "" {
|
||||||
|
onboardProjectID = p
|
||||||
|
}
|
||||||
|
|
||||||
|
onboardReqBody := map[string]interface{}{
|
||||||
|
"tierId": onboardTierID,
|
||||||
|
"metadata": c.getClientMetadata(),
|
||||||
|
}
|
||||||
|
if onboardProjectID != "" {
|
||||||
|
onboardReqBody["cloudaicompanionProject"] = onboardProjectID
|
||||||
|
} else {
|
||||||
|
return fmt.Errorf("failed to start user onboarding, need define a project id")
|
||||||
|
}
|
||||||
|
|
||||||
|
for {
|
||||||
|
var lroResp map[string]interface{}
|
||||||
|
err = c.makeAPIRequest(ctx, "onboardUser", "POST", onboardReqBody, &lroResp)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to start user onboarding: %w", err)
|
||||||
|
}
|
||||||
|
// a, _ := json.Marshal(&lroResp)
|
||||||
|
// log.Debug(string(a))
|
||||||
|
|
||||||
|
// 3. Poll Long-Running Operation (LRO)
|
||||||
|
done, doneOk := lroResp["done"].(bool)
|
||||||
|
if doneOk && done {
|
||||||
|
if project, projectOk := lroResp["response"].(map[string]interface{})["cloudaicompanionProject"].(map[string]interface{}); projectOk {
|
||||||
|
if projectID != "" {
|
||||||
|
c.tokenStorage.(*geminiAuth.GeminiTokenStorage).ProjectID = projectID
|
||||||
|
} else {
|
||||||
|
c.tokenStorage.(*geminiAuth.GeminiTokenStorage).ProjectID = project["id"].(string)
|
||||||
|
}
|
||||||
|
log.Infof("Onboarding complete. Using Project ID: %s", c.tokenStorage.(*geminiAuth.GeminiTokenStorage).ProjectID)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
log.Println("Onboarding in progress, waiting 5 seconds...")
|
||||||
|
time.Sleep(5 * time.Second)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// makeAPIRequest handles making requests to the CLI API endpoints.
|
||||||
|
func (c *GeminiClient) makeAPIRequest(ctx context.Context, endpoint, method string, body interface{}, result interface{}) error {
|
||||||
|
var reqBody io.Reader
|
||||||
|
if body != nil {
|
||||||
|
jsonBody, err := json.Marshal(body)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to marshal request body: %w", err)
|
||||||
|
}
|
||||||
|
reqBody = bytes.NewBuffer(jsonBody)
|
||||||
|
}
|
||||||
|
|
||||||
|
url := fmt.Sprintf("%s/%s:%s", codeAssistEndpoint, apiVersion, endpoint)
|
||||||
|
if strings.HasPrefix(endpoint, "operations/") {
|
||||||
|
url = fmt.Sprintf("%s/%s", codeAssistEndpoint, endpoint)
|
||||||
|
}
|
||||||
|
|
||||||
|
req, err := http.NewRequestWithContext(ctx, method, url, reqBody)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to create request: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
token, err := c.httpClient.Transport.(*oauth2.Transport).Source.Token()
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to get token: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Set headers
|
||||||
|
metadataStr := c.getClientMetadataString()
|
||||||
|
req.Header.Set("Content-Type", "application/json")
|
||||||
|
req.Header.Set("User-Agent", c.GetUserAgent())
|
||||||
|
req.Header.Set("X-Goog-Api-Client", "gl-node/22.17.0")
|
||||||
|
req.Header.Set("Client-Metadata", metadataStr)
|
||||||
|
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", token.AccessToken))
|
||||||
|
|
||||||
|
resp, err := c.httpClient.Do(req)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to execute request: %w", err)
|
||||||
|
}
|
||||||
|
defer func() {
|
||||||
|
if err = resp.Body.Close(); err != nil {
|
||||||
|
log.Printf("warn: failed to close response body: %v", err)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
|
||||||
|
bodyBytes, _ := io.ReadAll(resp.Body)
|
||||||
|
return fmt.Errorf("api request failed with status %d: %s", resp.StatusCode, string(bodyBytes))
|
||||||
|
}
|
||||||
|
|
||||||
|
if result != nil {
|
||||||
|
if err = json.NewDecoder(resp.Body).Decode(result); err != nil {
|
||||||
|
return fmt.Errorf("failed to decode response body: %w", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// APIRequest handles making requests to the CLI API endpoints.
|
||||||
|
func (c *GeminiClient) APIRequest(ctx context.Context, endpoint string, body interface{}, alt string, stream bool) (io.ReadCloser, *ErrorMessage) {
|
||||||
|
var jsonBody []byte
|
||||||
|
var err error
|
||||||
|
if byteBody, ok := body.([]byte); ok {
|
||||||
|
jsonBody = byteBody
|
||||||
|
} else {
|
||||||
|
jsonBody, err = json.Marshal(body)
|
||||||
|
if err != nil {
|
||||||
|
return nil, &ErrorMessage{500, fmt.Errorf("failed to marshal request body: %w", err)}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
var url string
|
||||||
|
if c.glAPIKey == "" {
|
||||||
|
// Add alt=sse for streaming
|
||||||
|
url = fmt.Sprintf("%s/%s:%s", codeAssistEndpoint, apiVersion, endpoint)
|
||||||
|
if alt == "" && stream {
|
||||||
|
url = url + "?alt=sse"
|
||||||
|
} else {
|
||||||
|
if alt != "" {
|
||||||
|
url = url + fmt.Sprintf("?$alt=%s", alt)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
if endpoint == "countTokens" {
|
||||||
|
modelResult := gjson.GetBytes(jsonBody, "model")
|
||||||
|
url = fmt.Sprintf("%s/%s/models/%s:%s", glEndPoint, glAPIVersion, modelResult.String(), endpoint)
|
||||||
|
} else {
|
||||||
|
modelResult := gjson.GetBytes(jsonBody, "model")
|
||||||
|
url = fmt.Sprintf("%s/%s/models/%s:%s", glEndPoint, glAPIVersion, modelResult.String(), endpoint)
|
||||||
|
if alt == "" && stream {
|
||||||
|
url = url + "?alt=sse"
|
||||||
|
} else {
|
||||||
|
if alt != "" {
|
||||||
|
url = url + fmt.Sprintf("?$alt=%s", alt)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
jsonBody = []byte(gjson.GetBytes(jsonBody, "request").Raw)
|
||||||
|
systemInstructionResult := gjson.GetBytes(jsonBody, "systemInstruction")
|
||||||
|
if systemInstructionResult.Exists() {
|
||||||
|
jsonBody, _ = sjson.SetRawBytes(jsonBody, "system_instruction", []byte(systemInstructionResult.Raw))
|
||||||
|
jsonBody, _ = sjson.DeleteBytes(jsonBody, "systemInstruction")
|
||||||
|
jsonBody, _ = sjson.DeleteBytes(jsonBody, "session_id")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// log.Debug(string(jsonBody))
|
||||||
|
// log.Debug(url)
|
||||||
|
reqBody := bytes.NewBuffer(jsonBody)
|
||||||
|
|
||||||
|
req, err := http.NewRequestWithContext(ctx, "POST", url, reqBody)
|
||||||
|
if err != nil {
|
||||||
|
return nil, &ErrorMessage{500, fmt.Errorf("failed to create request: %v", err)}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Set headers
|
||||||
|
metadataStr := c.getClientMetadataString()
|
||||||
|
req.Header.Set("Content-Type", "application/json")
|
||||||
|
if c.glAPIKey == "" {
|
||||||
|
token, errToken := c.httpClient.Transport.(*oauth2.Transport).Source.Token()
|
||||||
|
if errToken != nil {
|
||||||
|
return nil, &ErrorMessage{500, fmt.Errorf("failed to get token: %v", errToken)}
|
||||||
|
}
|
||||||
|
req.Header.Set("User-Agent", c.GetUserAgent())
|
||||||
|
req.Header.Set("X-Goog-Api-Client", "gl-node/22.17.0")
|
||||||
|
req.Header.Set("Client-Metadata", metadataStr)
|
||||||
|
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", token.AccessToken))
|
||||||
|
} else {
|
||||||
|
req.Header.Set("x-goog-api-key", c.glAPIKey)
|
||||||
|
}
|
||||||
|
|
||||||
|
resp, err := c.httpClient.Do(req)
|
||||||
|
if err != nil {
|
||||||
|
return nil, &ErrorMessage{500, fmt.Errorf("failed to execute request: %v", err)}
|
||||||
|
}
|
||||||
|
|
||||||
|
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
|
||||||
|
defer func() {
|
||||||
|
if err = resp.Body.Close(); err != nil {
|
||||||
|
log.Printf("warn: failed to close response body: %v", err)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
bodyBytes, _ := io.ReadAll(resp.Body)
|
||||||
|
// log.Debug(string(jsonBody))
|
||||||
|
return nil, &ErrorMessage{resp.StatusCode, fmt.Errorf(string(bodyBytes))}
|
||||||
|
}
|
||||||
|
|
||||||
|
return resp.Body, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// SendMessage handles a single conversational turn, including tool calls.
|
||||||
|
func (c *GeminiClient) SendMessage(ctx context.Context, rawJSON []byte, model string, systemInstruction *Content, contents []Content, tools []ToolDeclaration) ([]byte, *ErrorMessage) {
|
||||||
|
request := GenerateContentRequest{
|
||||||
|
Contents: contents,
|
||||||
|
GenerationConfig: GenerationConfig{
|
||||||
|
ThinkingConfig: GenerationConfigThinkingConfig{
|
||||||
|
IncludeThoughts: true,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
request.SystemInstruction = systemInstruction
|
||||||
|
|
||||||
|
request.Tools = tools
|
||||||
|
|
||||||
|
requestBody := map[string]interface{}{
|
||||||
|
"project": c.GetProjectID(), // Assuming ProjectID is available
|
||||||
|
"request": request,
|
||||||
|
"model": model,
|
||||||
|
}
|
||||||
|
|
||||||
|
byteRequestBody, _ := json.Marshal(requestBody)
|
||||||
|
|
||||||
|
// log.Debug(string(byteRequestBody))
|
||||||
|
|
||||||
|
reasoningEffortResult := gjson.GetBytes(rawJSON, "reasoning_effort")
|
||||||
|
if reasoningEffortResult.String() == "none" {
|
||||||
|
byteRequestBody, _ = sjson.DeleteBytes(byteRequestBody, "request.generationConfig.thinkingConfig.include_thoughts")
|
||||||
|
byteRequestBody, _ = sjson.SetBytes(byteRequestBody, "request.generationConfig.thinkingConfig.thinkingBudget", 0)
|
||||||
|
} else if reasoningEffortResult.String() == "auto" {
|
||||||
|
byteRequestBody, _ = sjson.SetBytes(byteRequestBody, "request.generationConfig.thinkingConfig.thinkingBudget", -1)
|
||||||
|
} else if reasoningEffortResult.String() == "low" {
|
||||||
|
byteRequestBody, _ = sjson.SetBytes(byteRequestBody, "request.generationConfig.thinkingConfig.thinkingBudget", 1024)
|
||||||
|
} else if reasoningEffortResult.String() == "medium" {
|
||||||
|
byteRequestBody, _ = sjson.SetBytes(byteRequestBody, "request.generationConfig.thinkingConfig.thinkingBudget", 8192)
|
||||||
|
} else if reasoningEffortResult.String() == "high" {
|
||||||
|
byteRequestBody, _ = sjson.SetBytes(byteRequestBody, "request.generationConfig.thinkingConfig.thinkingBudget", 24576)
|
||||||
|
} else {
|
||||||
|
byteRequestBody, _ = sjson.SetBytes(byteRequestBody, "request.generationConfig.thinkingConfig.thinkingBudget", -1)
|
||||||
|
}
|
||||||
|
|
||||||
|
temperatureResult := gjson.GetBytes(rawJSON, "temperature")
|
||||||
|
if temperatureResult.Exists() && temperatureResult.Type == gjson.Number {
|
||||||
|
byteRequestBody, _ = sjson.SetBytes(byteRequestBody, "request.generationConfig.temperature", temperatureResult.Num)
|
||||||
|
}
|
||||||
|
|
||||||
|
topPResult := gjson.GetBytes(rawJSON, "top_p")
|
||||||
|
if topPResult.Exists() && topPResult.Type == gjson.Number {
|
||||||
|
byteRequestBody, _ = sjson.SetBytes(byteRequestBody, "request.generationConfig.topP", topPResult.Num)
|
||||||
|
}
|
||||||
|
|
||||||
|
topKResult := gjson.GetBytes(rawJSON, "top_k")
|
||||||
|
if topKResult.Exists() && topKResult.Type == gjson.Number {
|
||||||
|
byteRequestBody, _ = sjson.SetBytes(byteRequestBody, "request.generationConfig.topK", topKResult.Num)
|
||||||
|
}
|
||||||
|
|
||||||
|
modelName := model
|
||||||
|
// log.Debug(string(byteRequestBody))
|
||||||
|
for {
|
||||||
|
if c.isModelQuotaExceeded(modelName) {
|
||||||
|
if c.cfg.QuotaExceeded.SwitchPreviewModel && c.glAPIKey == "" {
|
||||||
|
modelName = c.getPreviewModel(model)
|
||||||
|
if modelName != "" {
|
||||||
|
log.Debugf("Model %s is quota exceeded. Switch to preview model %s", model, modelName)
|
||||||
|
byteRequestBody, _ = sjson.SetBytes(byteRequestBody, "model", modelName)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil, &ErrorMessage{
|
||||||
|
StatusCode: 429,
|
||||||
|
Error: fmt.Errorf(`{"error":{"code":429,"message":"All the models of '%s' are quota exceeded","status":"RESOURCE_EXHAUSTED"}}`, model),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
respBody, err := c.APIRequest(ctx, "generateContent", byteRequestBody, "", false)
|
||||||
|
if err != nil {
|
||||||
|
if err.StatusCode == 429 {
|
||||||
|
now := time.Now()
|
||||||
|
c.modelQuotaExceeded[modelName] = &now
|
||||||
|
if c.cfg.QuotaExceeded.SwitchPreviewModel && c.glAPIKey == "" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
delete(c.modelQuotaExceeded, modelName)
|
||||||
|
bodyBytes, errReadAll := io.ReadAll(respBody)
|
||||||
|
if errReadAll != nil {
|
||||||
|
return nil, &ErrorMessage{StatusCode: 500, Error: errReadAll}
|
||||||
|
}
|
||||||
|
return bodyBytes, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// SendMessageStream handles streaming conversational turns with comprehensive parameter management.
|
||||||
|
// This function implements a sophisticated streaming system that supports tool calls, reasoning modes,
|
||||||
|
// quota management, and automatic model fallback. It returns two channels for asynchronous communication:
|
||||||
|
// one for streaming response data and another for error handling.
|
||||||
|
func (c *GeminiClient) SendMessageStream(ctx context.Context, rawJSON []byte, model string, systemInstruction *Content, contents []Content, tools []ToolDeclaration, includeThoughts ...bool) (<-chan []byte, <-chan *ErrorMessage) {
|
||||||
|
// Define the data prefix used in Server-Sent Events streaming format
|
||||||
|
dataTag := []byte("data: ")
|
||||||
|
|
||||||
|
// Create channels for asynchronous communication
|
||||||
|
// errChan: delivers error messages during streaming
|
||||||
|
// dataChan: delivers response data chunks
|
||||||
|
errChan := make(chan *ErrorMessage)
|
||||||
|
dataChan := make(chan []byte)
|
||||||
|
|
||||||
|
// Launch a goroutine to handle the streaming process asynchronously
|
||||||
|
// This allows the function to return immediately while processing continues in the background
|
||||||
|
go func() {
|
||||||
|
// Ensure channels are properly closed when the goroutine exits
|
||||||
|
defer close(errChan)
|
||||||
|
defer close(dataChan)
|
||||||
|
|
||||||
|
// Configure thinking/reasoning capabilities
|
||||||
|
// Default to including thoughts unless explicitly disabled
|
||||||
|
includeThoughtsFlag := true
|
||||||
|
if len(includeThoughts) > 0 {
|
||||||
|
includeThoughtsFlag = includeThoughts[0]
|
||||||
|
}
|
||||||
|
|
||||||
|
// Build the base request structure for the Gemini API
|
||||||
|
// This includes conversation contents and generation configuration
|
||||||
|
request := GenerateContentRequest{
|
||||||
|
Contents: contents,
|
||||||
|
GenerationConfig: GenerationConfig{
|
||||||
|
ThinkingConfig: GenerationConfigThinkingConfig{
|
||||||
|
IncludeThoughts: includeThoughtsFlag,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add system instructions if provided
|
||||||
|
// System instructions guide the AI's behavior and response style
|
||||||
|
request.SystemInstruction = systemInstruction
|
||||||
|
|
||||||
|
// Add available tools for function calling capabilities
|
||||||
|
// Tools allow the AI to perform actions beyond text generation
|
||||||
|
request.Tools = tools
|
||||||
|
|
||||||
|
// Construct the complete request body with project context
|
||||||
|
// The project ID is essential for proper API routing and billing
|
||||||
|
requestBody := map[string]interface{}{
|
||||||
|
"project": c.GetProjectID(), // Project ID for API routing and quota management
|
||||||
|
"request": request,
|
||||||
|
"model": model,
|
||||||
|
}
|
||||||
|
|
||||||
|
// Serialize the request body to JSON for API transmission
|
||||||
|
byteRequestBody, _ := json.Marshal(requestBody)
|
||||||
|
|
||||||
|
// Parse and configure reasoning effort levels from the original request
|
||||||
|
// This maps Claude-style reasoning effort parameters to Gemini's thinking budget system
|
||||||
|
reasoningEffortResult := gjson.GetBytes(rawJSON, "reasoning_effort")
|
||||||
|
if reasoningEffortResult.String() == "none" {
|
||||||
|
// Disable thinking entirely for fastest responses
|
||||||
|
byteRequestBody, _ = sjson.DeleteBytes(byteRequestBody, "request.generationConfig.thinkingConfig.include_thoughts")
|
||||||
|
byteRequestBody, _ = sjson.SetBytes(byteRequestBody, "request.generationConfig.thinkingConfig.thinkingBudget", 0)
|
||||||
|
} else if reasoningEffortResult.String() == "auto" {
|
||||||
|
// Let the model decide the appropriate thinking budget automatically
|
||||||
|
byteRequestBody, _ = sjson.SetBytes(byteRequestBody, "request.generationConfig.thinkingConfig.thinkingBudget", -1)
|
||||||
|
} else if reasoningEffortResult.String() == "low" {
|
||||||
|
// Minimal thinking for simple tasks (1KB thinking budget)
|
||||||
|
byteRequestBody, _ = sjson.SetBytes(byteRequestBody, "request.generationConfig.thinkingConfig.thinkingBudget", 1024)
|
||||||
|
} else if reasoningEffortResult.String() == "medium" {
|
||||||
|
// Moderate thinking for complex tasks (8KB thinking budget)
|
||||||
|
byteRequestBody, _ = sjson.SetBytes(byteRequestBody, "request.generationConfig.thinkingConfig.thinkingBudget", 8192)
|
||||||
|
} else if reasoningEffortResult.String() == "high" {
|
||||||
|
// Maximum thinking for very complex tasks (24KB thinking budget)
|
||||||
|
byteRequestBody, _ = sjson.SetBytes(byteRequestBody, "request.generationConfig.thinkingConfig.thinkingBudget", 24576)
|
||||||
|
} else {
|
||||||
|
// Default to automatic thinking budget if no specific level is provided
|
||||||
|
byteRequestBody, _ = sjson.SetBytes(byteRequestBody, "request.generationConfig.thinkingConfig.thinkingBudget", -1)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Configure temperature parameter for response randomness control
|
||||||
|
// Temperature affects the creativity vs consistency trade-off in responses
|
||||||
|
temperatureResult := gjson.GetBytes(rawJSON, "temperature")
|
||||||
|
if temperatureResult.Exists() && temperatureResult.Type == gjson.Number {
|
||||||
|
byteRequestBody, _ = sjson.SetBytes(byteRequestBody, "request.generationConfig.temperature", temperatureResult.Num)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Configure top-p parameter for nucleus sampling
|
||||||
|
// Controls the cumulative probability threshold for token selection
|
||||||
|
topPResult := gjson.GetBytes(rawJSON, "top_p")
|
||||||
|
if topPResult.Exists() && topPResult.Type == gjson.Number {
|
||||||
|
byteRequestBody, _ = sjson.SetBytes(byteRequestBody, "request.generationConfig.topP", topPResult.Num)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Configure top-k parameter for limiting token candidates
|
||||||
|
// Restricts the model to consider only the top K most likely tokens
|
||||||
|
topKResult := gjson.GetBytes(rawJSON, "top_k")
|
||||||
|
if topKResult.Exists() && topKResult.Type == gjson.Number {
|
||||||
|
byteRequestBody, _ = sjson.SetBytes(byteRequestBody, "request.generationConfig.topK", topKResult.Num)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Initialize model name for quota management and potential fallback
|
||||||
|
modelName := model
|
||||||
|
var stream io.ReadCloser
|
||||||
|
|
||||||
|
// Quota management and model fallback loop
|
||||||
|
// This loop handles quota exceeded scenarios and automatic model switching
|
||||||
|
for {
|
||||||
|
// Check if the current model has exceeded its quota
|
||||||
|
if c.isModelQuotaExceeded(modelName) {
|
||||||
|
// Attempt to switch to a preview model if configured and using account auth
|
||||||
|
if c.cfg.QuotaExceeded.SwitchPreviewModel && c.glAPIKey == "" {
|
||||||
|
modelName = c.getPreviewModel(model)
|
||||||
|
if modelName != "" {
|
||||||
|
log.Debugf("Model %s is quota exceeded. Switch to preview model %s", model, modelName)
|
||||||
|
// Update the request body with the new model name
|
||||||
|
byteRequestBody, _ = sjson.SetBytes(byteRequestBody, "model", modelName)
|
||||||
|
continue // Retry with the preview model
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// If no fallback is available, return a quota exceeded error
|
||||||
|
errChan <- &ErrorMessage{
|
||||||
|
StatusCode: 429,
|
||||||
|
Error: fmt.Errorf(`{"error":{"code":429,"message":"All the models of '%s' are quota exceeded","status":"RESOURCE_EXHAUSTED"}}`, model),
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Attempt to establish a streaming connection with the API
|
||||||
|
var err *ErrorMessage
|
||||||
|
stream, err = c.APIRequest(ctx, "streamGenerateContent", byteRequestBody, "", true)
|
||||||
|
if err != nil {
|
||||||
|
// Handle quota exceeded errors by marking the model and potentially retrying
|
||||||
|
if err.StatusCode == 429 {
|
||||||
|
now := time.Now()
|
||||||
|
c.modelQuotaExceeded[modelName] = &now // Mark model as quota exceeded
|
||||||
|
// If preview model switching is enabled, retry the loop
|
||||||
|
if c.cfg.QuotaExceeded.SwitchPreviewModel && c.glAPIKey == "" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// Forward other errors to the error channel
|
||||||
|
errChan <- err
|
||||||
|
return
|
||||||
|
}
|
||||||
|
// Clear any previous quota exceeded status for this model
|
||||||
|
delete(c.modelQuotaExceeded, modelName)
|
||||||
|
break // Successfully established connection, exit the retry loop
|
||||||
|
}
|
||||||
|
|
||||||
|
// Process the streaming response using a scanner
|
||||||
|
// This handles the Server-Sent Events format from the API
|
||||||
|
scanner := bufio.NewScanner(stream)
|
||||||
|
for scanner.Scan() {
|
||||||
|
line := scanner.Bytes()
|
||||||
|
// Filter and forward only data lines (those prefixed with "data: ")
|
||||||
|
// This extracts the actual JSON content from the SSE format
|
||||||
|
if bytes.HasPrefix(line, dataTag) {
|
||||||
|
dataChan <- line[6:] // Remove "data: " prefix and send the JSON content
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Handle any scanning errors that occurred during stream processing
|
||||||
|
if errScanner := scanner.Err(); errScanner != nil {
|
||||||
|
// Send a 500 Internal Server Error for scanning failures
|
||||||
|
errChan <- &ErrorMessage{500, errScanner}
|
||||||
|
_ = stream.Close()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Ensure the stream is properly closed to prevent resource leaks
|
||||||
|
_ = stream.Close()
|
||||||
|
}()
|
||||||
|
|
||||||
|
// Return the channels immediately for asynchronous communication
|
||||||
|
// The caller can read from these channels while the goroutine processes the request
|
||||||
|
return dataChan, errChan
|
||||||
|
}
|
||||||
|
|
||||||
|
// SendRawTokenCount handles a token count.
|
||||||
|
func (c *GeminiClient) SendRawTokenCount(ctx context.Context, rawJSON []byte, alt string) ([]byte, *ErrorMessage) {
|
||||||
|
modelResult := gjson.GetBytes(rawJSON, "model")
|
||||||
|
model := modelResult.String()
|
||||||
|
modelName := model
|
||||||
|
for {
|
||||||
|
if c.isModelQuotaExceeded(modelName) {
|
||||||
|
if c.cfg.QuotaExceeded.SwitchPreviewModel && c.glAPIKey == "" {
|
||||||
|
modelName = c.getPreviewModel(model)
|
||||||
|
if modelName != "" {
|
||||||
|
log.Debugf("Model %s is quota exceeded. Switch to preview model %s", model, modelName)
|
||||||
|
rawJSON, _ = sjson.SetBytes(rawJSON, "model", modelName)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil, &ErrorMessage{
|
||||||
|
StatusCode: 429,
|
||||||
|
Error: fmt.Errorf(`{"error":{"code":429,"message":"All the models of '%s' are quota exceeded","status":"RESOURCE_EXHAUSTED"}}`, model),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
respBody, err := c.APIRequest(ctx, "countTokens", rawJSON, alt, false)
|
||||||
|
if err != nil {
|
||||||
|
if err.StatusCode == 429 {
|
||||||
|
now := time.Now()
|
||||||
|
c.modelQuotaExceeded[modelName] = &now
|
||||||
|
if c.cfg.QuotaExceeded.SwitchPreviewModel && c.glAPIKey == "" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
delete(c.modelQuotaExceeded, modelName)
|
||||||
|
bodyBytes, errReadAll := io.ReadAll(respBody)
|
||||||
|
if errReadAll != nil {
|
||||||
|
return nil, &ErrorMessage{StatusCode: 500, Error: errReadAll}
|
||||||
|
}
|
||||||
|
return bodyBytes, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// SendRawMessage handles a single conversational turn, including tool calls.
|
||||||
|
func (c *GeminiClient) SendRawMessage(ctx context.Context, rawJSON []byte, alt string) ([]byte, *ErrorMessage) {
|
||||||
|
if c.glAPIKey == "" {
|
||||||
|
rawJSON, _ = sjson.SetBytes(rawJSON, "project", c.GetProjectID())
|
||||||
|
}
|
||||||
|
|
||||||
|
modelResult := gjson.GetBytes(rawJSON, "model")
|
||||||
|
model := modelResult.String()
|
||||||
|
modelName := model
|
||||||
|
for {
|
||||||
|
if c.isModelQuotaExceeded(modelName) {
|
||||||
|
if c.cfg.QuotaExceeded.SwitchPreviewModel && c.glAPIKey == "" {
|
||||||
|
modelName = c.getPreviewModel(model)
|
||||||
|
if modelName != "" {
|
||||||
|
log.Debugf("Model %s is quota exceeded. Switch to preview model %s", model, modelName)
|
||||||
|
rawJSON, _ = sjson.SetBytes(rawJSON, "model", modelName)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil, &ErrorMessage{
|
||||||
|
StatusCode: 429,
|
||||||
|
Error: fmt.Errorf(`{"error":{"code":429,"message":"All the models of '%s' are quota exceeded","status":"RESOURCE_EXHAUSTED"}}`, model),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
respBody, err := c.APIRequest(ctx, "generateContent", rawJSON, alt, false)
|
||||||
|
if err != nil {
|
||||||
|
if err.StatusCode == 429 {
|
||||||
|
now := time.Now()
|
||||||
|
c.modelQuotaExceeded[modelName] = &now
|
||||||
|
if c.cfg.QuotaExceeded.SwitchPreviewModel && c.glAPIKey == "" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
delete(c.modelQuotaExceeded, modelName)
|
||||||
|
bodyBytes, errReadAll := io.ReadAll(respBody)
|
||||||
|
if errReadAll != nil {
|
||||||
|
return nil, &ErrorMessage{StatusCode: 500, Error: errReadAll}
|
||||||
|
}
|
||||||
|
return bodyBytes, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// SendRawMessageStream handles a single conversational turn, including tool calls.
|
||||||
|
func (c *GeminiClient) SendRawMessageStream(ctx context.Context, rawJSON []byte, alt string) (<-chan []byte, <-chan *ErrorMessage) {
|
||||||
|
dataTag := []byte("data: ")
|
||||||
|
errChan := make(chan *ErrorMessage)
|
||||||
|
dataChan := make(chan []byte)
|
||||||
|
go func() {
|
||||||
|
defer close(errChan)
|
||||||
|
defer close(dataChan)
|
||||||
|
|
||||||
|
if c.glAPIKey == "" {
|
||||||
|
rawJSON, _ = sjson.SetBytes(rawJSON, "project", c.GetProjectID())
|
||||||
|
}
|
||||||
|
|
||||||
|
modelResult := gjson.GetBytes(rawJSON, "model")
|
||||||
|
model := modelResult.String()
|
||||||
|
modelName := model
|
||||||
|
var stream io.ReadCloser
|
||||||
|
for {
|
||||||
|
if c.isModelQuotaExceeded(modelName) {
|
||||||
|
if c.cfg.QuotaExceeded.SwitchPreviewModel && c.glAPIKey == "" {
|
||||||
|
modelName = c.getPreviewModel(model)
|
||||||
|
if modelName != "" {
|
||||||
|
log.Debugf("Model %s is quota exceeded. Switch to preview model %s", model, modelName)
|
||||||
|
rawJSON, _ = sjson.SetBytes(rawJSON, "model", modelName)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
}
|
||||||
|
errChan <- &ErrorMessage{
|
||||||
|
StatusCode: 429,
|
||||||
|
Error: fmt.Errorf(`{"error":{"code":429,"message":"All the models of '%s' are quota exceeded","status":"RESOURCE_EXHAUSTED"}}`, model),
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
var err *ErrorMessage
|
||||||
|
stream, err = c.APIRequest(ctx, "streamGenerateContent", rawJSON, alt, true)
|
||||||
|
if err != nil {
|
||||||
|
if err.StatusCode == 429 {
|
||||||
|
now := time.Now()
|
||||||
|
c.modelQuotaExceeded[modelName] = &now
|
||||||
|
if c.cfg.QuotaExceeded.SwitchPreviewModel && c.glAPIKey == "" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
}
|
||||||
|
errChan <- err
|
||||||
|
return
|
||||||
|
}
|
||||||
|
delete(c.modelQuotaExceeded, modelName)
|
||||||
|
break
|
||||||
|
}
|
||||||
|
|
||||||
|
if alt == "" {
|
||||||
|
scanner := bufio.NewScanner(stream)
|
||||||
|
for scanner.Scan() {
|
||||||
|
line := scanner.Bytes()
|
||||||
|
if bytes.HasPrefix(line, dataTag) {
|
||||||
|
dataChan <- line[6:]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if errScanner := scanner.Err(); errScanner != nil {
|
||||||
|
errChan <- &ErrorMessage{500, errScanner}
|
||||||
|
_ = stream.Close()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
} else {
|
||||||
|
data, err := io.ReadAll(stream)
|
||||||
|
if err != nil {
|
||||||
|
errChan <- &ErrorMessage{500, err}
|
||||||
|
_ = stream.Close()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
dataChan <- data
|
||||||
|
}
|
||||||
|
_ = stream.Close()
|
||||||
|
|
||||||
|
}()
|
||||||
|
|
||||||
|
return dataChan, errChan
|
||||||
|
}
|
||||||
|
|
||||||
|
// isModelQuotaExceeded checks if the specified model has exceeded its quota
|
||||||
|
// within the last 30 minutes.
|
||||||
|
func (c *GeminiClient) isModelQuotaExceeded(model string) bool {
|
||||||
|
if lastExceededTime, hasKey := c.modelQuotaExceeded[model]; hasKey {
|
||||||
|
duration := time.Now().Sub(*lastExceededTime)
|
||||||
|
if duration > 30*time.Minute {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
// getPreviewModel returns an available preview model for the given base model,
|
||||||
|
// or an empty string if no preview models are available or all are quota exceeded.
|
||||||
|
func (c *GeminiClient) getPreviewModel(model string) string {
|
||||||
|
if models, hasKey := previewModels[model]; hasKey {
|
||||||
|
for i := 0; i < len(models); i++ {
|
||||||
|
if !c.isModelQuotaExceeded(models[i]) {
|
||||||
|
return models[i]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
// IsModelQuotaExceeded returns true if the specified model has exceeded its quota
|
||||||
|
// and no fallback options are available.
|
||||||
|
func (c *GeminiClient) IsModelQuotaExceeded(model string) bool {
|
||||||
|
if c.isModelQuotaExceeded(model) {
|
||||||
|
if c.cfg.QuotaExceeded.SwitchPreviewModel {
|
||||||
|
return c.getPreviewModel(model) == ""
|
||||||
|
}
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
// CheckCloudAPIIsEnabled sends a simple test request to the API to verify
|
||||||
|
// that the Cloud AI API is enabled for the user's project. It provides
|
||||||
|
// an activation URL if the API is disabled.
|
||||||
|
func (c *GeminiClient) CheckCloudAPIIsEnabled() (bool, error) {
|
||||||
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
defer func() {
|
||||||
|
c.RequestMutex.Unlock()
|
||||||
|
cancel()
|
||||||
|
}()
|
||||||
|
c.RequestMutex.Lock()
|
||||||
|
|
||||||
|
// A simple request to test the API endpoint.
|
||||||
|
requestBody := fmt.Sprintf(`{"project":"%s","request":{"contents":[{"role":"user","parts":[{"text":"Be concise. What is the capital of France?"}]}],"generationConfig":{"thinkingConfig":{"include_thoughts":false,"thinkingBudget":0}}},"model":"gemini-2.5-flash"}`, c.tokenStorage.(*geminiAuth.GeminiTokenStorage).ProjectID)
|
||||||
|
|
||||||
|
stream, err := c.APIRequest(ctx, "streamGenerateContent", []byte(requestBody), "", true)
|
||||||
|
if err != nil {
|
||||||
|
// If a 403 Forbidden error occurs, it likely means the API is not enabled.
|
||||||
|
if err.StatusCode == 403 {
|
||||||
|
errJSON := err.Error.Error()
|
||||||
|
// Check for a specific error code and extract the activation URL.
|
||||||
|
if gjson.Get(errJSON, "0.error.code").Int() == 403 {
|
||||||
|
activationURL := gjson.Get(errJSON, "0.error.details.0.metadata.activationUrl").String()
|
||||||
|
if activationURL != "" {
|
||||||
|
log.Warnf(
|
||||||
|
"\n\nPlease activate your account with this url:\n\n%s\n\n And execute this command again:\n%s --login --project_id %s",
|
||||||
|
activationURL,
|
||||||
|
os.Args[0],
|
||||||
|
c.tokenStorage.(*geminiAuth.GeminiTokenStorage).ProjectID,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
log.Warnf("\n\nPlease copy this message and create an issue.\n\n%s\n\n", errJSON)
|
||||||
|
return false, nil
|
||||||
|
}
|
||||||
|
return false, err.Error
|
||||||
|
}
|
||||||
|
defer func() {
|
||||||
|
_ = stream.Close()
|
||||||
|
}()
|
||||||
|
|
||||||
|
// We only need to know if the request was successful, so we can drain the stream.
|
||||||
|
scanner := bufio.NewScanner(stream)
|
||||||
|
for scanner.Scan() {
|
||||||
|
// Do nothing, just consume the stream.
|
||||||
|
}
|
||||||
|
|
||||||
|
return scanner.Err() == nil, scanner.Err()
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetProjectList fetches a list of Google Cloud projects accessible by the user.
|
||||||
|
func (c *GeminiClient) GetProjectList(ctx context.Context) (*GCPProject, error) {
|
||||||
|
token, err := c.httpClient.Transport.(*oauth2.Transport).Source.Token()
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to get token: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
req, err := http.NewRequestWithContext(ctx, "GET", "https://cloudresourcemanager.googleapis.com/v1/projects", nil)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("could not create project list request: %v", err)
|
||||||
|
}
|
||||||
|
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", token.AccessToken))
|
||||||
|
|
||||||
|
resp, err := c.httpClient.Do(req)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to execute project list request: %w", err)
|
||||||
|
}
|
||||||
|
defer func() {
|
||||||
|
_ = resp.Body.Close()
|
||||||
|
}()
|
||||||
|
|
||||||
|
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
|
||||||
|
bodyBytes, _ := io.ReadAll(resp.Body)
|
||||||
|
return nil, fmt.Errorf("project list request failed with status %d: %s", resp.StatusCode, string(bodyBytes))
|
||||||
|
}
|
||||||
|
|
||||||
|
var project GCPProject
|
||||||
|
if err = json.NewDecoder(resp.Body).Decode(&project); err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to unmarshal project list: %w", err)
|
||||||
|
}
|
||||||
|
return &project, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// SaveTokenToFile serializes the client's current token storage to a JSON file.
|
||||||
|
// The filename is constructed from the user's email and project ID.
|
||||||
|
func (c *GeminiClient) SaveTokenToFile() error {
|
||||||
|
fileName := filepath.Join(c.cfg.AuthDir, fmt.Sprintf("%s-%s.json", c.tokenStorage.(*geminiAuth.GeminiTokenStorage).Email, c.tokenStorage.(*geminiAuth.GeminiTokenStorage).ProjectID))
|
||||||
|
log.Infof("Saving credentials to %s", fileName)
|
||||||
|
return c.tokenStorage.SaveTokenToFile(fileName)
|
||||||
|
}
|
||||||
|
|
||||||
|
// getClientMetadata returns a map of metadata about the client environment,
|
||||||
|
// such as IDE type, platform, and plugin version.
|
||||||
|
func (c *GeminiClient) getClientMetadata() map[string]string {
|
||||||
|
return map[string]string{
|
||||||
|
"ideType": "IDE_UNSPECIFIED",
|
||||||
|
"platform": "PLATFORM_UNSPECIFIED",
|
||||||
|
"pluginType": "GEMINI",
|
||||||
|
// "pluginVersion": pluginVersion,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// getClientMetadataString returns the client metadata as a single,
|
||||||
|
// comma-separated string, which is required for the 'GeminiClient-Metadata' header.
|
||||||
|
func (c *GeminiClient) getClientMetadataString() string {
|
||||||
|
md := c.getClientMetadata()
|
||||||
|
parts := make([]string, 0, len(md))
|
||||||
|
for k, v := range md {
|
||||||
|
parts = append(parts, fmt.Sprintf("%s=%s", k, v))
|
||||||
|
}
|
||||||
|
return strings.Join(parts, ",")
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetUserAgent constructs the User-Agent string for HTTP requests.
|
||||||
|
func (c *GeminiClient) GetUserAgent() string {
|
||||||
|
// return fmt.Sprintf("GeminiCLI/%s (%s; %s)", pluginVersion, runtime.GOOS, runtime.GOARCH)
|
||||||
|
return "google-api-nodejs-client/9.15.1"
|
||||||
|
}
|
||||||
@@ -1,91 +0,0 @@
|
|||||||
package client
|
|
||||||
|
|
||||||
import "time"
|
|
||||||
|
|
||||||
// ErrorMessage encapsulates an error with an associated HTTP status code.
|
|
||||||
type ErrorMessage struct {
|
|
||||||
StatusCode int
|
|
||||||
Error error
|
|
||||||
}
|
|
||||||
|
|
||||||
// GCPProject represents the response structure for a Google Cloud project list request.
|
|
||||||
type GCPProject struct {
|
|
||||||
Projects []GCPProjectProjects `json:"projects"`
|
|
||||||
}
|
|
||||||
|
|
||||||
// GCPProjectLabels defines the labels associated with a GCP project.
|
|
||||||
type GCPProjectLabels struct {
|
|
||||||
GenerativeLanguage string `json:"generative-language"`
|
|
||||||
}
|
|
||||||
|
|
||||||
// GCPProjectProjects contains details about a single Google Cloud project.
|
|
||||||
type GCPProjectProjects struct {
|
|
||||||
ProjectNumber string `json:"projectNumber"`
|
|
||||||
ProjectID string `json:"projectId"`
|
|
||||||
LifecycleState string `json:"lifecycleState"`
|
|
||||||
Name string `json:"name"`
|
|
||||||
Labels GCPProjectLabels `json:"labels"`
|
|
||||||
CreateTime time.Time `json:"createTime"`
|
|
||||||
}
|
|
||||||
|
|
||||||
// Content represents a single message in a conversation, with a role and parts.
|
|
||||||
type Content struct {
|
|
||||||
Role string `json:"role"`
|
|
||||||
Parts []Part `json:"parts"`
|
|
||||||
}
|
|
||||||
|
|
||||||
// Part represents a distinct piece of content within a message, which can be
|
|
||||||
// text, inline data (like an image), a function call, or a function response.
|
|
||||||
type Part struct {
|
|
||||||
Text string `json:"text,omitempty"`
|
|
||||||
InlineData *InlineData `json:"inlineData,omitempty"`
|
|
||||||
FunctionCall *FunctionCall `json:"functionCall,omitempty"`
|
|
||||||
FunctionResponse *FunctionResponse `json:"functionResponse,omitempty"`
|
|
||||||
}
|
|
||||||
|
|
||||||
// InlineData represents base64-encoded data with its MIME type.
|
|
||||||
type InlineData struct {
|
|
||||||
MimeType string `json:"mime_type,omitempty"`
|
|
||||||
Data string `json:"data,omitempty"`
|
|
||||||
}
|
|
||||||
|
|
||||||
// FunctionCall represents a tool call requested by the model, including the
|
|
||||||
// function name and its arguments.
|
|
||||||
type FunctionCall struct {
|
|
||||||
Name string `json:"name"`
|
|
||||||
Args map[string]interface{} `json:"args"`
|
|
||||||
}
|
|
||||||
|
|
||||||
// FunctionResponse represents the result of a tool execution, sent back to the model.
|
|
||||||
type FunctionResponse struct {
|
|
||||||
Name string `json:"name"`
|
|
||||||
Response map[string]interface{} `json:"response"`
|
|
||||||
}
|
|
||||||
|
|
||||||
// GenerateContentRequest is the top-level request structure for the streamGenerateContent endpoint.
|
|
||||||
type GenerateContentRequest struct {
|
|
||||||
SystemInstruction *Content `json:"systemInstruction,omitempty"`
|
|
||||||
Contents []Content `json:"contents"`
|
|
||||||
Tools []ToolDeclaration `json:"tools,omitempty"`
|
|
||||||
GenerationConfig `json:"generationConfig"`
|
|
||||||
}
|
|
||||||
|
|
||||||
// GenerationConfig defines parameters that control the model's generation behavior.
|
|
||||||
type GenerationConfig struct {
|
|
||||||
ThinkingConfig GenerationConfigThinkingConfig `json:"thinkingConfig,omitempty"`
|
|
||||||
Temperature float64 `json:"temperature,omitempty"`
|
|
||||||
TopP float64 `json:"topP,omitempty"`
|
|
||||||
TopK float64 `json:"topK,omitempty"`
|
|
||||||
}
|
|
||||||
|
|
||||||
// GenerationConfigThinkingConfig specifies configuration for the model's "thinking" process.
|
|
||||||
type GenerationConfigThinkingConfig struct {
|
|
||||||
// IncludeThoughts determines whether the model should output its reasoning process.
|
|
||||||
IncludeThoughts bool `json:"include_thoughts,omitempty"`
|
|
||||||
}
|
|
||||||
|
|
||||||
// ToolDeclaration defines the structure for declaring tools (like functions)
|
|
||||||
// that the model can call.
|
|
||||||
type ToolDeclaration struct {
|
|
||||||
FunctionDeclarations []interface{} `json:"functionDeclarations"`
|
|
||||||
}
|
|
||||||
@@ -5,27 +5,33 @@ package cmd
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"github.com/luispater/CLIProxyAPI/internal/auth"
|
"os"
|
||||||
|
|
||||||
|
"github.com/luispater/CLIProxyAPI/internal/auth/gemini"
|
||||||
"github.com/luispater/CLIProxyAPI/internal/client"
|
"github.com/luispater/CLIProxyAPI/internal/client"
|
||||||
"github.com/luispater/CLIProxyAPI/internal/config"
|
"github.com/luispater/CLIProxyAPI/internal/config"
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
"os"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
// DoLogin handles the entire user login and setup process.
|
// DoLogin handles the entire user login and setup process.
|
||||||
// It authenticates the user, sets up the user's project, checks API enablement,
|
// It authenticates the user, sets up the user's project, checks API enablement,
|
||||||
// and saves the token for future use.
|
// and saves the token for future use.
|
||||||
func DoLogin(cfg *config.Config, projectID string) {
|
func DoLogin(cfg *config.Config, projectID string, options *LoginOptions) {
|
||||||
|
if options == nil {
|
||||||
|
options = &LoginOptions{}
|
||||||
|
}
|
||||||
|
|
||||||
var err error
|
var err error
|
||||||
var ts auth.TokenStorage
|
var ts gemini.GeminiTokenStorage
|
||||||
if projectID != "" {
|
if projectID != "" {
|
||||||
ts.ProjectID = projectID
|
ts.ProjectID = projectID
|
||||||
}
|
}
|
||||||
|
|
||||||
// Initialize an authenticated HTTP client. This will trigger the OAuth flow if necessary.
|
// Initialize an authenticated HTTP client. This will trigger the OAuth flow if necessary.
|
||||||
clientCtx := context.Background()
|
clientCtx := context.Background()
|
||||||
log.Info("Initializing authentication...")
|
log.Info("Initializing Google authentication...")
|
||||||
httpClient, errGetClient := auth.GetAuthenticatedClient(clientCtx, &ts, cfg)
|
geminiAuth := gemini.NewGeminiAuth()
|
||||||
|
httpClient, errGetClient := geminiAuth.GetAuthenticatedClient(clientCtx, &ts, cfg, options.NoBrowser)
|
||||||
if errGetClient != nil {
|
if errGetClient != nil {
|
||||||
log.Fatalf("failed to get authenticated client: %v", errGetClient)
|
log.Fatalf("failed to get authenticated client: %v", errGetClient)
|
||||||
return
|
return
|
||||||
@@ -33,7 +39,7 @@ func DoLogin(cfg *config.Config, projectID string) {
|
|||||||
log.Info("Authentication successful.")
|
log.Info("Authentication successful.")
|
||||||
|
|
||||||
// Initialize the API client.
|
// Initialize the API client.
|
||||||
cliClient := client.NewClient(httpClient, &ts, cfg)
|
cliClient := client.NewGeminiClient(httpClient, &ts, cfg)
|
||||||
|
|
||||||
// Perform the user setup process.
|
// Perform the user setup process.
|
||||||
err = cliClient.SetupUser(clientCtx, ts.Email, projectID)
|
err = cliClient.SetupUser(clientCtx, ts.Email, projectID)
|
||||||
|
|||||||
173
internal/cmd/openai_login.go
Normal file
173
internal/cmd/openai_login.go
Normal file
@@ -0,0 +1,173 @@
|
|||||||
|
package cmd
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"crypto/rand"
|
||||||
|
"encoding/hex"
|
||||||
|
"fmt"
|
||||||
|
"net/http"
|
||||||
|
"os"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/luispater/CLIProxyAPI/internal/auth/codex"
|
||||||
|
"github.com/luispater/CLIProxyAPI/internal/browser"
|
||||||
|
"github.com/luispater/CLIProxyAPI/internal/client"
|
||||||
|
"github.com/luispater/CLIProxyAPI/internal/config"
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
)
|
||||||
|
|
||||||
|
// LoginOptions contains options for login
|
||||||
|
type LoginOptions struct {
|
||||||
|
NoBrowser bool
|
||||||
|
}
|
||||||
|
|
||||||
|
// DoCodexLogin handles the Codex OAuth login process
|
||||||
|
func DoCodexLogin(cfg *config.Config, options *LoginOptions) {
|
||||||
|
if options == nil {
|
||||||
|
options = &LoginOptions{}
|
||||||
|
}
|
||||||
|
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
|
log.Info("Initializing Codex authentication...")
|
||||||
|
|
||||||
|
// Generate PKCE codes
|
||||||
|
pkceCodes, err := codex.GeneratePKCECodes()
|
||||||
|
if err != nil {
|
||||||
|
log.Fatalf("Failed to generate PKCE codes: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Generate random state parameter
|
||||||
|
state, err := generateRandomState()
|
||||||
|
if err != nil {
|
||||||
|
log.Fatalf("Failed to generate state parameter: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Initialize OAuth server
|
||||||
|
oauthServer := codex.NewOAuthServer(1455)
|
||||||
|
|
||||||
|
// Start OAuth callback server
|
||||||
|
if err = oauthServer.Start(ctx); err != nil {
|
||||||
|
if strings.Contains(err.Error(), "already in use") {
|
||||||
|
authErr := codex.NewAuthenticationError(codex.ErrPortInUse, err)
|
||||||
|
log.Error(codex.GetUserFriendlyMessage(authErr))
|
||||||
|
os.Exit(13) // Exit code 13 for port-in-use error
|
||||||
|
}
|
||||||
|
authErr := codex.NewAuthenticationError(codex.ErrServerStartFailed, err)
|
||||||
|
log.Fatalf("Failed to start OAuth callback server: %v", authErr)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
defer func() {
|
||||||
|
if err = oauthServer.Stop(ctx); err != nil {
|
||||||
|
log.Warnf("Failed to stop OAuth server: %v", err)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
// Initialize Codex auth service
|
||||||
|
openaiAuth := codex.NewCodexAuth(cfg)
|
||||||
|
|
||||||
|
// Generate authorization URL
|
||||||
|
authURL, err := openaiAuth.GenerateAuthURL(state, pkceCodes)
|
||||||
|
if err != nil {
|
||||||
|
log.Fatalf("Failed to generate authorization URL: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Open browser or display URL
|
||||||
|
if !options.NoBrowser {
|
||||||
|
log.Info("Opening browser for authentication...")
|
||||||
|
|
||||||
|
// Check if browser is available
|
||||||
|
if !browser.IsAvailable() {
|
||||||
|
log.Warn("No browser available on this system")
|
||||||
|
log.Infof("Please manually open this URL in your browser:\n\n%s\n", authURL)
|
||||||
|
} else {
|
||||||
|
if err = browser.OpenURL(authURL); err != nil {
|
||||||
|
authErr := codex.NewAuthenticationError(codex.ErrBrowserOpenFailed, err)
|
||||||
|
log.Warn(codex.GetUserFriendlyMessage(authErr))
|
||||||
|
log.Infof("Please manually open this URL in your browser:\n\n%s\n", authURL)
|
||||||
|
|
||||||
|
// Log platform info for debugging
|
||||||
|
platformInfo := browser.GetPlatformInfo()
|
||||||
|
log.Debugf("Browser platform info: %+v", platformInfo)
|
||||||
|
} else {
|
||||||
|
log.Debug("Browser opened successfully")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
log.Infof("Please open this URL in your browser:\n\n%s\n", authURL)
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Info("Waiting for authentication callback...")
|
||||||
|
|
||||||
|
// Wait for OAuth callback
|
||||||
|
result, err := oauthServer.WaitForCallback(5 * time.Minute)
|
||||||
|
if err != nil {
|
||||||
|
if strings.Contains(err.Error(), "timeout") {
|
||||||
|
authErr := codex.NewAuthenticationError(codex.ErrCallbackTimeout, err)
|
||||||
|
log.Error(codex.GetUserFriendlyMessage(authErr))
|
||||||
|
} else {
|
||||||
|
log.Errorf("Authentication failed: %v", err)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if result.Error != "" {
|
||||||
|
oauthErr := codex.NewOAuthError(result.Error, "", http.StatusBadRequest)
|
||||||
|
log.Error(codex.GetUserFriendlyMessage(oauthErr))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Validate state parameter
|
||||||
|
if result.State != state {
|
||||||
|
authErr := codex.NewAuthenticationError(codex.ErrInvalidState, fmt.Errorf("expected %s, got %s", state, result.State))
|
||||||
|
log.Error(codex.GetUserFriendlyMessage(authErr))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Debug("Authorization code received, exchanging for tokens...")
|
||||||
|
|
||||||
|
// Exchange authorization code for tokens
|
||||||
|
authBundle, err := openaiAuth.ExchangeCodeForTokens(ctx, result.Code, pkceCodes)
|
||||||
|
if err != nil {
|
||||||
|
authErr := codex.NewAuthenticationError(codex.ErrCodeExchangeFailed, err)
|
||||||
|
log.Errorf("Failed to exchange authorization code for tokens: %v", authErr)
|
||||||
|
log.Debug("This may be due to network issues or invalid authorization code")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create token storage
|
||||||
|
tokenStorage := openaiAuth.CreateTokenStorage(authBundle)
|
||||||
|
|
||||||
|
// Initialize Codex client
|
||||||
|
openaiClient, err := client.NewCodexClient(cfg, tokenStorage)
|
||||||
|
if err != nil {
|
||||||
|
log.Fatalf("Failed to initialize Codex client: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Save token storage
|
||||||
|
if err = openaiClient.SaveTokenToFile(); err != nil {
|
||||||
|
log.Fatalf("Failed to save authentication tokens: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Info("Authentication successful!")
|
||||||
|
if authBundle.APIKey != "" {
|
||||||
|
log.Info("API key obtained and saved")
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Info("You can now use Codex services through this CLI")
|
||||||
|
}
|
||||||
|
|
||||||
|
// generateRandomState generates a cryptographically secure random state parameter
|
||||||
|
func generateRandomState() (string, error) {
|
||||||
|
bytes := make([]byte, 16)
|
||||||
|
if _, err := rand.Read(bytes); err != nil {
|
||||||
|
return "", fmt.Errorf("failed to generate random bytes: %w", err)
|
||||||
|
}
|
||||||
|
return hex.EncodeToString(bytes), nil
|
||||||
|
}
|
||||||
@@ -8,29 +8,37 @@ package cmd
|
|||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"github.com/luispater/CLIProxyAPI/internal/api"
|
|
||||||
"github.com/luispater/CLIProxyAPI/internal/auth"
|
|
||||||
"github.com/luispater/CLIProxyAPI/internal/client"
|
|
||||||
"github.com/luispater/CLIProxyAPI/internal/config"
|
|
||||||
"github.com/luispater/CLIProxyAPI/internal/util"
|
|
||||||
"github.com/luispater/CLIProxyAPI/internal/watcher"
|
|
||||||
log "github.com/sirupsen/logrus"
|
|
||||||
"io/fs"
|
"io/fs"
|
||||||
"net/http"
|
"net/http"
|
||||||
"os"
|
"os"
|
||||||
"os/signal"
|
"os/signal"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
"strings"
|
"strings"
|
||||||
|
"sync"
|
||||||
"syscall"
|
"syscall"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/luispater/CLIProxyAPI/internal/api"
|
||||||
|
"github.com/luispater/CLIProxyAPI/internal/auth/codex"
|
||||||
|
"github.com/luispater/CLIProxyAPI/internal/auth/gemini"
|
||||||
|
"github.com/luispater/CLIProxyAPI/internal/client"
|
||||||
|
"github.com/luispater/CLIProxyAPI/internal/config"
|
||||||
|
"github.com/luispater/CLIProxyAPI/internal/util"
|
||||||
|
"github.com/luispater/CLIProxyAPI/internal/watcher"
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
"github.com/tidwall/gjson"
|
||||||
)
|
)
|
||||||
|
|
||||||
// StartService initializes and starts the main API proxy service.
|
// StartService initializes and starts the main API proxy service.
|
||||||
// It loads all available authentication tokens, creates a pool of clients,
|
// It loads all available authentication tokens, creates a pool of clients,
|
||||||
// starts the API server, and handles graceful shutdown signals.
|
// starts the API server, and handles graceful shutdown signals.
|
||||||
|
//
|
||||||
|
// Parameters:
|
||||||
|
// - cfg: The application configuration
|
||||||
|
// - configPath: The path to the configuration file
|
||||||
func StartService(cfg *config.Config, configPath string) {
|
func StartService(cfg *config.Config, configPath string) {
|
||||||
// Create a pool of API clients, one for each token file found.
|
// Create a pool of API clients, one for each token file found.
|
||||||
cliClients := make([]*client.Client, 0)
|
cliClients := make([]client.Client, 0)
|
||||||
err := filepath.Walk(cfg.AuthDir, func(path string, info fs.FileInfo, err error) error {
|
err := filepath.Walk(cfg.AuthDir, func(path string, info fs.FileInfo, err error) error {
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
@@ -39,31 +47,51 @@ func StartService(cfg *config.Config, configPath string) {
|
|||||||
// Process only JSON files in the auth directory.
|
// Process only JSON files in the auth directory.
|
||||||
if !info.IsDir() && strings.HasSuffix(info.Name(), ".json") {
|
if !info.IsDir() && strings.HasSuffix(info.Name(), ".json") {
|
||||||
log.Debugf("Loading token from: %s", path)
|
log.Debugf("Loading token from: %s", path)
|
||||||
f, errOpen := os.Open(path)
|
data, errReadFile := os.ReadFile(path)
|
||||||
if errOpen != nil {
|
if errReadFile != nil {
|
||||||
return errOpen
|
return errReadFile
|
||||||
}
|
}
|
||||||
defer func() {
|
|
||||||
_ = f.Close()
|
|
||||||
}()
|
|
||||||
|
|
||||||
// Decode the token storage file.
|
tokenType := "gemini"
|
||||||
var ts auth.TokenStorage
|
typeResult := gjson.GetBytes(data, "type")
|
||||||
if err = json.NewDecoder(f).Decode(&ts); err == nil {
|
if typeResult.Exists() {
|
||||||
// For each valid token, create an authenticated client.
|
tokenType = typeResult.String()
|
||||||
clientCtx := context.Background()
|
}
|
||||||
log.Info("Initializing authentication for token...")
|
|
||||||
httpClient, errGetClient := auth.GetAuthenticatedClient(clientCtx, &ts, cfg)
|
clientCtx := context.Background()
|
||||||
if errGetClient != nil {
|
|
||||||
// Log fatal will exit, but we return the error for completeness.
|
if tokenType == "gemini" {
|
||||||
log.Fatalf("failed to get authenticated client for token %s: %v", path, errGetClient)
|
var ts gemini.GeminiTokenStorage
|
||||||
return errGetClient
|
if err = json.Unmarshal(data, &ts); err == nil {
|
||||||
|
// For each valid token, create an authenticated client.
|
||||||
|
log.Info("Initializing gemini authentication for token...")
|
||||||
|
geminiAuth := gemini.NewGeminiAuth()
|
||||||
|
httpClient, errGetClient := geminiAuth.GetAuthenticatedClient(clientCtx, &ts, cfg)
|
||||||
|
if errGetClient != nil {
|
||||||
|
// Log fatal will exit, but we return the error for completeness.
|
||||||
|
log.Fatalf("failed to get authenticated client for token %s: %v", path, errGetClient)
|
||||||
|
return errGetClient
|
||||||
|
}
|
||||||
|
log.Info("Authentication successful.")
|
||||||
|
|
||||||
|
// Add the new client to the pool.
|
||||||
|
cliClient := client.NewGeminiClient(httpClient, &ts, cfg)
|
||||||
|
cliClients = append(cliClients, cliClient)
|
||||||
|
}
|
||||||
|
} else if tokenType == "codex" {
|
||||||
|
var ts codex.CodexTokenStorage
|
||||||
|
if err = json.Unmarshal(data, &ts); err == nil {
|
||||||
|
// For each valid token, create an authenticated client.
|
||||||
|
log.Info("Initializing codex authentication for token...")
|
||||||
|
codexClient, errGetClient := client.NewCodexClient(cfg, &ts)
|
||||||
|
if errGetClient != nil {
|
||||||
|
// Log fatal will exit, but we return the error for completeness.
|
||||||
|
log.Fatalf("failed to get authenticated client for token %s: %v", path, errGetClient)
|
||||||
|
return errGetClient
|
||||||
|
}
|
||||||
|
log.Info("Authentication successful.")
|
||||||
|
cliClients = append(cliClients, codexClient)
|
||||||
}
|
}
|
||||||
log.Info("Authentication successful.")
|
|
||||||
|
|
||||||
// Add the new client to the pool.
|
|
||||||
cliClient := client.NewClient(httpClient, &ts, cfg)
|
|
||||||
cliClients = append(cliClients, cliClient)
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
@@ -74,13 +102,10 @@ func StartService(cfg *config.Config, configPath string) {
|
|||||||
|
|
||||||
if len(cfg.GlAPIKey) > 0 {
|
if len(cfg.GlAPIKey) > 0 {
|
||||||
for i := 0; i < len(cfg.GlAPIKey); i++ {
|
for i := 0; i < len(cfg.GlAPIKey); i++ {
|
||||||
httpClient, errSetProxy := util.SetProxy(cfg, &http.Client{})
|
httpClient := util.SetProxy(cfg, &http.Client{})
|
||||||
if errSetProxy != nil {
|
|
||||||
log.Fatalf("set proxy failed: %v", errSetProxy)
|
|
||||||
}
|
|
||||||
|
|
||||||
log.Debug("Initializing with Generative Language API key...")
|
log.Debug("Initializing with Generative Language API key...")
|
||||||
cliClient := client.NewClient(httpClient, nil, cfg, cfg.GlAPIKey[i])
|
cliClient := client.NewGeminiClient(httpClient, nil, cfg, cfg.GlAPIKey[i])
|
||||||
cliClients = append(cliClients, cliClient)
|
cliClients = append(cliClients, cliClient)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -101,7 +126,7 @@ func StartService(cfg *config.Config, configPath string) {
|
|||||||
log.Info("API server started successfully")
|
log.Info("API server started successfully")
|
||||||
|
|
||||||
// Setup file watcher for config and auth directory changes
|
// Setup file watcher for config and auth directory changes
|
||||||
fileWatcher, errNewWatcher := watcher.NewWatcher(configPath, cfg.AuthDir, func(newClients []*client.Client, newCfg *config.Config) {
|
fileWatcher, errNewWatcher := watcher.NewWatcher(configPath, cfg.AuthDir, func(newClients []client.Client, newCfg *config.Config) {
|
||||||
// Update the API server with new clients and configuration
|
// Update the API server with new clients and configuration
|
||||||
apiServer.UpdateClients(newClients, newCfg)
|
apiServer.UpdateClients(newClients, newCfg)
|
||||||
})
|
})
|
||||||
@@ -132,12 +157,50 @@ func StartService(cfg *config.Config, configPath string) {
|
|||||||
sigChan := make(chan os.Signal, 1)
|
sigChan := make(chan os.Signal, 1)
|
||||||
signal.Notify(sigChan, syscall.SIGINT, syscall.SIGTERM)
|
signal.Notify(sigChan, syscall.SIGINT, syscall.SIGTERM)
|
||||||
|
|
||||||
|
// Background token refresh ticker for Codex clients
|
||||||
|
ctxRefresh, cancelRefresh := context.WithCancel(context.Background())
|
||||||
|
var wgRefresh sync.WaitGroup
|
||||||
|
wgRefresh.Add(1)
|
||||||
|
go func() {
|
||||||
|
defer wgRefresh.Done()
|
||||||
|
ticker := time.NewTicker(1 * time.Hour)
|
||||||
|
defer ticker.Stop()
|
||||||
|
checkAndRefresh := func() {
|
||||||
|
for i := 0; i < len(cliClients); i++ {
|
||||||
|
if codexCli, ok := cliClients[i].(*client.CodexClient); ok {
|
||||||
|
ts := codexCli.TokenStorage().(*codex.CodexTokenStorage)
|
||||||
|
if ts != nil && ts.Expire != "" {
|
||||||
|
if expTime, errParse := time.Parse(time.RFC3339, ts.Expire); errParse == nil {
|
||||||
|
if time.Until(expTime) <= 5*24*time.Hour {
|
||||||
|
log.Debugf("refreshing codex tokens for %s", codexCli.GetEmail())
|
||||||
|
_ = codexCli.RefreshTokens(ctxRefresh)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// Initial check on start
|
||||||
|
checkAndRefresh()
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-ctxRefresh.Done():
|
||||||
|
return
|
||||||
|
case <-ticker.C:
|
||||||
|
checkAndRefresh()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
// Main loop to wait for shutdown signal.
|
// Main loop to wait for shutdown signal.
|
||||||
for {
|
for {
|
||||||
select {
|
select {
|
||||||
case <-sigChan:
|
case <-sigChan:
|
||||||
log.Debugf("Received shutdown signal. Cleaning up...")
|
log.Debugf("Received shutdown signal. Cleaning up...")
|
||||||
|
|
||||||
|
cancelRefresh()
|
||||||
|
wgRefresh.Wait()
|
||||||
|
|
||||||
// Create a context with a timeout for the shutdown process.
|
// Create a context with a timeout for the shutdown process.
|
||||||
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
|
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
|
||||||
_ = cancel
|
_ = cancel
|
||||||
@@ -150,8 +213,6 @@ func StartService(cfg *config.Config, configPath string) {
|
|||||||
log.Debugf("Cleanup completed. Exiting...")
|
log.Debugf("Cleanup completed. Exiting...")
|
||||||
os.Exit(0)
|
os.Exit(0)
|
||||||
case <-time.After(5 * time.Second):
|
case <-time.After(5 * time.Second):
|
||||||
// This case is currently empty and acts as a periodic check.
|
|
||||||
// It could be used for periodic tasks in the future.
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -6,26 +6,36 @@ package config
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"gopkg.in/yaml.v3"
|
|
||||||
"os"
|
"os"
|
||||||
|
|
||||||
|
"gopkg.in/yaml.v3"
|
||||||
)
|
)
|
||||||
|
|
||||||
// Config represents the application's configuration, loaded from a YAML file.
|
// Config represents the application's configuration, loaded from a YAML file.
|
||||||
type Config struct {
|
type Config struct {
|
||||||
// Port is the network port on which the API server will listen.
|
// Port is the network port on which the API server will listen.
|
||||||
Port int `yaml:"port"`
|
Port int `yaml:"port"`
|
||||||
|
|
||||||
// AuthDir is the directory where authentication token files are stored.
|
// AuthDir is the directory where authentication token files are stored.
|
||||||
AuthDir string `yaml:"auth-dir"`
|
AuthDir string `yaml:"auth-dir"`
|
||||||
|
|
||||||
// Debug enables or disables debug-level logging and other debug features.
|
// Debug enables or disables debug-level logging and other debug features.
|
||||||
Debug bool `yaml:"debug"`
|
Debug bool `yaml:"debug"`
|
||||||
|
|
||||||
// ProxyURL is the URL of an optional proxy server to use for outbound requests.
|
// ProxyURL is the URL of an optional proxy server to use for outbound requests.
|
||||||
ProxyURL string `yaml:"proxy-url"`
|
ProxyURL string `yaml:"proxy-url"`
|
||||||
|
|
||||||
// APIKeys is a list of keys for authenticating clients to this proxy server.
|
// APIKeys is a list of keys for authenticating clients to this proxy server.
|
||||||
APIKeys []string `yaml:"api-keys"`
|
APIKeys []string `yaml:"api-keys"`
|
||||||
|
|
||||||
// QuotaExceeded defines the behavior when a quota is exceeded.
|
// QuotaExceeded defines the behavior when a quota is exceeded.
|
||||||
QuotaExceeded QuotaExceeded `yaml:"quota-exceeded"`
|
QuotaExceeded QuotaExceeded `yaml:"quota-exceeded"`
|
||||||
|
|
||||||
// GlAPIKey is the API key for the generative language API.
|
// GlAPIKey is the API key for the generative language API.
|
||||||
GlAPIKey []string `yaml:"generative-language-api-key"`
|
GlAPIKey []string `yaml:"generative-language-api-key"`
|
||||||
|
|
||||||
|
// RequestLog enables or disables detailed request logging functionality.
|
||||||
|
RequestLog bool `yaml:"request-log"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// QuotaExceeded defines the behavior when API quota limits are exceeded.
|
// QuotaExceeded defines the behavior when API quota limits are exceeded.
|
||||||
@@ -33,12 +43,21 @@ type Config struct {
|
|||||||
type QuotaExceeded struct {
|
type QuotaExceeded struct {
|
||||||
// SwitchProject indicates whether to automatically switch to another project when a quota is exceeded.
|
// SwitchProject indicates whether to automatically switch to another project when a quota is exceeded.
|
||||||
SwitchProject bool `yaml:"switch-project"`
|
SwitchProject bool `yaml:"switch-project"`
|
||||||
|
|
||||||
// SwitchPreviewModel indicates whether to automatically switch to a preview model when a quota is exceeded.
|
// SwitchPreviewModel indicates whether to automatically switch to a preview model when a quota is exceeded.
|
||||||
SwitchPreviewModel bool `yaml:"switch-preview-model"`
|
SwitchPreviewModel bool `yaml:"switch-preview-model"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// LoadConfig reads a YAML configuration file from the given path,
|
// LoadConfig reads a YAML configuration file from the given path,
|
||||||
// unmarshals it into a Config struct, and returns it.
|
// unmarshals it into a Config struct, applies environment variable overrides,
|
||||||
|
// and returns it.
|
||||||
|
//
|
||||||
|
// Parameters:
|
||||||
|
// - configFile: The path to the YAML configuration file
|
||||||
|
//
|
||||||
|
// Returns:
|
||||||
|
// - *Config: The loaded configuration
|
||||||
|
// - error: An error if the configuration could not be loaded
|
||||||
func LoadConfig(configFile string) (*Config, error) {
|
func LoadConfig(configFile string) (*Config, error) {
|
||||||
// Read the entire configuration file into memory.
|
// Read the entire configuration file into memory.
|
||||||
data, err := os.ReadFile(configFile)
|
data, err := os.ReadFile(configFile)
|
||||||
|
|||||||
390
internal/logging/request_logger.go
Normal file
390
internal/logging/request_logger.go
Normal file
@@ -0,0 +1,390 @@
|
|||||||
|
// Package logging provides request logging functionality for the CLI Proxy API server.
|
||||||
|
// It handles capturing and storing detailed HTTP request and response data when enabled
|
||||||
|
// through configuration, supporting both regular and streaming responses.
|
||||||
|
package logging
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"compress/flate"
|
||||||
|
"compress/gzip"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
"regexp"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
// RequestLogger defines the interface for logging HTTP requests and responses.
|
||||||
|
type RequestLogger interface {
|
||||||
|
// LogRequest logs a complete non-streaming request/response cycle
|
||||||
|
LogRequest(url, method string, requestHeaders map[string][]string, body []byte, statusCode int, responseHeaders map[string][]string, response []byte) error
|
||||||
|
|
||||||
|
// LogStreamingRequest initiates logging for a streaming request and returns a writer for chunks
|
||||||
|
LogStreamingRequest(url, method string, headers map[string][]string, body []byte) (StreamingLogWriter, error)
|
||||||
|
|
||||||
|
// IsEnabled returns whether request logging is currently enabled
|
||||||
|
IsEnabled() bool
|
||||||
|
}
|
||||||
|
|
||||||
|
// StreamingLogWriter handles real-time logging of streaming response chunks.
|
||||||
|
type StreamingLogWriter interface {
|
||||||
|
// WriteChunkAsync writes a response chunk asynchronously (non-blocking)
|
||||||
|
WriteChunkAsync(chunk []byte)
|
||||||
|
|
||||||
|
// WriteStatus writes the response status and headers to the log
|
||||||
|
WriteStatus(status int, headers map[string][]string) error
|
||||||
|
|
||||||
|
// Close finalizes the log file and cleans up resources
|
||||||
|
Close() error
|
||||||
|
}
|
||||||
|
|
||||||
|
// FileRequestLogger implements RequestLogger using file-based storage.
|
||||||
|
type FileRequestLogger struct {
|
||||||
|
enabled bool
|
||||||
|
logsDir string
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewFileRequestLogger creates a new file-based request logger.
|
||||||
|
func NewFileRequestLogger(enabled bool, logsDir string) *FileRequestLogger {
|
||||||
|
return &FileRequestLogger{
|
||||||
|
enabled: enabled,
|
||||||
|
logsDir: logsDir,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// IsEnabled returns whether request logging is currently enabled.
|
||||||
|
func (l *FileRequestLogger) IsEnabled() bool {
|
||||||
|
return l.enabled
|
||||||
|
}
|
||||||
|
|
||||||
|
// LogRequest logs a complete non-streaming request/response cycle to a file.
|
||||||
|
func (l *FileRequestLogger) LogRequest(url, method string, requestHeaders map[string][]string, body []byte, statusCode int, responseHeaders map[string][]string, response []byte) error {
|
||||||
|
if !l.enabled {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Ensure logs directory exists
|
||||||
|
if err := l.ensureLogsDir(); err != nil {
|
||||||
|
return fmt.Errorf("failed to create logs directory: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Generate filename
|
||||||
|
filename := l.generateFilename(url)
|
||||||
|
filePath := filepath.Join(l.logsDir, filename)
|
||||||
|
|
||||||
|
// Decompress response if needed
|
||||||
|
decompressedResponse, err := l.decompressResponse(responseHeaders, response)
|
||||||
|
if err != nil {
|
||||||
|
// If decompression fails, log the error but continue with original response
|
||||||
|
decompressedResponse = append(response, []byte(fmt.Sprintf("\n[DECOMPRESSION ERROR: %v]", err))...)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create log content
|
||||||
|
content := l.formatLogContent(url, method, requestHeaders, body, decompressedResponse, statusCode, responseHeaders)
|
||||||
|
|
||||||
|
// Write to file
|
||||||
|
if err := os.WriteFile(filePath, []byte(content), 0644); err != nil {
|
||||||
|
return fmt.Errorf("failed to write log file: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// LogStreamingRequest initiates logging for a streaming request.
|
||||||
|
func (l *FileRequestLogger) LogStreamingRequest(url, method string, headers map[string][]string, body []byte) (StreamingLogWriter, error) {
|
||||||
|
if !l.enabled {
|
||||||
|
return &NoOpStreamingLogWriter{}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Ensure logs directory exists
|
||||||
|
if err := l.ensureLogsDir(); err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to create logs directory: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Generate filename
|
||||||
|
filename := l.generateFilename(url)
|
||||||
|
filePath := filepath.Join(l.logsDir, filename)
|
||||||
|
|
||||||
|
// Create and open file
|
||||||
|
file, err := os.Create(filePath)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to create log file: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Write initial request information
|
||||||
|
requestInfo := l.formatRequestInfo(url, method, headers, body)
|
||||||
|
if _, err := file.WriteString(requestInfo); err != nil {
|
||||||
|
_ = file.Close()
|
||||||
|
return nil, fmt.Errorf("failed to write request info: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create streaming writer
|
||||||
|
writer := &FileStreamingLogWriter{
|
||||||
|
file: file,
|
||||||
|
chunkChan: make(chan []byte, 100), // Buffered channel for async writes
|
||||||
|
closeChan: make(chan struct{}),
|
||||||
|
errorChan: make(chan error, 1),
|
||||||
|
}
|
||||||
|
|
||||||
|
// Start async writer goroutine
|
||||||
|
go writer.asyncWriter()
|
||||||
|
|
||||||
|
return writer, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// ensureLogsDir creates the logs directory if it doesn't exist.
|
||||||
|
func (l *FileRequestLogger) ensureLogsDir() error {
|
||||||
|
if _, err := os.Stat(l.logsDir); os.IsNotExist(err) {
|
||||||
|
return os.MkdirAll(l.logsDir, 0755)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// generateFilename creates a sanitized filename from the URL path and current timestamp.
|
||||||
|
func (l *FileRequestLogger) generateFilename(url string) string {
|
||||||
|
// Extract path from URL
|
||||||
|
path := url
|
||||||
|
if strings.Contains(url, "?") {
|
||||||
|
path = strings.Split(url, "?")[0]
|
||||||
|
}
|
||||||
|
|
||||||
|
// Remove leading slash
|
||||||
|
if strings.HasPrefix(path, "/") {
|
||||||
|
path = path[1:]
|
||||||
|
}
|
||||||
|
|
||||||
|
// Sanitize path for filename
|
||||||
|
sanitized := l.sanitizeForFilename(path)
|
||||||
|
|
||||||
|
// Add timestamp
|
||||||
|
timestamp := time.Now().UnixNano()
|
||||||
|
|
||||||
|
return fmt.Sprintf("%s-%d.log", sanitized, timestamp)
|
||||||
|
}
|
||||||
|
|
||||||
|
// sanitizeForFilename replaces characters that are not safe for filenames.
|
||||||
|
func (l *FileRequestLogger) sanitizeForFilename(path string) string {
|
||||||
|
// Replace slashes with hyphens
|
||||||
|
sanitized := strings.ReplaceAll(path, "/", "-")
|
||||||
|
|
||||||
|
// Replace colons with hyphens
|
||||||
|
sanitized = strings.ReplaceAll(sanitized, ":", "-")
|
||||||
|
|
||||||
|
// Replace other problematic characters with hyphens
|
||||||
|
reg := regexp.MustCompile(`[<>:"|?*\s]`)
|
||||||
|
sanitized = reg.ReplaceAllString(sanitized, "-")
|
||||||
|
|
||||||
|
// Remove multiple consecutive hyphens
|
||||||
|
reg = regexp.MustCompile(`-+`)
|
||||||
|
sanitized = reg.ReplaceAllString(sanitized, "-")
|
||||||
|
|
||||||
|
// Remove leading/trailing hyphens
|
||||||
|
sanitized = strings.Trim(sanitized, "-")
|
||||||
|
|
||||||
|
// Handle empty result
|
||||||
|
if sanitized == "" {
|
||||||
|
sanitized = "root"
|
||||||
|
}
|
||||||
|
|
||||||
|
return sanitized
|
||||||
|
}
|
||||||
|
|
||||||
|
// formatLogContent creates the complete log content for non-streaming requests.
|
||||||
|
func (l *FileRequestLogger) formatLogContent(url, method string, headers map[string][]string, body []byte, response []byte, status int, responseHeaders map[string][]string) string {
|
||||||
|
var content strings.Builder
|
||||||
|
|
||||||
|
// Request info
|
||||||
|
content.WriteString(l.formatRequestInfo(url, method, headers, body))
|
||||||
|
|
||||||
|
// Response section
|
||||||
|
content.WriteString("========================================\n")
|
||||||
|
content.WriteString("=== RESPONSE ===\n")
|
||||||
|
content.WriteString(fmt.Sprintf("Status: %d\n", status))
|
||||||
|
|
||||||
|
if responseHeaders != nil {
|
||||||
|
for key, values := range responseHeaders {
|
||||||
|
for _, value := range values {
|
||||||
|
content.WriteString(fmt.Sprintf("%s: %s\n", key, value))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
content.WriteString("\n")
|
||||||
|
content.Write(response)
|
||||||
|
content.WriteString("\n")
|
||||||
|
|
||||||
|
return content.String()
|
||||||
|
}
|
||||||
|
|
||||||
|
// decompressResponse decompresses response data based on Content-Encoding header.
|
||||||
|
func (l *FileRequestLogger) decompressResponse(responseHeaders map[string][]string, response []byte) ([]byte, error) {
|
||||||
|
if responseHeaders == nil || len(response) == 0 {
|
||||||
|
return response, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check Content-Encoding header
|
||||||
|
var contentEncoding string
|
||||||
|
for key, values := range responseHeaders {
|
||||||
|
if strings.ToLower(key) == "content-encoding" && len(values) > 0 {
|
||||||
|
contentEncoding = strings.ToLower(values[0])
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
switch contentEncoding {
|
||||||
|
case "gzip":
|
||||||
|
return l.decompressGzip(response)
|
||||||
|
case "deflate":
|
||||||
|
return l.decompressDeflate(response)
|
||||||
|
default:
|
||||||
|
// No compression or unsupported compression
|
||||||
|
return response, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// decompressGzip decompresses gzip-encoded data.
|
||||||
|
func (l *FileRequestLogger) decompressGzip(data []byte) ([]byte, error) {
|
||||||
|
reader, err := gzip.NewReader(bytes.NewReader(data))
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to create gzip reader: %w", err)
|
||||||
|
}
|
||||||
|
defer reader.Close()
|
||||||
|
|
||||||
|
decompressed, err := io.ReadAll(reader)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to decompress gzip data: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return decompressed, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// decompressDeflate decompresses deflate-encoded data.
|
||||||
|
func (l *FileRequestLogger) decompressDeflate(data []byte) ([]byte, error) {
|
||||||
|
reader := flate.NewReader(bytes.NewReader(data))
|
||||||
|
defer reader.Close()
|
||||||
|
|
||||||
|
decompressed, err := io.ReadAll(reader)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to decompress deflate data: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return decompressed, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// formatRequestInfo creates the request information section of the log.
|
||||||
|
func (l *FileRequestLogger) formatRequestInfo(url, method string, headers map[string][]string, body []byte) string {
|
||||||
|
var content strings.Builder
|
||||||
|
|
||||||
|
content.WriteString("=== REQUEST INFO ===\n")
|
||||||
|
content.WriteString(fmt.Sprintf("URL: %s\n", url))
|
||||||
|
content.WriteString(fmt.Sprintf("Method: %s\n", method))
|
||||||
|
content.WriteString(fmt.Sprintf("Timestamp: %s\n", time.Now().Format(time.RFC3339Nano)))
|
||||||
|
content.WriteString("\n")
|
||||||
|
|
||||||
|
content.WriteString("=== HEADERS ===\n")
|
||||||
|
for key, values := range headers {
|
||||||
|
for _, value := range values {
|
||||||
|
content.WriteString(fmt.Sprintf("%s: %s\n", key, value))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
content.WriteString("\n")
|
||||||
|
|
||||||
|
content.WriteString("=== REQUEST BODY ===\n")
|
||||||
|
content.Write(body)
|
||||||
|
content.WriteString("\n\n")
|
||||||
|
|
||||||
|
return content.String()
|
||||||
|
}
|
||||||
|
|
||||||
|
// FileStreamingLogWriter implements StreamingLogWriter for file-based streaming logs.
|
||||||
|
type FileStreamingLogWriter struct {
|
||||||
|
file *os.File
|
||||||
|
chunkChan chan []byte
|
||||||
|
closeChan chan struct{}
|
||||||
|
errorChan chan error
|
||||||
|
statusWritten bool
|
||||||
|
}
|
||||||
|
|
||||||
|
// WriteChunkAsync writes a response chunk asynchronously (non-blocking).
|
||||||
|
func (w *FileStreamingLogWriter) WriteChunkAsync(chunk []byte) {
|
||||||
|
if w.chunkChan == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Make a copy of the chunk to avoid data races
|
||||||
|
chunkCopy := make([]byte, len(chunk))
|
||||||
|
copy(chunkCopy, chunk)
|
||||||
|
|
||||||
|
// Non-blocking send
|
||||||
|
select {
|
||||||
|
case w.chunkChan <- chunkCopy:
|
||||||
|
default:
|
||||||
|
// Channel is full, skip this chunk to avoid blocking
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// WriteStatus writes the response status and headers to the log.
|
||||||
|
func (w *FileStreamingLogWriter) WriteStatus(status int, headers map[string][]string) error {
|
||||||
|
if w.file == nil || w.statusWritten {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
var content strings.Builder
|
||||||
|
content.WriteString("========================================\n")
|
||||||
|
content.WriteString("=== RESPONSE ===\n")
|
||||||
|
content.WriteString(fmt.Sprintf("Status: %d\n", status))
|
||||||
|
|
||||||
|
for key, values := range headers {
|
||||||
|
for _, value := range values {
|
||||||
|
content.WriteString(fmt.Sprintf("%s: %s\n", key, value))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
content.WriteString("\n")
|
||||||
|
|
||||||
|
_, err := w.file.WriteString(content.String())
|
||||||
|
if err == nil {
|
||||||
|
w.statusWritten = true
|
||||||
|
}
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Close finalizes the log file and cleans up resources.
|
||||||
|
func (w *FileStreamingLogWriter) Close() error {
|
||||||
|
if w.chunkChan != nil {
|
||||||
|
close(w.chunkChan)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Wait for async writer to finish
|
||||||
|
if w.closeChan != nil {
|
||||||
|
<-w.closeChan
|
||||||
|
w.chunkChan = nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if w.file != nil {
|
||||||
|
return w.file.Close()
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// asyncWriter runs in a goroutine to handle async chunk writing.
|
||||||
|
func (w *FileStreamingLogWriter) asyncWriter() {
|
||||||
|
defer close(w.closeChan)
|
||||||
|
|
||||||
|
for chunk := range w.chunkChan {
|
||||||
|
if w.file != nil {
|
||||||
|
_, _ = w.file.Write(chunk)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// NoOpStreamingLogWriter is a no-operation implementation for when logging is disabled.
|
||||||
|
type NoOpStreamingLogWriter struct{}
|
||||||
|
|
||||||
|
func (w *NoOpStreamingLogWriter) WriteChunkAsync(chunk []byte) {}
|
||||||
|
func (w *NoOpStreamingLogWriter) WriteStatus(status int, headers map[string][]string) error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
func (w *NoOpStreamingLogWriter) Close() error { return nil }
|
||||||
6
internal/misc/codex_instructions.go
Normal file
6
internal/misc/codex_instructions.go
Normal file
@@ -0,0 +1,6 @@
|
|||||||
|
package misc
|
||||||
|
|
||||||
|
import _ "embed"
|
||||||
|
|
||||||
|
//go:embed codex_instructions.txt
|
||||||
|
var CodexInstructions string
|
||||||
1
internal/misc/codex_instructions.txt
Normal file
1
internal/misc/codex_instructions.txt
Normal file
File diff suppressed because one or more lines are too long
@@ -1,7 +1,7 @@
|
|||||||
// Package translator provides data translation and format conversion utilities
|
// Package translator provides data translation and format conversion utilities
|
||||||
// for the CLI Proxy API. It includes MIME type mappings and other translation
|
// for the CLI Proxy API. It includes MIME type mappings and other translation
|
||||||
// functions used across different API endpoints.
|
// functions used across different API endpoints.
|
||||||
package translator
|
package misc
|
||||||
|
|
||||||
// MimeTypes is a comprehensive map of file extensions to their corresponding MIME types.
|
// MimeTypes is a comprehensive map of file extensions to their corresponding MIME types.
|
||||||
// This is used to identify the type of file being uploaded or processed.
|
// This is used to identify the type of file being uploaded or processed.
|
||||||
114
internal/translator/codex/claude/code/codex_cc_request.go
Normal file
114
internal/translator/codex/claude/code/codex_cc_request.go
Normal file
@@ -0,0 +1,114 @@
|
|||||||
|
// Package code provides request translation functionality for Claude API.
|
||||||
|
// It handles parsing and transforming Claude API requests into the internal client format,
|
||||||
|
// extracting model information, system instructions, message contents, and tool declarations.
|
||||||
|
// The package also performs JSON data cleaning and transformation to ensure compatibility
|
||||||
|
// between Claude API format and the internal client's expected format.
|
||||||
|
package code
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
|
||||||
|
"github.com/luispater/CLIProxyAPI/internal/misc"
|
||||||
|
"github.com/tidwall/gjson"
|
||||||
|
"github.com/tidwall/sjson"
|
||||||
|
)
|
||||||
|
|
||||||
|
// PrepareClaudeRequest parses and transforms a Claude API request into internal client format.
|
||||||
|
// It extracts the model name, system instruction, message contents, and tool declarations
|
||||||
|
// from the raw JSON request and returns them in the format expected by the internal client.
|
||||||
|
func ConvertClaudeCodeRequestToCodex(rawJSON []byte) string {
|
||||||
|
template := `{"model":"","instructions":"","input":[]}`
|
||||||
|
|
||||||
|
instructions := misc.CodexInstructions
|
||||||
|
template, _ = sjson.SetRaw(template, "instructions", instructions)
|
||||||
|
|
||||||
|
rootResult := gjson.ParseBytes(rawJSON)
|
||||||
|
modelResult := rootResult.Get("model")
|
||||||
|
template, _ = sjson.Set(template, "model", modelResult.String())
|
||||||
|
|
||||||
|
systemsResult := rootResult.Get("system")
|
||||||
|
if systemsResult.IsArray() {
|
||||||
|
systemResults := systemsResult.Array()
|
||||||
|
message := `{"type":"message","role":"user","content":[]}`
|
||||||
|
for i := 0; i < len(systemResults); i++ {
|
||||||
|
systemResult := systemResults[i]
|
||||||
|
systemTypeResult := systemResult.Get("type")
|
||||||
|
if systemTypeResult.String() == "text" {
|
||||||
|
message, _ = sjson.Set(message, fmt.Sprintf("content.%d.type", i), "input_text")
|
||||||
|
message, _ = sjson.Set(message, fmt.Sprintf("content.%d.text", i), systemResult.Get("text").String())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
template, _ = sjson.SetRaw(template, "input.-1", message)
|
||||||
|
}
|
||||||
|
|
||||||
|
messagesResult := rootResult.Get("messages")
|
||||||
|
if messagesResult.IsArray() {
|
||||||
|
messageResults := messagesResult.Array()
|
||||||
|
|
||||||
|
for i := 0; i < len(messageResults); i++ {
|
||||||
|
messageResult := messageResults[i]
|
||||||
|
|
||||||
|
messageContentsResult := messageResult.Get("content")
|
||||||
|
if messageContentsResult.IsArray() {
|
||||||
|
messageContentResults := messageContentsResult.Array()
|
||||||
|
for j := 0; j < len(messageContentResults); j++ {
|
||||||
|
messageContentResult := messageContentResults[j]
|
||||||
|
messageContentTypeResult := messageContentResult.Get("type")
|
||||||
|
if messageContentTypeResult.String() == "text" {
|
||||||
|
message := `{"type": "message","role":"","content":[]}`
|
||||||
|
messageRole := messageResult.Get("role").String()
|
||||||
|
message, _ = sjson.Set(message, "role", messageRole)
|
||||||
|
|
||||||
|
partType := "input_text"
|
||||||
|
if messageRole == "assistant" {
|
||||||
|
partType = "output_text"
|
||||||
|
}
|
||||||
|
|
||||||
|
currentIndex := len(gjson.Get(message, "content").Array())
|
||||||
|
message, _ = sjson.Set(message, fmt.Sprintf("content.%d.type", currentIndex), partType)
|
||||||
|
message, _ = sjson.Set(message, fmt.Sprintf("content.%d.text", currentIndex), messageContentResult.Get("text").String())
|
||||||
|
template, _ = sjson.SetRaw(template, "input.-1", message)
|
||||||
|
} else if messageContentTypeResult.String() == "tool_use" {
|
||||||
|
functionCallMessage := `{"type":"function_call"}`
|
||||||
|
functionCallMessage, _ = sjson.Set(functionCallMessage, "call_id", messageContentResult.Get("id").String())
|
||||||
|
functionCallMessage, _ = sjson.Set(functionCallMessage, "name", messageContentResult.Get("name").String())
|
||||||
|
functionCallMessage, _ = sjson.Set(functionCallMessage, "arguments", messageContentResult.Get("input").Raw)
|
||||||
|
template, _ = sjson.SetRaw(template, "input.-1", functionCallMessage)
|
||||||
|
} else if messageContentTypeResult.String() == "tool_result" {
|
||||||
|
functionCallOutputMessage := `{"type":"function_call_output"}`
|
||||||
|
functionCallOutputMessage, _ = sjson.Set(functionCallOutputMessage, "call_id", messageContentResult.Get("tool_use_id").String())
|
||||||
|
functionCallOutputMessage, _ = sjson.Set(functionCallOutputMessage, "output", messageContentResult.Get("content").String())
|
||||||
|
template, _ = sjson.SetRaw(template, "input.-1", functionCallOutputMessage)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
toolsResult := rootResult.Get("tools")
|
||||||
|
if toolsResult.IsArray() {
|
||||||
|
template, _ = sjson.SetRaw(template, "tools", `[]`)
|
||||||
|
template, _ = sjson.Set(template, "tool_choice", `auto`)
|
||||||
|
toolResults := toolsResult.Array()
|
||||||
|
for i := 0; i < len(toolResults); i++ {
|
||||||
|
toolResult := toolResults[i]
|
||||||
|
tool := toolResult.Raw
|
||||||
|
tool, _ = sjson.Set(tool, "type", "function")
|
||||||
|
tool, _ = sjson.SetRaw(tool, "parameters", toolResult.Get("input_schema").Raw)
|
||||||
|
tool, _ = sjson.Delete(tool, "input_schema")
|
||||||
|
tool, _ = sjson.Delete(tool, "parameters.$schema")
|
||||||
|
tool, _ = sjson.Set(tool, "strict", false)
|
||||||
|
template, _ = sjson.SetRaw(template, "tools.-1", tool)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template, _ = sjson.Set(template, "parallel_tool_calls", true)
|
||||||
|
template, _ = sjson.Set(template, "reasoning.effort", "low")
|
||||||
|
template, _ = sjson.Set(template, "reasoning.summary", "auto")
|
||||||
|
template, _ = sjson.Set(template, "stream", true)
|
||||||
|
template, _ = sjson.Set(template, "store", false)
|
||||||
|
template, _ = sjson.Set(template, "include", []string{"reasoning.encrypted_content"})
|
||||||
|
|
||||||
|
return template
|
||||||
|
}
|
||||||
129
internal/translator/codex/claude/code/codex_cc_response.go
Normal file
129
internal/translator/codex/claude/code/codex_cc_response.go
Normal file
@@ -0,0 +1,129 @@
|
|||||||
|
// Package code provides response translation functionality for Claude API.
|
||||||
|
// This package handles the conversion of backend client responses into Claude-compatible
|
||||||
|
// Server-Sent Events (SSE) format, implementing a sophisticated state machine that manages
|
||||||
|
// different response types including text content, thinking processes, and function calls.
|
||||||
|
// The translation ensures proper sequencing of SSE events and maintains state across
|
||||||
|
// multiple response chunks to provide a seamless streaming experience.
|
||||||
|
package code
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
|
||||||
|
"github.com/tidwall/gjson"
|
||||||
|
"github.com/tidwall/sjson"
|
||||||
|
)
|
||||||
|
|
||||||
|
// ConvertCliToClaude performs sophisticated streaming response format conversion.
|
||||||
|
// This function implements a complex state machine that translates backend client responses
|
||||||
|
// into Claude-compatible Server-Sent Events (SSE) format. It manages different response types
|
||||||
|
// and handles state transitions between content blocks, thinking processes, and function calls.
|
||||||
|
//
|
||||||
|
// Response type states: 0=none, 1=content, 2=thinking, 3=function
|
||||||
|
// The function maintains state across multiple calls to ensure proper SSE event sequencing.
|
||||||
|
func ConvertCodexResponseToClaude(rawJSON []byte, hasToolCall bool) (string, bool) {
|
||||||
|
// log.Debugf("rawJSON: %s", string(rawJSON))
|
||||||
|
output := ""
|
||||||
|
rootResult := gjson.ParseBytes(rawJSON)
|
||||||
|
typeResult := rootResult.Get("type")
|
||||||
|
typeStr := typeResult.String()
|
||||||
|
template := ""
|
||||||
|
if typeStr == "response.created" {
|
||||||
|
template = `{"type":"message_start","message":{"id":"","type":"message","role":"assistant","model":"claude-opus-4-1-20250805","stop_sequence":null,"usage":{"input_tokens":0,"output_tokens":0},"content":[],"stop_reason":null}}`
|
||||||
|
template, _ = sjson.Set(template, "message.model", rootResult.Get("response.model").String())
|
||||||
|
template, _ = sjson.Set(template, "message.id", rootResult.Get("response.id").String())
|
||||||
|
|
||||||
|
output = "event: message_start\n"
|
||||||
|
output += fmt.Sprintf("data: %s\n", template)
|
||||||
|
} else if typeStr == "response.reasoning_summary_part.added" {
|
||||||
|
template = `{"type":"content_block_start","index":0,"content_block":{"type":"thinking","thinking":""}}`
|
||||||
|
template, _ = sjson.Set(template, "index", rootResult.Get("output_index").Int())
|
||||||
|
|
||||||
|
output = "event: content_block_start\n"
|
||||||
|
output += fmt.Sprintf("data: %s\n", template)
|
||||||
|
} else if typeStr == "response.reasoning_summary_text.delta" {
|
||||||
|
template = `{"type":"content_block_delta","index":0,"delta":{"type":"thinking_delta","thinking":""}}`
|
||||||
|
template, _ = sjson.Set(template, "index", rootResult.Get("output_index").Int())
|
||||||
|
template, _ = sjson.Set(template, "delta.thinking", rootResult.Get("delta").String())
|
||||||
|
|
||||||
|
output = "event: content_block_delta\n"
|
||||||
|
output += fmt.Sprintf("data: %s\n", template)
|
||||||
|
} else if typeStr == "response.reasoning_summary_part.done" {
|
||||||
|
template = `{"type":"content_block_stop","index":0}`
|
||||||
|
template, _ = sjson.Set(template, "index", rootResult.Get("output_index").Int())
|
||||||
|
|
||||||
|
output = "event: content_block_stop\n"
|
||||||
|
output += fmt.Sprintf("data: %s\n", template)
|
||||||
|
} else if typeStr == "response.content_part.added" {
|
||||||
|
template = `{"type":"content_block_start","index":0,"content_block":{"type":"text","text":""}}`
|
||||||
|
template, _ = sjson.Set(template, "index", rootResult.Get("output_index").Int())
|
||||||
|
|
||||||
|
output = "event: content_block_start\n"
|
||||||
|
output += fmt.Sprintf("data: %s\n", template)
|
||||||
|
} else if typeStr == "response.output_text.delta" {
|
||||||
|
template = `{"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":""}}`
|
||||||
|
template, _ = sjson.Set(template, "index", rootResult.Get("output_index").Int())
|
||||||
|
template, _ = sjson.Set(template, "delta.text", rootResult.Get("delta").String())
|
||||||
|
|
||||||
|
output = "event: content_block_delta\n"
|
||||||
|
output += fmt.Sprintf("data: %s\n", template)
|
||||||
|
} else if typeStr == "response.content_part.done" {
|
||||||
|
template = `{"type":"content_block_stop","index":0}`
|
||||||
|
template, _ = sjson.Set(template, "index", rootResult.Get("output_index").Int())
|
||||||
|
|
||||||
|
output = "event: content_block_stop\n"
|
||||||
|
output += fmt.Sprintf("data: %s\n", template)
|
||||||
|
} else if typeStr == "response.completed" {
|
||||||
|
template = `{"type":"message_delta","delta":{"stop_reason":"tool_use","stop_sequence":null},"usage":{"input_tokens":0,"output_tokens":0}}`
|
||||||
|
if hasToolCall {
|
||||||
|
template, _ = sjson.Set(template, "delta.stop_reason", "tool_use")
|
||||||
|
} else {
|
||||||
|
template, _ = sjson.Set(template, "delta.stop_reason", "end_turn")
|
||||||
|
}
|
||||||
|
template, _ = sjson.Set(template, "usage.input_tokens", rootResult.Get("response.usage.input_tokens").Int())
|
||||||
|
template, _ = sjson.Set(template, "usage.output_tokens", rootResult.Get("response.usage.output_tokens").Int())
|
||||||
|
|
||||||
|
output = "event: message_delta\n"
|
||||||
|
output += fmt.Sprintf("data: %s\n\n", template)
|
||||||
|
output += "event: message_stop\n"
|
||||||
|
output += `data: {"type":"message_stop"}`
|
||||||
|
output += "\n\n"
|
||||||
|
} else if typeStr == "response.output_item.added" {
|
||||||
|
itemResult := rootResult.Get("item")
|
||||||
|
itemType := itemResult.Get("type").String()
|
||||||
|
if itemType == "function_call" {
|
||||||
|
hasToolCall = true
|
||||||
|
template = `{"type":"content_block_start","index":0,"content_block":{"type":"tool_use","id":"","name":"","input":{}}}`
|
||||||
|
template, _ = sjson.Set(template, "index", rootResult.Get("output_index").Int())
|
||||||
|
template, _ = sjson.Set(template, "content_block.id", itemResult.Get("call_id").String())
|
||||||
|
template, _ = sjson.Set(template, "content_block.name", itemResult.Get("name").String())
|
||||||
|
|
||||||
|
output = "event: content_block_start\n"
|
||||||
|
output += fmt.Sprintf("data: %s\n\n", template)
|
||||||
|
|
||||||
|
template = `{"type":"content_block_delta","index":0,"delta":{"type":"input_json_delta","partial_json":""}}`
|
||||||
|
template, _ = sjson.Set(template, "index", rootResult.Get("output_index").Int())
|
||||||
|
|
||||||
|
output += "event: content_block_delta\n"
|
||||||
|
output += fmt.Sprintf("data: %s\n", template)
|
||||||
|
}
|
||||||
|
} else if typeStr == "response.output_item.done" {
|
||||||
|
itemResult := rootResult.Get("item")
|
||||||
|
itemType := itemResult.Get("type").String()
|
||||||
|
if itemType == "function_call" {
|
||||||
|
template = `{"type":"content_block_stop","index":0}`
|
||||||
|
template, _ = sjson.Set(template, "index", rootResult.Get("output_index").Int())
|
||||||
|
|
||||||
|
output = "event: content_block_stop\n"
|
||||||
|
output += fmt.Sprintf("data: %s\n", template)
|
||||||
|
}
|
||||||
|
} else if typeStr == "response.function_call_arguments.delta" {
|
||||||
|
template = `{"type":"content_block_delta","index":0,"delta":{"type":"input_json_delta","partial_json":""}}`
|
||||||
|
template, _ = sjson.Set(template, "index", rootResult.Get("output_index").Int())
|
||||||
|
template, _ = sjson.Set(template, "delta.partial_json", rootResult.Get("delta").String())
|
||||||
|
|
||||||
|
output += "event: content_block_delta\n"
|
||||||
|
output += fmt.Sprintf("data: %s\n", template)
|
||||||
|
}
|
||||||
|
|
||||||
|
return output, hasToolCall
|
||||||
|
}
|
||||||
199
internal/translator/codex/gemini/codex_gemini_request.go
Normal file
199
internal/translator/codex/gemini/codex_gemini_request.go
Normal file
@@ -0,0 +1,199 @@
|
|||||||
|
// Package code provides request translation functionality for Claude API.
|
||||||
|
// It handles parsing and transforming Claude API requests into the internal client format,
|
||||||
|
// extracting model information, system instructions, message contents, and tool declarations.
|
||||||
|
// The package also performs JSON data cleaning and transformation to ensure compatibility
|
||||||
|
// between Claude API format and the internal client's expected format.
|
||||||
|
package code
|
||||||
|
|
||||||
|
import (
|
||||||
|
"crypto/rand"
|
||||||
|
"math/big"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"github.com/luispater/CLIProxyAPI/internal/misc"
|
||||||
|
"github.com/tidwall/gjson"
|
||||||
|
"github.com/tidwall/sjson"
|
||||||
|
)
|
||||||
|
|
||||||
|
// PrepareClaudeRequest parses and transforms a Claude API request into internal client format.
|
||||||
|
// It extracts the model name, system instruction, message contents, and tool declarations
|
||||||
|
// from the raw JSON request and returns them in the format expected by the internal client.
|
||||||
|
func ConvertGeminiRequestToCodex(rawJSON []byte) string {
|
||||||
|
// Base template
|
||||||
|
out := `{"model":"","instructions":"","input":[]}`
|
||||||
|
|
||||||
|
// Inject standard Codex instructions
|
||||||
|
instructions := misc.CodexInstructions
|
||||||
|
out, _ = sjson.SetRaw(out, "instructions", instructions)
|
||||||
|
|
||||||
|
root := gjson.ParseBytes(rawJSON)
|
||||||
|
|
||||||
|
// helper for generating paired call IDs in the form: call_<alphanum>
|
||||||
|
// Gemini uses sequential pairing across possibly multiple in-flight
|
||||||
|
// functionCalls, so we keep a FIFO queue of generated call IDs and
|
||||||
|
// consume them in order when functionResponses arrive.
|
||||||
|
var pendingCallIDs []string
|
||||||
|
|
||||||
|
// genCallID creates a random call id like: call_<8chars>
|
||||||
|
genCallID := func() string {
|
||||||
|
const letters = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789"
|
||||||
|
var b strings.Builder
|
||||||
|
// 8 chars random suffix
|
||||||
|
for i := 0; i < 24; i++ {
|
||||||
|
n, _ := rand.Int(rand.Reader, big.NewInt(int64(len(letters))))
|
||||||
|
b.WriteByte(letters[n.Int64()])
|
||||||
|
}
|
||||||
|
return "call_" + b.String()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Model
|
||||||
|
if v := root.Get("model"); v.Exists() {
|
||||||
|
out, _ = sjson.Set(out, "model", v.Value())
|
||||||
|
}
|
||||||
|
|
||||||
|
// System instruction -> as a user message with input_text parts
|
||||||
|
sysParts := root.Get("system_instruction.parts")
|
||||||
|
if sysParts.IsArray() {
|
||||||
|
msg := `{"type":"message","role":"user","content":[]}`
|
||||||
|
arr := sysParts.Array()
|
||||||
|
for i := 0; i < len(arr); i++ {
|
||||||
|
p := arr[i]
|
||||||
|
if t := p.Get("text"); t.Exists() {
|
||||||
|
part := `{}`
|
||||||
|
part, _ = sjson.Set(part, "type", "input_text")
|
||||||
|
part, _ = sjson.Set(part, "text", t.String())
|
||||||
|
msg, _ = sjson.SetRaw(msg, "content.-1", part)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if len(gjson.Get(msg, "content").Array()) > 0 {
|
||||||
|
out, _ = sjson.SetRaw(out, "input.-1", msg)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Contents -> messages and function calls/results
|
||||||
|
contents := root.Get("contents")
|
||||||
|
if contents.IsArray() {
|
||||||
|
items := contents.Array()
|
||||||
|
for i := 0; i < len(items); i++ {
|
||||||
|
item := items[i]
|
||||||
|
role := item.Get("role").String()
|
||||||
|
if role == "model" {
|
||||||
|
role = "assistant"
|
||||||
|
}
|
||||||
|
|
||||||
|
parts := item.Get("parts")
|
||||||
|
if !parts.IsArray() {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
parr := parts.Array()
|
||||||
|
for j := 0; j < len(parr); j++ {
|
||||||
|
p := parr[j]
|
||||||
|
// text part
|
||||||
|
if t := p.Get("text"); t.Exists() {
|
||||||
|
msg := `{"type":"message","role":"","content":[]}`
|
||||||
|
msg, _ = sjson.Set(msg, "role", role)
|
||||||
|
partType := "input_text"
|
||||||
|
if role == "assistant" {
|
||||||
|
partType = "output_text"
|
||||||
|
}
|
||||||
|
part := `{}`
|
||||||
|
part, _ = sjson.Set(part, "type", partType)
|
||||||
|
part, _ = sjson.Set(part, "text", t.String())
|
||||||
|
msg, _ = sjson.SetRaw(msg, "content.-1", part)
|
||||||
|
out, _ = sjson.SetRaw(out, "input.-1", msg)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// function call from model
|
||||||
|
if fc := p.Get("functionCall"); fc.Exists() {
|
||||||
|
fn := `{"type":"function_call"}`
|
||||||
|
if name := fc.Get("name"); name.Exists() {
|
||||||
|
fn, _ = sjson.Set(fn, "name", name.String())
|
||||||
|
}
|
||||||
|
if args := fc.Get("args"); args.Exists() {
|
||||||
|
fn, _ = sjson.Set(fn, "arguments", args.Raw)
|
||||||
|
}
|
||||||
|
// generate a paired random call_id and enqueue it so the
|
||||||
|
// corresponding functionResponse can pop the earliest id
|
||||||
|
// to preserve ordering when multiple calls are present.
|
||||||
|
id := genCallID()
|
||||||
|
fn, _ = sjson.Set(fn, "call_id", id)
|
||||||
|
pendingCallIDs = append(pendingCallIDs, id)
|
||||||
|
out, _ = sjson.SetRaw(out, "input.-1", fn)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// function response from user
|
||||||
|
if fr := p.Get("functionResponse"); fr.Exists() {
|
||||||
|
fno := `{"type":"function_call_output"}`
|
||||||
|
// Prefer a string result if present; otherwise embed the raw response as a string
|
||||||
|
if res := fr.Get("response.result"); res.Exists() {
|
||||||
|
fno, _ = sjson.Set(fno, "output", res.String())
|
||||||
|
} else if resp := fr.Get("response"); resp.Exists() {
|
||||||
|
fno, _ = sjson.Set(fno, "output", resp.Raw)
|
||||||
|
}
|
||||||
|
// fno, _ = sjson.Set(fno, "call_id", "call_W6nRJzFXyPM2LFBbfo98qAbq")
|
||||||
|
// attach the oldest queued call_id to pair the response
|
||||||
|
// with its call. If the queue is empty, generate a new id.
|
||||||
|
var id string
|
||||||
|
if len(pendingCallIDs) > 0 {
|
||||||
|
id = pendingCallIDs[0]
|
||||||
|
// pop the first element
|
||||||
|
pendingCallIDs = pendingCallIDs[1:]
|
||||||
|
} else {
|
||||||
|
id = genCallID()
|
||||||
|
}
|
||||||
|
fno, _ = sjson.Set(fno, "call_id", id)
|
||||||
|
out, _ = sjson.SetRaw(out, "input.-1", fno)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Tools mapping: Gemini functionDeclarations -> Codex tools
|
||||||
|
tools := root.Get("tools")
|
||||||
|
if tools.IsArray() {
|
||||||
|
out, _ = sjson.SetRaw(out, "tools", `[]`)
|
||||||
|
out, _ = sjson.Set(out, "tool_choice", "auto")
|
||||||
|
tarr := tools.Array()
|
||||||
|
for i := 0; i < len(tarr); i++ {
|
||||||
|
td := tarr[i]
|
||||||
|
fns := td.Get("functionDeclarations")
|
||||||
|
if !fns.IsArray() {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
farr := fns.Array()
|
||||||
|
for j := 0; j < len(farr); j++ {
|
||||||
|
fn := farr[j]
|
||||||
|
tool := `{}`
|
||||||
|
tool, _ = sjson.Set(tool, "type", "function")
|
||||||
|
if v := fn.Get("name"); v.Exists() {
|
||||||
|
tool, _ = sjson.Set(tool, "name", v.String())
|
||||||
|
}
|
||||||
|
if v := fn.Get("description"); v.Exists() {
|
||||||
|
tool, _ = sjson.Set(tool, "description", v.String())
|
||||||
|
}
|
||||||
|
if prm := fn.Get("parameters"); prm.Exists() {
|
||||||
|
// Remove optional $schema field if present
|
||||||
|
cleaned := prm.Raw
|
||||||
|
cleaned, _ = sjson.Delete(cleaned, "$schema")
|
||||||
|
cleaned, _ = sjson.Set(cleaned, "additionalProperties", false)
|
||||||
|
tool, _ = sjson.SetRaw(tool, "parameters", cleaned)
|
||||||
|
}
|
||||||
|
tool, _ = sjson.Set(tool, "strict", false)
|
||||||
|
out, _ = sjson.SetRaw(out, "tools.-1", tool)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Fixed flags aligning with Codex expectations
|
||||||
|
out, _ = sjson.Set(out, "parallel_tool_calls", true)
|
||||||
|
out, _ = sjson.Set(out, "reasoning.effort", "low")
|
||||||
|
out, _ = sjson.Set(out, "reasoning.summary", "auto")
|
||||||
|
out, _ = sjson.Set(out, "stream", true)
|
||||||
|
out, _ = sjson.Set(out, "store", false)
|
||||||
|
out, _ = sjson.Set(out, "include", []string{"reasoning.encrypted_content"})
|
||||||
|
|
||||||
|
return out
|
||||||
|
}
|
||||||
251
internal/translator/codex/gemini/codex_gemini_response.go
Normal file
251
internal/translator/codex/gemini/codex_gemini_response.go
Normal file
@@ -0,0 +1,251 @@
|
|||||||
|
// Package code provides response translation functionality for Gemini API.
|
||||||
|
// This package handles the conversion of Codex backend responses into Gemini-compatible
|
||||||
|
// JSON format, transforming streaming events into single-line JSON responses that include
|
||||||
|
// thinking content, regular text content, and function calls in the format expected by
|
||||||
|
// Gemini API clients.
|
||||||
|
package code
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/tidwall/gjson"
|
||||||
|
"github.com/tidwall/sjson"
|
||||||
|
)
|
||||||
|
|
||||||
|
type ConvertCodexResponseToGeminiParams struct {
|
||||||
|
Model string
|
||||||
|
CreatedAt int64
|
||||||
|
ResponseID string
|
||||||
|
LastStorageOutput string
|
||||||
|
}
|
||||||
|
|
||||||
|
// ConvertCodexResponseToGemini converts Codex streaming response format to Gemini single-line JSON format.
|
||||||
|
// This function processes various Codex event types and transforms them into Gemini-compatible JSON responses.
|
||||||
|
// It handles thinking content, regular text content, and function calls, outputting single-line JSON
|
||||||
|
// that matches the Gemini API response format.
|
||||||
|
// The lastEventType parameter tracks the previous event type to handle consecutive function calls properly.
|
||||||
|
func ConvertCodexResponseToGemini(rawJSON []byte, param *ConvertCodexResponseToGeminiParams) []string {
|
||||||
|
rootResult := gjson.ParseBytes(rawJSON)
|
||||||
|
typeResult := rootResult.Get("type")
|
||||||
|
typeStr := typeResult.String()
|
||||||
|
|
||||||
|
// Base Gemini response template
|
||||||
|
template := `{"candidates":[{"content":{"role":"model","parts":[]}}],"usageMetadata":{"trafficType":"PROVISIONED_THROUGHPUT"},"modelVersion":"gemini-2.5-pro","createTime":"2025-08-15T02:52:03.884209Z","responseId":"06CeaPH7NaCU48APvNXDyA4"}`
|
||||||
|
if param.LastStorageOutput != "" && typeStr == "response.output_item.done" {
|
||||||
|
template = param.LastStorageOutput
|
||||||
|
} else {
|
||||||
|
template, _ = sjson.Set(template, "modelVersion", param.Model)
|
||||||
|
createdAtResult := rootResult.Get("response.created_at")
|
||||||
|
if createdAtResult.Exists() {
|
||||||
|
param.CreatedAt = createdAtResult.Int()
|
||||||
|
template, _ = sjson.Set(template, "createTime", time.Unix(param.CreatedAt, 0).Format(time.RFC3339Nano))
|
||||||
|
}
|
||||||
|
template, _ = sjson.Set(template, "responseId", param.ResponseID)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Handle function call completion
|
||||||
|
if typeStr == "response.output_item.done" {
|
||||||
|
itemResult := rootResult.Get("item")
|
||||||
|
itemType := itemResult.Get("type").String()
|
||||||
|
if itemType == "function_call" {
|
||||||
|
// Create function call part
|
||||||
|
functionCall := `{"functionCall":{"name":"","args":{}}}`
|
||||||
|
functionCall, _ = sjson.Set(functionCall, "functionCall.name", itemResult.Get("name").String())
|
||||||
|
|
||||||
|
// Parse and set arguments
|
||||||
|
argsStr := itemResult.Get("arguments").String()
|
||||||
|
if argsStr != "" {
|
||||||
|
argsResult := gjson.Parse(argsStr)
|
||||||
|
if argsResult.IsObject() {
|
||||||
|
functionCall, _ = sjson.SetRaw(functionCall, "functionCall.args", argsStr)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template, _ = sjson.SetRaw(template, "candidates.0.content.parts.-1", functionCall)
|
||||||
|
template, _ = sjson.Set(template, "candidates.0.finishReason", "STOP")
|
||||||
|
|
||||||
|
param.LastStorageOutput = template
|
||||||
|
|
||||||
|
// Use this return to storage message
|
||||||
|
return []string{}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if typeStr == "response.created" { // Handle response creation - set model and response ID
|
||||||
|
template, _ = sjson.Set(template, "modelVersion", rootResult.Get("response.model").String())
|
||||||
|
template, _ = sjson.Set(template, "responseId", rootResult.Get("response.id").String())
|
||||||
|
param.ResponseID = rootResult.Get("response.id").String()
|
||||||
|
} else if typeStr == "response.reasoning_summary_text.delta" { // Handle reasoning/thinking content delta
|
||||||
|
part := `{"thought":true,"text":""}`
|
||||||
|
part, _ = sjson.Set(part, "text", rootResult.Get("delta").String())
|
||||||
|
template, _ = sjson.SetRaw(template, "candidates.0.content.parts.-1", part)
|
||||||
|
} else if typeStr == "response.output_text.delta" { // Handle regular text content delta
|
||||||
|
part := `{"text":""}`
|
||||||
|
part, _ = sjson.Set(part, "text", rootResult.Get("delta").String())
|
||||||
|
template, _ = sjson.SetRaw(template, "candidates.0.content.parts.-1", part)
|
||||||
|
} else if typeStr == "response.completed" { // Handle response completion with usage metadata
|
||||||
|
template, _ = sjson.Set(template, "usageMetadata.promptTokenCount", rootResult.Get("response.usage.input_tokens").Int())
|
||||||
|
template, _ = sjson.Set(template, "usageMetadata.candidatesTokenCount", rootResult.Get("response.usage.output_tokens").Int())
|
||||||
|
totalTokens := rootResult.Get("response.usage.input_tokens").Int() + rootResult.Get("response.usage.output_tokens").Int()
|
||||||
|
template, _ = sjson.Set(template, "usageMetadata.totalTokenCount", totalTokens)
|
||||||
|
} else {
|
||||||
|
return []string{}
|
||||||
|
}
|
||||||
|
|
||||||
|
if param.LastStorageOutput != "" {
|
||||||
|
return []string{param.LastStorageOutput, template}
|
||||||
|
} else {
|
||||||
|
return []string{template}
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
// ConvertCodexResponseToGeminiNonStream converts a completed Codex response to Gemini non-streaming format.
|
||||||
|
// This function processes the final response.completed event and transforms it into a complete
|
||||||
|
// Gemini-compatible JSON response that includes all content parts, function calls, and usage metadata.
|
||||||
|
func ConvertCodexResponseToGeminiNonStream(rawJSON []byte, model string) string {
|
||||||
|
rootResult := gjson.ParseBytes(rawJSON)
|
||||||
|
|
||||||
|
// Verify this is a response.completed event
|
||||||
|
if rootResult.Get("type").String() != "response.completed" {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
// Base Gemini response template for non-streaming
|
||||||
|
template := `{"candidates":[{"content":{"role":"model","parts":[]},"finishReason":"STOP"}],"usageMetadata":{"trafficType":"PROVISIONED_THROUGHPUT"},"modelVersion":"","createTime":"","responseId":""}`
|
||||||
|
|
||||||
|
// Set model version
|
||||||
|
template, _ = sjson.Set(template, "modelVersion", model)
|
||||||
|
|
||||||
|
// Set response metadata from the completed response
|
||||||
|
responseData := rootResult.Get("response")
|
||||||
|
if responseData.Exists() {
|
||||||
|
// Set response ID
|
||||||
|
if responseId := responseData.Get("id"); responseId.Exists() {
|
||||||
|
template, _ = sjson.Set(template, "responseId", responseId.String())
|
||||||
|
}
|
||||||
|
|
||||||
|
// Set creation time
|
||||||
|
if createdAt := responseData.Get("created_at"); createdAt.Exists() {
|
||||||
|
template, _ = sjson.Set(template, "createTime", time.Unix(createdAt.Int(), 0).Format(time.RFC3339Nano))
|
||||||
|
}
|
||||||
|
|
||||||
|
// Set usage metadata
|
||||||
|
if usage := responseData.Get("usage"); usage.Exists() {
|
||||||
|
inputTokens := usage.Get("input_tokens").Int()
|
||||||
|
outputTokens := usage.Get("output_tokens").Int()
|
||||||
|
totalTokens := inputTokens + outputTokens
|
||||||
|
|
||||||
|
template, _ = sjson.Set(template, "usageMetadata.promptTokenCount", inputTokens)
|
||||||
|
template, _ = sjson.Set(template, "usageMetadata.candidatesTokenCount", outputTokens)
|
||||||
|
template, _ = sjson.Set(template, "usageMetadata.totalTokenCount", totalTokens)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Process output content to build parts array
|
||||||
|
var parts []interface{}
|
||||||
|
hasToolCall := false
|
||||||
|
var pendingFunctionCalls []interface{}
|
||||||
|
|
||||||
|
flushPendingFunctionCalls := func() {
|
||||||
|
if len(pendingFunctionCalls) > 0 {
|
||||||
|
// Add all pending function calls as individual parts
|
||||||
|
// This maintains the original Gemini API format while ensuring consecutive calls are grouped together
|
||||||
|
for _, fc := range pendingFunctionCalls {
|
||||||
|
parts = append(parts, fc)
|
||||||
|
}
|
||||||
|
pendingFunctionCalls = nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if output := responseData.Get("output"); output.Exists() && output.IsArray() {
|
||||||
|
output.ForEach(func(key, value gjson.Result) bool {
|
||||||
|
itemType := value.Get("type").String()
|
||||||
|
|
||||||
|
switch itemType {
|
||||||
|
case "reasoning":
|
||||||
|
// Flush any pending function calls before adding non-function content
|
||||||
|
flushPendingFunctionCalls()
|
||||||
|
|
||||||
|
// Add thinking content
|
||||||
|
if content := value.Get("content"); content.Exists() {
|
||||||
|
part := map[string]interface{}{
|
||||||
|
"thought": true,
|
||||||
|
"text": content.String(),
|
||||||
|
}
|
||||||
|
parts = append(parts, part)
|
||||||
|
}
|
||||||
|
|
||||||
|
case "message":
|
||||||
|
// Flush any pending function calls before adding non-function content
|
||||||
|
flushPendingFunctionCalls()
|
||||||
|
|
||||||
|
// Add regular text content
|
||||||
|
if content := value.Get("content"); content.Exists() && content.IsArray() {
|
||||||
|
content.ForEach(func(_, contentItem gjson.Result) bool {
|
||||||
|
if contentItem.Get("type").String() == "output_text" {
|
||||||
|
if text := contentItem.Get("text"); text.Exists() {
|
||||||
|
part := map[string]interface{}{
|
||||||
|
"text": text.String(),
|
||||||
|
}
|
||||||
|
parts = append(parts, part)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return true
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
case "function_call":
|
||||||
|
// Collect function call for potential merging with consecutive ones
|
||||||
|
hasToolCall = true
|
||||||
|
functionCall := map[string]interface{}{
|
||||||
|
"functionCall": map[string]interface{}{
|
||||||
|
"name": value.Get("name").String(),
|
||||||
|
"args": map[string]interface{}{},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
// Parse and set arguments
|
||||||
|
if argsStr := value.Get("arguments").String(); argsStr != "" {
|
||||||
|
argsResult := gjson.Parse(argsStr)
|
||||||
|
if argsResult.IsObject() {
|
||||||
|
var args map[string]interface{}
|
||||||
|
if err := json.Unmarshal([]byte(argsStr), &args); err == nil {
|
||||||
|
functionCall["functionCall"].(map[string]interface{})["args"] = args
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pendingFunctionCalls = append(pendingFunctionCalls, functionCall)
|
||||||
|
}
|
||||||
|
return true
|
||||||
|
})
|
||||||
|
|
||||||
|
// Handle any remaining pending function calls at the end
|
||||||
|
flushPendingFunctionCalls()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Set the parts array
|
||||||
|
if len(parts) > 0 {
|
||||||
|
template, _ = sjson.SetRaw(template, "candidates.0.content.parts", mustMarshalJSON(parts))
|
||||||
|
}
|
||||||
|
|
||||||
|
// Set finish reason based on whether there were tool calls
|
||||||
|
if hasToolCall {
|
||||||
|
template, _ = sjson.Set(template, "candidates.0.finishReason", "STOP")
|
||||||
|
} else {
|
||||||
|
template, _ = sjson.Set(template, "candidates.0.finishReason", "STOP")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return template
|
||||||
|
}
|
||||||
|
|
||||||
|
// mustMarshalJSON marshals data to JSON, panicking on error (should not happen with valid data)
|
||||||
|
func mustMarshalJSON(v interface{}) string {
|
||||||
|
data, err := json.Marshal(v)
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
return string(data)
|
||||||
|
}
|
||||||
227
internal/translator/codex/openai/codex_openai_request.go
Normal file
227
internal/translator/codex/openai/codex_openai_request.go
Normal file
@@ -0,0 +1,227 @@
|
|||||||
|
// Package codex provides utilities to translate OpenAI Chat Completions
|
||||||
|
// request JSON into OpenAI Responses API request JSON using gjson/sjson.
|
||||||
|
// It supports tools, multimodal text/image inputs, and Structured Outputs.
|
||||||
|
package openai
|
||||||
|
|
||||||
|
import (
|
||||||
|
"github.com/luispater/CLIProxyAPI/internal/misc"
|
||||||
|
"github.com/tidwall/gjson"
|
||||||
|
"github.com/tidwall/sjson"
|
||||||
|
)
|
||||||
|
|
||||||
|
// ConvertOpenAIChatRequestToCodex converts an OpenAI Chat Completions request JSON
|
||||||
|
// into an OpenAI Responses API request JSON. The transformation follows the
|
||||||
|
// examples defined in docs/2.md exactly, including tools, multi-turn dialog,
|
||||||
|
// multimodal text/image handling, and Structured Outputs mapping.
|
||||||
|
func ConvertOpenAIChatRequestToCodex(rawJSON []byte) string {
|
||||||
|
// Start with empty JSON object
|
||||||
|
out := `{}`
|
||||||
|
store := false
|
||||||
|
|
||||||
|
// Stream must be set to true
|
||||||
|
if v := gjson.GetBytes(rawJSON, "stream"); v.Exists() {
|
||||||
|
out, _ = sjson.Set(out, "stream", true)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Codex not support temperature, top_p, top_k, max_output_tokens, so comment them
|
||||||
|
// if v := gjson.GetBytes(rawJSON, "temperature"); v.Exists() {
|
||||||
|
// out, _ = sjson.Set(out, "temperature", v.Value())
|
||||||
|
// }
|
||||||
|
// if v := gjson.GetBytes(rawJSON, "top_p"); v.Exists() {
|
||||||
|
// out, _ = sjson.Set(out, "top_p", v.Value())
|
||||||
|
// }
|
||||||
|
// if v := gjson.GetBytes(rawJSON, "top_k"); v.Exists() {
|
||||||
|
// out, _ = sjson.Set(out, "top_k", v.Value())
|
||||||
|
// }
|
||||||
|
|
||||||
|
// Map token limits
|
||||||
|
// if v := gjson.GetBytes(rawJSON, "max_tokens"); v.Exists() {
|
||||||
|
// out, _ = sjson.Set(out, "max_output_tokens", v.Value())
|
||||||
|
// }
|
||||||
|
// if v := gjson.GetBytes(rawJSON, "max_completion_tokens"); v.Exists() {
|
||||||
|
// out, _ = sjson.Set(out, "max_output_tokens", v.Value())
|
||||||
|
// }
|
||||||
|
|
||||||
|
// Map reasoning effort
|
||||||
|
if v := gjson.GetBytes(rawJSON, "reasoning_effort"); v.Exists() {
|
||||||
|
out, _ = sjson.Set(out, "reasoning.effort", v.Value())
|
||||||
|
out, _ = sjson.Set(out, "reasoning.summary", "auto")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Model
|
||||||
|
if v := gjson.GetBytes(rawJSON, "model"); v.Exists() {
|
||||||
|
out, _ = sjson.Set(out, "model", v.Value())
|
||||||
|
}
|
||||||
|
|
||||||
|
// Extract system instructions from first system message (string or text object)
|
||||||
|
messages := gjson.GetBytes(rawJSON, "messages")
|
||||||
|
instructions := misc.CodexInstructions
|
||||||
|
out, _ = sjson.SetRaw(out, "instructions", instructions)
|
||||||
|
// if messages.IsArray() {
|
||||||
|
// arr := messages.Array()
|
||||||
|
// for i := 0; i < len(arr); i++ {
|
||||||
|
// m := arr[i]
|
||||||
|
// if m.Get("role").String() == "system" {
|
||||||
|
// c := m.Get("content")
|
||||||
|
// if c.Type == gjson.String {
|
||||||
|
// out, _ = sjson.Set(out, "instructions", c.String())
|
||||||
|
// } else if c.IsObject() && c.Get("type").String() == "text" {
|
||||||
|
// out, _ = sjson.Set(out, "instructions", c.Get("text").String())
|
||||||
|
// }
|
||||||
|
// break
|
||||||
|
// }
|
||||||
|
// }
|
||||||
|
// }
|
||||||
|
|
||||||
|
// Build input from messages, skipping system/tool roles
|
||||||
|
out, _ = sjson.SetRaw(out, "input", `[]`)
|
||||||
|
if messages.IsArray() {
|
||||||
|
arr := messages.Array()
|
||||||
|
for i := 0; i < len(arr); i++ {
|
||||||
|
m := arr[i]
|
||||||
|
role := m.Get("role").String()
|
||||||
|
if role == "tool" || role == "function" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// Prepare message object
|
||||||
|
msg := `{}`
|
||||||
|
if role == "system" {
|
||||||
|
msg, _ = sjson.Set(msg, "role", "user")
|
||||||
|
} else {
|
||||||
|
msg, _ = sjson.Set(msg, "role", role)
|
||||||
|
}
|
||||||
|
|
||||||
|
msg, _ = sjson.SetRaw(msg, "content", `[]`)
|
||||||
|
|
||||||
|
c := m.Get("content")
|
||||||
|
if c.Type == gjson.String {
|
||||||
|
// Single string content
|
||||||
|
partType := "input_text"
|
||||||
|
if role == "assistant" {
|
||||||
|
partType = "output_text"
|
||||||
|
}
|
||||||
|
part := `{}`
|
||||||
|
part, _ = sjson.Set(part, "type", partType)
|
||||||
|
part, _ = sjson.Set(part, "text", c.String())
|
||||||
|
msg, _ = sjson.SetRaw(msg, "content.-1", part)
|
||||||
|
} else if c.IsArray() {
|
||||||
|
items := c.Array()
|
||||||
|
for j := 0; j < len(items); j++ {
|
||||||
|
it := items[j]
|
||||||
|
t := it.Get("type").String()
|
||||||
|
switch t {
|
||||||
|
case "text":
|
||||||
|
partType := "input_text"
|
||||||
|
if role == "assistant" {
|
||||||
|
partType = "output_text"
|
||||||
|
}
|
||||||
|
part := `{}`
|
||||||
|
part, _ = sjson.Set(part, "type", partType)
|
||||||
|
part, _ = sjson.Set(part, "text", it.Get("text").String())
|
||||||
|
msg, _ = sjson.SetRaw(msg, "content.-1", part)
|
||||||
|
case "image_url":
|
||||||
|
// Map image inputs to input_image for Responses API
|
||||||
|
if role == "user" {
|
||||||
|
part := `{}`
|
||||||
|
part, _ = sjson.Set(part, "type", "input_image")
|
||||||
|
if u := it.Get("image_url.url"); u.Exists() {
|
||||||
|
part, _ = sjson.Set(part, "image_url", u.String())
|
||||||
|
}
|
||||||
|
msg, _ = sjson.SetRaw(msg, "content.-1", part)
|
||||||
|
}
|
||||||
|
case "file":
|
||||||
|
// Files are not specified in examples; skip for now
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
out, _ = sjson.SetRaw(out, "input.-1", msg)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Map response_format and text settings to Responses API text.format
|
||||||
|
rf := gjson.GetBytes(rawJSON, "response_format")
|
||||||
|
text := gjson.GetBytes(rawJSON, "text")
|
||||||
|
if rf.Exists() {
|
||||||
|
// Always create text object when response_format provided
|
||||||
|
if !gjson.Get(out, "text").Exists() {
|
||||||
|
out, _ = sjson.SetRaw(out, "text", `{}`)
|
||||||
|
}
|
||||||
|
|
||||||
|
rft := rf.Get("type").String()
|
||||||
|
switch rft {
|
||||||
|
case "text":
|
||||||
|
out, _ = sjson.Set(out, "text.format.type", "text")
|
||||||
|
case "json_schema":
|
||||||
|
js := rf.Get("json_schema")
|
||||||
|
if js.Exists() {
|
||||||
|
out, _ = sjson.Set(out, "text.format.type", "json_schema")
|
||||||
|
if v := js.Get("name"); v.Exists() {
|
||||||
|
out, _ = sjson.Set(out, "text.format.name", v.Value())
|
||||||
|
}
|
||||||
|
if v := js.Get("strict"); v.Exists() {
|
||||||
|
out, _ = sjson.Set(out, "text.format.strict", v.Value())
|
||||||
|
}
|
||||||
|
if v := js.Get("schema"); v.Exists() {
|
||||||
|
out, _ = sjson.SetRaw(out, "text.format.schema", v.Raw)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Map verbosity if provided
|
||||||
|
if text.Exists() {
|
||||||
|
if v := text.Get("verbosity"); v.Exists() {
|
||||||
|
out, _ = sjson.Set(out, "text.verbosity", v.Value())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// The examples include store: true when response_format is provided
|
||||||
|
store = true
|
||||||
|
} else if text.Exists() {
|
||||||
|
// If only text.verbosity present (no response_format), map verbosity
|
||||||
|
if v := text.Get("verbosity"); v.Exists() {
|
||||||
|
if !gjson.Get(out, "text").Exists() {
|
||||||
|
out, _ = sjson.SetRaw(out, "text", `{}`)
|
||||||
|
}
|
||||||
|
out, _ = sjson.Set(out, "text.verbosity", v.Value())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Map tools (flatten function fields)
|
||||||
|
tools := gjson.GetBytes(rawJSON, "tools")
|
||||||
|
if tools.IsArray() {
|
||||||
|
out, _ = sjson.SetRaw(out, "tools", `[]`)
|
||||||
|
arr := tools.Array()
|
||||||
|
for i := 0; i < len(arr); i++ {
|
||||||
|
t := arr[i]
|
||||||
|
if t.Get("type").String() == "function" {
|
||||||
|
item := `{}`
|
||||||
|
item, _ = sjson.Set(item, "type", "function")
|
||||||
|
fn := t.Get("function")
|
||||||
|
if fn.Exists() {
|
||||||
|
if v := fn.Get("name"); v.Exists() {
|
||||||
|
item, _ = sjson.Set(item, "name", v.Value())
|
||||||
|
}
|
||||||
|
if v := fn.Get("description"); v.Exists() {
|
||||||
|
item, _ = sjson.Set(item, "description", v.Value())
|
||||||
|
}
|
||||||
|
if v := fn.Get("parameters"); v.Exists() {
|
||||||
|
item, _ = sjson.SetRaw(item, "parameters", v.Raw)
|
||||||
|
}
|
||||||
|
if v := fn.Get("strict"); v.Exists() {
|
||||||
|
item, _ = sjson.Set(item, "strict", v.Value())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
out, _ = sjson.SetRaw(out, "tools.-1", item)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// The examples include store: true when tools and formatting are used; be conservative
|
||||||
|
if rf.Exists() {
|
||||||
|
store = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
out, _ = sjson.Set(out, "store", store)
|
||||||
|
return out
|
||||||
|
}
|
||||||
231
internal/translator/codex/openai/codex_openai_response.go
Normal file
231
internal/translator/codex/openai/codex_openai_response.go
Normal file
@@ -0,0 +1,231 @@
|
|||||||
|
// Package codex provides response translation functionality for converting between
|
||||||
|
// Codex API response formats and OpenAI-compatible formats. It handles both
|
||||||
|
// streaming and non-streaming responses, transforming backend client responses
|
||||||
|
// into OpenAI Server-Sent Events (SSE) format and standard JSON response formats.
|
||||||
|
// The package supports content translation, function calls, reasoning content,
|
||||||
|
// usage metadata, and various response attributes while maintaining compatibility
|
||||||
|
// with OpenAI API specifications.
|
||||||
|
package openai
|
||||||
|
|
||||||
|
import (
|
||||||
|
"github.com/tidwall/gjson"
|
||||||
|
"github.com/tidwall/sjson"
|
||||||
|
)
|
||||||
|
|
||||||
|
type ConvertCliToOpenAIParams struct {
|
||||||
|
ResponseID string
|
||||||
|
CreatedAt int64
|
||||||
|
Model string
|
||||||
|
}
|
||||||
|
|
||||||
|
// ConvertCodexResponseToOpenAIChat translates a single chunk of a streaming response from the
|
||||||
|
// Codex backend client format to the OpenAI Server-Sent Events (SSE) format.
|
||||||
|
// It returns an empty string if the chunk contains no useful data.
|
||||||
|
func ConvertCodexResponseToOpenAIChat(rawJSON []byte, params *ConvertCliToOpenAIParams) (*ConvertCliToOpenAIParams, string) {
|
||||||
|
// Initialize the OpenAI SSE template.
|
||||||
|
template := `{"id":"","object":"chat.completion.chunk","created":12345,"model":"model","choices":[{"index":0,"delta":{"role":null,"content":null,"reasoning_content":null,"tool_calls":null},"finish_reason":null,"native_finish_reason":null}]}`
|
||||||
|
|
||||||
|
rootResult := gjson.ParseBytes(rawJSON)
|
||||||
|
|
||||||
|
typeResult := rootResult.Get("type")
|
||||||
|
dataType := typeResult.String()
|
||||||
|
if dataType == "response.created" {
|
||||||
|
return &ConvertCliToOpenAIParams{
|
||||||
|
ResponseID: rootResult.Get("response.id").String(),
|
||||||
|
CreatedAt: rootResult.Get("response.created_at").Int(),
|
||||||
|
Model: rootResult.Get("response.model").String(),
|
||||||
|
}, ""
|
||||||
|
}
|
||||||
|
|
||||||
|
if params == nil {
|
||||||
|
return params, ""
|
||||||
|
}
|
||||||
|
|
||||||
|
// Extract and set the model version.
|
||||||
|
if modelResult := gjson.GetBytes(rawJSON, "model"); modelResult.Exists() {
|
||||||
|
template, _ = sjson.Set(template, "model", modelResult.String())
|
||||||
|
}
|
||||||
|
|
||||||
|
template, _ = sjson.Set(template, "created", params.CreatedAt)
|
||||||
|
|
||||||
|
// Extract and set the response ID.
|
||||||
|
template, _ = sjson.Set(template, "id", params.ResponseID)
|
||||||
|
|
||||||
|
// Extract and set usage metadata (token counts).
|
||||||
|
if usageResult := gjson.GetBytes(rawJSON, "response.usage"); usageResult.Exists() {
|
||||||
|
if outputTokensResult := usageResult.Get("output_tokens"); outputTokensResult.Exists() {
|
||||||
|
template, _ = sjson.Set(template, "usage.completion_tokens", outputTokensResult.Int())
|
||||||
|
}
|
||||||
|
if totalTokensResult := usageResult.Get("total_tokens"); totalTokensResult.Exists() {
|
||||||
|
template, _ = sjson.Set(template, "usage.total_tokens", totalTokensResult.Int())
|
||||||
|
}
|
||||||
|
if inputTokensResult := usageResult.Get("input_tokens"); inputTokensResult.Exists() {
|
||||||
|
template, _ = sjson.Set(template, "usage.prompt_tokens", inputTokensResult.Int())
|
||||||
|
}
|
||||||
|
if reasoningTokensResult := usageResult.Get("output_tokens_details.reasoning_tokens"); reasoningTokensResult.Exists() {
|
||||||
|
template, _ = sjson.Set(template, "usage.completion_tokens_details.reasoning_tokens", reasoningTokensResult.Int())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if dataType == "response.reasoning_summary_text.delta" {
|
||||||
|
if deltaResult := rootResult.Get("delta"); deltaResult.Exists() {
|
||||||
|
template, _ = sjson.Set(template, "choices.0.delta.role", "assistant")
|
||||||
|
template, _ = sjson.Set(template, "choices.0.delta.reasoning_content", deltaResult.String())
|
||||||
|
}
|
||||||
|
} else if dataType == "response.reasoning_summary_text.done" {
|
||||||
|
template, _ = sjson.Set(template, "choices.0.delta.role", "assistant")
|
||||||
|
template, _ = sjson.Set(template, "choices.0.delta.reasoning_content", "\n\n")
|
||||||
|
} else if dataType == "response.output_text.delta" {
|
||||||
|
if deltaResult := rootResult.Get("delta"); deltaResult.Exists() {
|
||||||
|
template, _ = sjson.Set(template, "choices.0.delta.role", "assistant")
|
||||||
|
template, _ = sjson.Set(template, "choices.0.delta.content", deltaResult.String())
|
||||||
|
}
|
||||||
|
} else if dataType == "response.completed" {
|
||||||
|
template, _ = sjson.Set(template, "choices.0.finish_reason", "stop")
|
||||||
|
template, _ = sjson.Set(template, "choices.0.native_finish_reason", "stop")
|
||||||
|
} else if dataType == "response.output_item.done" {
|
||||||
|
functionCallItemTemplate := `{"id": "","type": "function","function": {"name": "","arguments": ""}}`
|
||||||
|
itemResult := rootResult.Get("item")
|
||||||
|
if itemResult.Exists() {
|
||||||
|
if itemResult.Get("type").String() != "function_call" {
|
||||||
|
return params, ""
|
||||||
|
}
|
||||||
|
template, _ = sjson.SetRaw(template, "choices.0.delta.tool_calls", `[]`)
|
||||||
|
functionCallItemTemplate, _ = sjson.Set(functionCallItemTemplate, "id", itemResult.Get("call_id").String())
|
||||||
|
functionCallItemTemplate, _ = sjson.Set(functionCallItemTemplate, "function.name", itemResult.Get("name").String())
|
||||||
|
functionCallItemTemplate, _ = sjson.Set(functionCallItemTemplate, "function.arguments", itemResult.Get("arguments").String())
|
||||||
|
template, _ = sjson.Set(template, "choices.0.delta.role", "assistant")
|
||||||
|
template, _ = sjson.SetRaw(template, "choices.0.delta.tool_calls.-1", functionCallItemTemplate)
|
||||||
|
}
|
||||||
|
|
||||||
|
} else {
|
||||||
|
return params, ""
|
||||||
|
}
|
||||||
|
|
||||||
|
return params, template
|
||||||
|
}
|
||||||
|
|
||||||
|
// ConvertCodexResponseToOpenAIChatNonStream aggregates response from the Codex backend client
|
||||||
|
// convert a single, non-streaming OpenAI-compatible JSON response.
|
||||||
|
func ConvertCodexResponseToOpenAIChatNonStream(rawJSON string, unixTimestamp int64) string {
|
||||||
|
template := `{"id":"","object":"chat.completion","created":123456,"model":"model","choices":[{"index":0,"message":{"role":"assistant","content":null,"reasoning_content":null,"tool_calls":null},"finish_reason":null,"native_finish_reason":null}]}`
|
||||||
|
|
||||||
|
// Extract and set the model version.
|
||||||
|
if modelResult := gjson.Get(rawJSON, "model"); modelResult.Exists() {
|
||||||
|
template, _ = sjson.Set(template, "model", modelResult.String())
|
||||||
|
}
|
||||||
|
|
||||||
|
// Extract and set the creation timestamp.
|
||||||
|
if createdAtResult := gjson.Get(rawJSON, "created_at"); createdAtResult.Exists() {
|
||||||
|
template, _ = sjson.Set(template, "created", createdAtResult.Int())
|
||||||
|
} else {
|
||||||
|
template, _ = sjson.Set(template, "created", unixTimestamp)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Extract and set the response ID.
|
||||||
|
if idResult := gjson.Get(rawJSON, "id"); idResult.Exists() {
|
||||||
|
template, _ = sjson.Set(template, "id", idResult.String())
|
||||||
|
}
|
||||||
|
|
||||||
|
// Extract and set usage metadata (token counts).
|
||||||
|
if usageResult := gjson.Get(rawJSON, "usage"); usageResult.Exists() {
|
||||||
|
if outputTokensResult := usageResult.Get("output_tokens"); outputTokensResult.Exists() {
|
||||||
|
template, _ = sjson.Set(template, "usage.completion_tokens", outputTokensResult.Int())
|
||||||
|
}
|
||||||
|
if totalTokensResult := usageResult.Get("total_tokens"); totalTokensResult.Exists() {
|
||||||
|
template, _ = sjson.Set(template, "usage.total_tokens", totalTokensResult.Int())
|
||||||
|
}
|
||||||
|
if inputTokensResult := usageResult.Get("input_tokens"); inputTokensResult.Exists() {
|
||||||
|
template, _ = sjson.Set(template, "usage.prompt_tokens", inputTokensResult.Int())
|
||||||
|
}
|
||||||
|
if reasoningTokensResult := usageResult.Get("output_tokens_details.reasoning_tokens"); reasoningTokensResult.Exists() {
|
||||||
|
template, _ = sjson.Set(template, "usage.completion_tokens_details.reasoning_tokens", reasoningTokensResult.Int())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Process the output array for content and function calls
|
||||||
|
outputResult := gjson.Get(rawJSON, "output")
|
||||||
|
if outputResult.IsArray() {
|
||||||
|
outputArray := outputResult.Array()
|
||||||
|
var contentText string
|
||||||
|
var reasoningText string
|
||||||
|
var toolCalls []string
|
||||||
|
|
||||||
|
for _, outputItem := range outputArray {
|
||||||
|
outputType := outputItem.Get("type").String()
|
||||||
|
|
||||||
|
switch outputType {
|
||||||
|
case "reasoning":
|
||||||
|
// Extract reasoning content from summary
|
||||||
|
if summaryResult := outputItem.Get("summary"); summaryResult.IsArray() {
|
||||||
|
summaryArray := summaryResult.Array()
|
||||||
|
for _, summaryItem := range summaryArray {
|
||||||
|
if summaryItem.Get("type").String() == "summary_text" {
|
||||||
|
reasoningText = summaryItem.Get("text").String()
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
case "message":
|
||||||
|
// Extract message content
|
||||||
|
if contentResult := outputItem.Get("content"); contentResult.IsArray() {
|
||||||
|
contentArray := contentResult.Array()
|
||||||
|
for _, contentItem := range contentArray {
|
||||||
|
if contentItem.Get("type").String() == "output_text" {
|
||||||
|
contentText = contentItem.Get("text").String()
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
case "function_call":
|
||||||
|
// Handle function call content
|
||||||
|
functionCallTemplate := `{"id": "","type": "function","function": {"name": "","arguments": ""}}`
|
||||||
|
|
||||||
|
if callIdResult := outputItem.Get("call_id"); callIdResult.Exists() {
|
||||||
|
functionCallTemplate, _ = sjson.Set(functionCallTemplate, "id", callIdResult.String())
|
||||||
|
}
|
||||||
|
|
||||||
|
if nameResult := outputItem.Get("name"); nameResult.Exists() {
|
||||||
|
functionCallTemplate, _ = sjson.Set(functionCallTemplate, "function.name", nameResult.String())
|
||||||
|
}
|
||||||
|
|
||||||
|
if argsResult := outputItem.Get("arguments"); argsResult.Exists() {
|
||||||
|
functionCallTemplate, _ = sjson.Set(functionCallTemplate, "function.arguments", argsResult.String())
|
||||||
|
}
|
||||||
|
|
||||||
|
toolCalls = append(toolCalls, functionCallTemplate)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Set content and reasoning content if found
|
||||||
|
if contentText != "" {
|
||||||
|
template, _ = sjson.Set(template, "choices.0.message.content", contentText)
|
||||||
|
template, _ = sjson.Set(template, "choices.0.message.role", "assistant")
|
||||||
|
}
|
||||||
|
|
||||||
|
if reasoningText != "" {
|
||||||
|
template, _ = sjson.Set(template, "choices.0.message.reasoning_content", reasoningText)
|
||||||
|
template, _ = sjson.Set(template, "choices.0.message.role", "assistant")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add tool calls if any
|
||||||
|
if len(toolCalls) > 0 {
|
||||||
|
template, _ = sjson.SetRaw(template, "choices.0.message.tool_calls", `[]`)
|
||||||
|
for _, toolCall := range toolCalls {
|
||||||
|
template, _ = sjson.SetRaw(template, "choices.0.message.tool_calls.-1", toolCall)
|
||||||
|
}
|
||||||
|
template, _ = sjson.Set(template, "choices.0.message.role", "assistant")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Extract and set the finish reason based on status
|
||||||
|
if statusResult := gjson.Get(rawJSON, "status"); statusResult.Exists() {
|
||||||
|
status := statusResult.String()
|
||||||
|
if status == "completed" {
|
||||||
|
template, _ = sjson.Set(template, "choices.0.finish_reason", "stop")
|
||||||
|
template, _ = sjson.Set(template, "choices.0.native_finish_reason", "stop")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return template
|
||||||
|
}
|
||||||
@@ -8,16 +8,17 @@ package code
|
|||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
|
"strings"
|
||||||
|
|
||||||
"github.com/luispater/CLIProxyAPI/internal/client"
|
"github.com/luispater/CLIProxyAPI/internal/client"
|
||||||
"github.com/tidwall/gjson"
|
"github.com/tidwall/gjson"
|
||||||
"github.com/tidwall/sjson"
|
"github.com/tidwall/sjson"
|
||||||
"strings"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
// PrepareClaudeRequest parses and transforms a Claude API request into internal client format.
|
// ConvertClaudeCodeRequestToCli parses and transforms a Claude API request into internal client format.
|
||||||
// It extracts the model name, system instruction, message contents, and tool declarations
|
// It extracts the model name, system instruction, message contents, and tool declarations
|
||||||
// from the raw JSON request and returns them in the format expected by the internal client.
|
// from the raw JSON request and returns them in the format expected by the internal client.
|
||||||
func PrepareClaudeRequest(rawJSON []byte) (string, *client.Content, []client.Content, []client.ToolDeclaration) {
|
func ConvertClaudeCodeRequestToCli(rawJSON []byte) (string, *client.Content, []client.Content, []client.ToolDeclaration) {
|
||||||
var pathsToDelete []string
|
var pathsToDelete []string
|
||||||
root := gjson.ParseBytes(rawJSON)
|
root := gjson.ParseBytes(rawJSON)
|
||||||
walk(root, "", "additionalProperties", &pathsToDelete)
|
walk(root, "", "additionalProperties", &pathsToDelete)
|
||||||
@@ -9,19 +9,20 @@ package code
|
|||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"time"
|
||||||
|
|
||||||
"github.com/tidwall/gjson"
|
"github.com/tidwall/gjson"
|
||||||
"github.com/tidwall/sjson"
|
"github.com/tidwall/sjson"
|
||||||
"time"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
// ConvertCliToClaude performs sophisticated streaming response format conversion.
|
// ConvertCliResponseToClaudeCode performs sophisticated streaming response format conversion.
|
||||||
// This function implements a complex state machine that translates backend client responses
|
// This function implements a complex state machine that translates backend client responses
|
||||||
// into Claude-compatible Server-Sent Events (SSE) format. It manages different response types
|
// into Claude-compatible Server-Sent Events (SSE) format. It manages different response types
|
||||||
// and handles state transitions between content blocks, thinking processes, and function calls.
|
// and handles state transitions between content blocks, thinking processes, and function calls.
|
||||||
//
|
//
|
||||||
// Response type states: 0=none, 1=content, 2=thinking, 3=function
|
// Response type states: 0=none, 1=content, 2=thinking, 3=function
|
||||||
// The function maintains state across multiple calls to ensure proper SSE event sequencing.
|
// The function maintains state across multiple calls to ensure proper SSE event sequencing.
|
||||||
func ConvertCliToClaude(rawJSON []byte, isGlAPIKey, hasFirstResponse bool, responseType, responseIndex *int) string {
|
func ConvertCliResponseToClaudeCode(rawJSON []byte, isGlAPIKey, hasFirstResponse bool, responseType, responseIndex *int) string {
|
||||||
// Normalize the response format for different API key types
|
// Normalize the response format for different API key types
|
||||||
// Generative Language API keys have a different response structure
|
// Generative Language API keys have a different response structure
|
||||||
if isGlAPIKey {
|
if isGlAPIKey {
|
||||||
@@ -9,6 +9,7 @@ package cli
|
|||||||
import (
|
import (
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
"github.com/tidwall/gjson"
|
"github.com/tidwall/gjson"
|
||||||
"github.com/tidwall/sjson"
|
"github.com/tidwall/sjson"
|
||||||
@@ -6,18 +6,31 @@ package openai
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"github.com/luispater/CLIProxyAPI/internal/api/translator"
|
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"github.com/luispater/CLIProxyAPI/internal/client"
|
"github.com/luispater/CLIProxyAPI/internal/client"
|
||||||
|
"github.com/luispater/CLIProxyAPI/internal/misc"
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
"github.com/tidwall/gjson"
|
"github.com/tidwall/gjson"
|
||||||
)
|
)
|
||||||
|
|
||||||
// PrepareRequest translates a raw JSON request from an OpenAI-compatible format
|
// ConvertOpenAIChatRequestToCli translates a raw JSON request from an OpenAI-compatible format
|
||||||
// to the internal format expected by the backend client. It parses messages,
|
// to the internal format expected by the backend client. It parses messages,
|
||||||
// roles, content types (text, image, file), and tool calls.
|
// roles, content types (text, image, file), and tool calls.
|
||||||
func PrepareRequest(rawJSON []byte) (string, *client.Content, []client.Content, []client.ToolDeclaration) {
|
//
|
||||||
|
// This function handles the complex task of converting between the OpenAI message
|
||||||
|
// format and the internal format used by the Gemini client. It processes different
|
||||||
|
// message types (system, user, assistant, tool) and content types (text, images, files).
|
||||||
|
//
|
||||||
|
// Parameters:
|
||||||
|
// - rawJSON: The raw JSON bytes of the OpenAI-compatible request
|
||||||
|
//
|
||||||
|
// Returns:
|
||||||
|
// - string: The model name to use
|
||||||
|
// - *client.Content: System instruction content (if any)
|
||||||
|
// - []client.Content: The conversation contents in internal format
|
||||||
|
// - []client.ToolDeclaration: Tool declarations from the request
|
||||||
|
func ConvertOpenAIChatRequestToCli(rawJSON []byte) (string, *client.Content, []client.Content, []client.ToolDeclaration) {
|
||||||
// Extract the model name from the request, defaulting to "gemini-2.5-pro".
|
// Extract the model name from the request, defaulting to "gemini-2.5-pro".
|
||||||
modelName := "gemini-2.5-pro"
|
modelName := "gemini-2.5-pro"
|
||||||
modelResult := gjson.GetBytes(rawJSON, "model")
|
modelResult := gjson.GetBytes(rawJSON, "model")
|
||||||
@@ -126,7 +139,7 @@ func PrepareRequest(rawJSON []byte) (string, *client.Content, []client.Content,
|
|||||||
if split := strings.Split(filename, "."); len(split) > 1 {
|
if split := strings.Split(filename, "."); len(split) > 1 {
|
||||||
ext = split[len(split)-1]
|
ext = split[len(split)-1]
|
||||||
}
|
}
|
||||||
if mimeType, ok := translator.MimeTypes[ext]; ok {
|
if mimeType, ok := misc.MimeTypes[ext]; ok {
|
||||||
parts = append(parts, client.Part{InlineData: &client.InlineData{
|
parts = append(parts, client.Part{InlineData: &client.InlineData{
|
||||||
MimeType: mimeType,
|
MimeType: mimeType,
|
||||||
Data: fileData,
|
Data: fileData,
|
||||||
@@ -15,10 +15,10 @@ import (
|
|||||||
"github.com/tidwall/sjson"
|
"github.com/tidwall/sjson"
|
||||||
)
|
)
|
||||||
|
|
||||||
// ConvertCliToOpenAI translates a single chunk of a streaming response from the
|
// ConvertCliResponseToOpenAIChat translates a single chunk of a streaming response from the
|
||||||
// backend client format to the OpenAI Server-Sent Events (SSE) format.
|
// backend client format to the OpenAI Server-Sent Events (SSE) format.
|
||||||
// It returns an empty string if the chunk contains no useful data.
|
// It returns an empty string if the chunk contains no useful data.
|
||||||
func ConvertCliToOpenAI(rawJSON []byte, unixTimestamp int64, isGlAPIKey bool) string {
|
func ConvertCliResponseToOpenAIChat(rawJSON []byte, unixTimestamp int64, isGlAPIKey bool) string {
|
||||||
if isGlAPIKey {
|
if isGlAPIKey {
|
||||||
rawJSON, _ = sjson.SetRawBytes(rawJSON, "response", rawJSON)
|
rawJSON, _ = sjson.SetRawBytes(rawJSON, "response", rawJSON)
|
||||||
}
|
}
|
||||||
@@ -109,9 +109,9 @@ func ConvertCliToOpenAI(rawJSON []byte, unixTimestamp int64, isGlAPIKey bool) st
|
|||||||
return template
|
return template
|
||||||
}
|
}
|
||||||
|
|
||||||
// ConvertCliToOpenAINonStream aggregates response from the backend client
|
// ConvertCliResponseToOpenAIChatNonStream aggregates response from the backend client
|
||||||
// convert a single, non-streaming OpenAI-compatible JSON response.
|
// convert a single, non-streaming OpenAI-compatible JSON response.
|
||||||
func ConvertCliToOpenAINonStream(rawJSON []byte, unixTimestamp int64, isGlAPIKey bool) string {
|
func ConvertCliResponseToOpenAIChatNonStream(rawJSON []byte, unixTimestamp int64, isGlAPIKey bool) string {
|
||||||
if isGlAPIKey {
|
if isGlAPIKey {
|
||||||
rawJSON, _ = sjson.SetRawBytes(rawJSON, "response", rawJSON)
|
rawJSON, _ = sjson.SetRawBytes(rawJSON, "response", rawJSON)
|
||||||
}
|
}
|
||||||
24
internal/util/provider.go
Normal file
24
internal/util/provider.go
Normal file
@@ -0,0 +1,24 @@
|
|||||||
|
// Package util provides utility functions used across the CLIProxyAPI application.
|
||||||
|
// These functions handle common tasks such as determining AI service providers
|
||||||
|
// from model names and managing HTTP proxies.
|
||||||
|
package util
|
||||||
|
|
||||||
|
import (
|
||||||
|
"strings"
|
||||||
|
)
|
||||||
|
|
||||||
|
// GetProviderName determines the AI service provider based on the model name.
|
||||||
|
// It analyzes the model name string to identify which service provider it belongs to.
|
||||||
|
//
|
||||||
|
// Supported providers:
|
||||||
|
// - "gemini" for Google's Gemini models
|
||||||
|
// - "gpt" for OpenAI's GPT models
|
||||||
|
// - "unknow" for unrecognized model names
|
||||||
|
func GetProviderName(modelName string) string {
|
||||||
|
if strings.Contains(modelName, "gemini") {
|
||||||
|
return "gemini"
|
||||||
|
} else if strings.Contains(modelName, "gpt") {
|
||||||
|
return "gpt"
|
||||||
|
}
|
||||||
|
return "unknow"
|
||||||
|
}
|
||||||
@@ -5,17 +5,19 @@ package util
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"github.com/luispater/CLIProxyAPI/internal/config"
|
|
||||||
"golang.org/x/net/proxy"
|
|
||||||
"net"
|
"net"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/url"
|
"net/url"
|
||||||
|
|
||||||
|
"github.com/luispater/CLIProxyAPI/internal/config"
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
"golang.org/x/net/proxy"
|
||||||
)
|
)
|
||||||
|
|
||||||
// SetProxy configures the provided HTTP client with proxy settings from the configuration.
|
// SetProxy configures the provided HTTP client with proxy settings from the configuration.
|
||||||
// It supports SOCKS5, HTTP, and HTTPS proxies. The function modifies the client's transport
|
// It supports SOCKS5, HTTP, and HTTPS proxies. The function modifies the client's transport
|
||||||
// to route requests through the configured proxy server.
|
// to route requests through the configured proxy server.
|
||||||
func SetProxy(cfg *config.Config, httpClient *http.Client) (*http.Client, error) {
|
func SetProxy(cfg *config.Config, httpClient *http.Client) *http.Client {
|
||||||
var transport *http.Transport
|
var transport *http.Transport
|
||||||
proxyURL, errParse := url.Parse(cfg.ProxyURL)
|
proxyURL, errParse := url.Parse(cfg.ProxyURL)
|
||||||
if errParse == nil {
|
if errParse == nil {
|
||||||
@@ -25,7 +27,8 @@ func SetProxy(cfg *config.Config, httpClient *http.Client) (*http.Client, error)
|
|||||||
proxyAuth := &proxy.Auth{User: username, Password: password}
|
proxyAuth := &proxy.Auth{User: username, Password: password}
|
||||||
dialer, errSOCKS5 := proxy.SOCKS5("tcp", proxyURL.Host, proxyAuth, proxy.Direct)
|
dialer, errSOCKS5 := proxy.SOCKS5("tcp", proxyURL.Host, proxyAuth, proxy.Direct)
|
||||||
if errSOCKS5 != nil {
|
if errSOCKS5 != nil {
|
||||||
return nil, errSOCKS5
|
log.Errorf("create SOCKS5 dialer failed: %v", errSOCKS5)
|
||||||
|
return httpClient
|
||||||
}
|
}
|
||||||
transport = &http.Transport{
|
transport = &http.Transport{
|
||||||
DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
|
DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
|
||||||
@@ -39,5 +42,5 @@ func SetProxy(cfg *config.Config, httpClient *http.Client) (*http.Client, error)
|
|||||||
if transport != nil {
|
if transport != nil {
|
||||||
httpClient.Transport = transport
|
httpClient.Transport = transport
|
||||||
}
|
}
|
||||||
return httpClient, nil
|
return httpClient
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -7,12 +7,6 @@ package watcher
|
|||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"github.com/fsnotify/fsnotify"
|
|
||||||
"github.com/luispater/CLIProxyAPI/internal/auth"
|
|
||||||
"github.com/luispater/CLIProxyAPI/internal/client"
|
|
||||||
"github.com/luispater/CLIProxyAPI/internal/config"
|
|
||||||
"github.com/luispater/CLIProxyAPI/internal/util"
|
|
||||||
log "github.com/sirupsen/logrus"
|
|
||||||
"io/fs"
|
"io/fs"
|
||||||
"net/http"
|
"net/http"
|
||||||
"os"
|
"os"
|
||||||
@@ -20,6 +14,15 @@ import (
|
|||||||
"strings"
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/fsnotify/fsnotify"
|
||||||
|
"github.com/luispater/CLIProxyAPI/internal/auth/codex"
|
||||||
|
"github.com/luispater/CLIProxyAPI/internal/auth/gemini"
|
||||||
|
"github.com/luispater/CLIProxyAPI/internal/client"
|
||||||
|
"github.com/luispater/CLIProxyAPI/internal/config"
|
||||||
|
"github.com/luispater/CLIProxyAPI/internal/util"
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
"github.com/tidwall/gjson"
|
||||||
)
|
)
|
||||||
|
|
||||||
// Watcher manages file watching for configuration and authentication files
|
// Watcher manages file watching for configuration and authentication files
|
||||||
@@ -27,14 +30,14 @@ type Watcher struct {
|
|||||||
configPath string
|
configPath string
|
||||||
authDir string
|
authDir string
|
||||||
config *config.Config
|
config *config.Config
|
||||||
clients []*client.Client
|
clients []client.Client
|
||||||
clientsMutex sync.RWMutex
|
clientsMutex sync.RWMutex
|
||||||
reloadCallback func([]*client.Client, *config.Config)
|
reloadCallback func([]client.Client, *config.Config)
|
||||||
watcher *fsnotify.Watcher
|
watcher *fsnotify.Watcher
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewWatcher creates a new file watcher instance
|
// NewWatcher creates a new file watcher instance
|
||||||
func NewWatcher(configPath, authDir string, reloadCallback func([]*client.Client, *config.Config)) (*Watcher, error) {
|
func NewWatcher(configPath, authDir string, reloadCallback func([]client.Client, *config.Config)) (*Watcher, error) {
|
||||||
watcher, errNewWatcher := fsnotify.NewWatcher()
|
watcher, errNewWatcher := fsnotify.NewWatcher()
|
||||||
if errNewWatcher != nil {
|
if errNewWatcher != nil {
|
||||||
return nil, errNewWatcher
|
return nil, errNewWatcher
|
||||||
@@ -83,7 +86,7 @@ func (w *Watcher) SetConfig(cfg *config.Config) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// SetClients updates the current client list
|
// SetClients updates the current client list
|
||||||
func (w *Watcher) SetClients(clients []*client.Client) {
|
func (w *Watcher) SetClients(clients []client.Client) {
|
||||||
w.clientsMutex.Lock()
|
w.clientsMutex.Lock()
|
||||||
defer w.clientsMutex.Unlock()
|
defer w.clientsMutex.Unlock()
|
||||||
w.clients = clients
|
w.clients = clients
|
||||||
@@ -193,7 +196,7 @@ func (w *Watcher) reloadClients() {
|
|||||||
log.Debugf("scanning auth directory: %s", cfg.AuthDir)
|
log.Debugf("scanning auth directory: %s", cfg.AuthDir)
|
||||||
|
|
||||||
// Create new client list
|
// Create new client list
|
||||||
newClients := make([]*client.Client, 0)
|
newClients := make([]client.Client, 0)
|
||||||
authFileCount := 0
|
authFileCount := 0
|
||||||
successfulAuthCount := 0
|
successfulAuthCount := 0
|
||||||
|
|
||||||
@@ -209,37 +212,57 @@ func (w *Watcher) reloadClients() {
|
|||||||
authFileCount++
|
authFileCount++
|
||||||
log.Debugf("processing auth file %d: %s", authFileCount, filepath.Base(path))
|
log.Debugf("processing auth file %d: %s", authFileCount, filepath.Base(path))
|
||||||
|
|
||||||
f, errOpen := os.Open(path)
|
data, errReadFile := os.ReadFile(path)
|
||||||
if errOpen != nil {
|
if errReadFile != nil {
|
||||||
log.Errorf("failed to open token file %s: %v", path, errOpen)
|
return errReadFile
|
||||||
return nil // Continue processing other files
|
}
|
||||||
|
|
||||||
|
tokenType := "gemini"
|
||||||
|
typeResult := gjson.GetBytes(data, "type")
|
||||||
|
if typeResult.Exists() {
|
||||||
|
tokenType = typeResult.String()
|
||||||
}
|
}
|
||||||
defer func() {
|
|
||||||
errClose := f.Close()
|
|
||||||
if errClose != nil {
|
|
||||||
log.Errorf("failed to close token file %s: %v", path, errClose)
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
|
|
||||||
// Decode the token storage file
|
// Decode the token storage file
|
||||||
var ts auth.TokenStorage
|
if tokenType == "gemini" {
|
||||||
if errDecode := json.NewDecoder(f).Decode(&ts); errDecode == nil {
|
var ts gemini.GeminiTokenStorage
|
||||||
// For each valid token, create an authenticated client
|
if err = json.Unmarshal(data, &ts); err == nil {
|
||||||
clientCtx := context.Background()
|
// For each valid token, create an authenticated client
|
||||||
log.Debugf(" initializing authentication for token from %s...", filepath.Base(path))
|
clientCtx := context.Background()
|
||||||
httpClient, errGetClient := auth.GetAuthenticatedClient(clientCtx, &ts, cfg)
|
log.Debugf(" initializing gemini authentication for token from %s...", filepath.Base(path))
|
||||||
if errGetClient != nil {
|
geminiAuth := gemini.NewGeminiAuth()
|
||||||
log.Errorf(" failed to get authenticated client for token %s: %v", path, errGetClient)
|
httpClient, errGetClient := geminiAuth.GetAuthenticatedClient(clientCtx, &ts, cfg)
|
||||||
return nil // Continue processing other files
|
if errGetClient != nil {
|
||||||
}
|
log.Errorf(" failed to get authenticated client for token %s: %v", path, errGetClient)
|
||||||
log.Debugf(" authentication successful for token from %s", filepath.Base(path))
|
return nil // Continue processing other files
|
||||||
|
}
|
||||||
|
log.Debugf(" authentication successful for token from %s", filepath.Base(path))
|
||||||
|
|
||||||
// Add the new client to the pool
|
// Add the new client to the pool
|
||||||
cliClient := client.NewClient(httpClient, &ts, cfg)
|
cliClient := client.NewGeminiClient(httpClient, &ts, cfg)
|
||||||
newClients = append(newClients, cliClient)
|
newClients = append(newClients, cliClient)
|
||||||
successfulAuthCount++
|
successfulAuthCount++
|
||||||
} else {
|
} else {
|
||||||
log.Errorf(" failed to decode token file %s: %v", path, errDecode)
|
log.Errorf(" failed to decode token file %s: %v", path, err)
|
||||||
|
}
|
||||||
|
} else if tokenType == "codex" {
|
||||||
|
var ts codex.CodexTokenStorage
|
||||||
|
if err = json.Unmarshal(data, &ts); err == nil {
|
||||||
|
// For each valid token, create an authenticated client
|
||||||
|
log.Debugf(" initializing codex authentication for token from %s...", filepath.Base(path))
|
||||||
|
codexClient, errGetClient := client.NewCodexClient(cfg, &ts)
|
||||||
|
if errGetClient != nil {
|
||||||
|
log.Errorf(" failed to get authenticated client for token %s: %v", path, errGetClient)
|
||||||
|
return nil // Continue processing other files
|
||||||
|
}
|
||||||
|
log.Debugf(" authentication successful for token from %s", filepath.Base(path))
|
||||||
|
|
||||||
|
// Add the new client to the pool
|
||||||
|
newClients = append(newClients, codexClient)
|
||||||
|
successfulAuthCount++
|
||||||
|
} else {
|
||||||
|
log.Errorf(" failed to decode token file %s: %v", path, err)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
@@ -256,14 +279,10 @@ func (w *Watcher) reloadClients() {
|
|||||||
if len(cfg.GlAPIKey) > 0 {
|
if len(cfg.GlAPIKey) > 0 {
|
||||||
log.Debugf("processing %d Generative Language API keys", len(cfg.GlAPIKey))
|
log.Debugf("processing %d Generative Language API keys", len(cfg.GlAPIKey))
|
||||||
for i := 0; i < len(cfg.GlAPIKey); i++ {
|
for i := 0; i < len(cfg.GlAPIKey); i++ {
|
||||||
httpClient, errSetProxy := util.SetProxy(cfg, &http.Client{})
|
httpClient := util.SetProxy(cfg, &http.Client{})
|
||||||
if errSetProxy != nil {
|
|
||||||
log.Errorf("set proxy failed for GL API key %d: %v", i+1, errSetProxy)
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
log.Debugf(" initializing with Generative Language API key %d...", i+1)
|
log.Debugf(" initializing with Generative Language API key %d...", i+1)
|
||||||
cliClient := client.NewClient(httpClient, nil, cfg, cfg.GlAPIKey[i])
|
cliClient := client.NewGeminiClient(httpClient, nil, cfg, cfg.GlAPIKey[i])
|
||||||
newClients = append(newClients, cliClient)
|
newClients = append(newClients, cliClient)
|
||||||
glAPIKeyCount++
|
glAPIKeyCount++
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user