mirror of
https://github.com/router-for-me/CLIProxyAPI.git
synced 2026-02-02 04:20:50 +08:00
Add Qwen support
This commit is contained in:
30
README.md
30
README.md
@@ -8,19 +8,23 @@ It now also supports OpenAI Codex (GPT models) and Claude Code via OAuth.
|
||||
|
||||
so you can use local or multi‑account CLI access with OpenAI‑compatible clients and SDKs.
|
||||
|
||||
Now, We added the first Chinese provider: [Qwen Code](https://github.com/QwenLM/qwen-code).
|
||||
|
||||
## Features
|
||||
|
||||
- OpenAI/Gemini/Claude compatible API endpoints for CLI models
|
||||
- OpenAI Codex support (GPT models) via OAuth login
|
||||
- Claude Code support via OAuth login
|
||||
- Qwen Code support via OAuth login
|
||||
- Streaming and non-streaming responses
|
||||
- Function calling/tools support
|
||||
- Multimodal input support (text and images)
|
||||
- Multiple accounts with round‑robin load balancing (Gemini, OpenAI, and Claude)
|
||||
- Simple CLI authentication flows (Gemini, OpenAI, and Claude)
|
||||
- Multiple accounts with round‑robin load balancing (Gemini, OpenAI, Claude and Qwen)
|
||||
- Simple CLI authentication flows (Gemini, OpenAI, Claude and Qwen)
|
||||
- Generative Language API Key support
|
||||
- Gemini CLI multi‑account load balancing
|
||||
- Claude Code multi‑account load balancing
|
||||
- Qwen Code multi‑account load balancing
|
||||
|
||||
## Installation
|
||||
|
||||
@@ -30,6 +34,7 @@ so you can use local or multi‑account CLI access with OpenAI‑compatible clie
|
||||
- A Google account with access to Gemini CLI models (optional)
|
||||
- An OpenAI account for Codex/GPT access (optional)
|
||||
- An Anthropic account for Claude Code access (optional)
|
||||
- A Qwen Chat account for Qwen Code access (optional)
|
||||
|
||||
### Building from Source
|
||||
|
||||
@@ -60,6 +65,8 @@ You can authenticate for Gemini, OpenAI, and/or Claude. All can coexist in the s
|
||||
```
|
||||
The local OAuth callback uses port `8085`.
|
||||
|
||||
Options: add `--no-browser` to print the login URL instead of opening a browser. The local OAuth callback uses port `1455`.
|
||||
|
||||
- OpenAI (Codex/GPT via OAuth):
|
||||
```bash
|
||||
./cli-proxy-api --codex-login
|
||||
@@ -72,6 +79,13 @@ You can authenticate for Gemini, OpenAI, and/or Claude. All can coexist in the s
|
||||
```
|
||||
Options: add `--no-browser` to print the login URL instead of opening a browser. The local OAuth callback uses port `54545`.
|
||||
|
||||
- Qwen (Qwen Chat via OAuth):
|
||||
```bash
|
||||
./cli-proxy-api --qwen-login
|
||||
```
|
||||
Options: add `--no-browser` to print the login URL instead of opening a browser. Use the Qwen Chat's OAuth device flow.
|
||||
|
||||
|
||||
### Starting the Server
|
||||
|
||||
Once authenticated, start the server:
|
||||
@@ -112,7 +126,7 @@ Request body example:
|
||||
```
|
||||
|
||||
Notes:
|
||||
- Use a `gemini-*` model for Gemini (e.g., `gemini-2.5-pro`), a `gpt-*` model for OpenAI (e.g., `gpt-5`), or a `claude-*` model for Claude (e.g., `claude-3-5-sonnet-20241022`). The proxy will route to the correct provider automatically.
|
||||
- Use a `gemini-*` model for Gemini (e.g., `gemini-2.5-pro`), a `gpt-*` model for OpenAI (e.g., `gpt-5`), a `claude-*` model for Claude (e.g., `claude-3-5-sonnet-20241022`), or a `qwen-*` model for Qwen (e.g., `qwen3-coder-plus`). The proxy will route to the correct provider automatically.
|
||||
|
||||
#### Claude Messages (SSE-compatible)
|
||||
|
||||
@@ -210,6 +224,8 @@ console.log(await claudeResponse.json());
|
||||
- claude-sonnet-4-20250514
|
||||
- claude-3-7-sonnet-20250219
|
||||
- claude-3-5-haiku-20241022
|
||||
- qwen3-coder-plus
|
||||
- qwen3-coder-flash
|
||||
- Gemini models auto‑switch to preview variants when needed
|
||||
|
||||
## Configuration
|
||||
@@ -338,6 +354,14 @@ export ANTHROPIC_MODEL=claude-sonnet-4-20250514
|
||||
export ANTHROPIC_SMALL_FAST_MODEL=claude-3-5-haiku-20241022
|
||||
```
|
||||
|
||||
Using Claude models:
|
||||
```bash
|
||||
export ANTHROPIC_BASE_URL=http://127.0.0.1:8317
|
||||
export ANTHROPIC_AUTH_TOKEN=sk-dummy
|
||||
export ANTHROPIC_MODEL=qwen3-coder-plus
|
||||
export ANTHROPIC_SMALL_FAST_MODEL=qwen3-coder-flash
|
||||
```
|
||||
|
||||
## Run with Docker
|
||||
|
||||
Run the following command to login (Gemini OAuth on port 8085):
|
||||
|
||||
27
README_CN.md
27
README_CN.md
@@ -8,19 +8,23 @@
|
||||
|
||||
可与本地或多账户方式配合,使用任何 OpenAI 兼容的客户端与 SDK。
|
||||
|
||||
现在,我们添加了第一个中国提供商:[Qwen Code](https://github.com/QwenLM/qwen-code)。
|
||||
|
||||
## 功能特性
|
||||
|
||||
- 为 CLI 模型提供 OpenAI/Gemini/Claude 兼容的 API 端点
|
||||
- 新增 OpenAI Codex(GPT 系列)支持(OAuth 登录)
|
||||
- 新增 Claude Code 支持(OAuth 登录)
|
||||
- 新增 Qwen Code 支持(OAuth 登录)
|
||||
- 支持流式与非流式响应
|
||||
- 函数调用/工具支持
|
||||
- 多模态输入(文本、图片)
|
||||
- 多账户支持与轮询负载均衡(Gemini、OpenAI 与 Claude)
|
||||
- 简单的 CLI 身份验证流程(Gemini、OpenAI 与 Claude)
|
||||
- 多账户支持与轮询负载均衡(Gemini、OpenAI、Claude 与 Qwen)
|
||||
- 简单的 CLI 身份验证流程(Gemini、OpenAI、Claude 与 Qwen)
|
||||
- 支持 Gemini AIStudio API 密钥
|
||||
- 支持 Gemini CLI 多账户轮询
|
||||
- 支持 Claude Code 多账户轮询
|
||||
- 支持 Qwen Code 多账户轮询
|
||||
|
||||
## 安装
|
||||
|
||||
@@ -30,6 +34,7 @@
|
||||
- 有权访问 Gemini CLI 模型的 Google 账户(可选)
|
||||
- 有权访问 OpenAI Codex/GPT 的 OpenAI 账户(可选)
|
||||
- 有权访问 Claude Code 的 Anthropic 账户(可选)
|
||||
- 有权访问 Qwen Code 的 Qwen Chat 账户(可选)
|
||||
|
||||
### 从源码构建
|
||||
|
||||
@@ -72,6 +77,12 @@
|
||||
```
|
||||
选项:加上 `--no-browser` 可打印登录地址而不自动打开浏览器。本地 OAuth 回调端口为 `54545`。
|
||||
|
||||
- Qwen(Qwen Chat,OAuth):
|
||||
```bash
|
||||
./cli-proxy-api --qwen-login
|
||||
```
|
||||
选项:加上 `--no-browser` 可打印登录地址而不自动打开浏览器。使用 Qwen Chat 的 OAuth 设备登录流程。
|
||||
|
||||
### 启动服务器
|
||||
|
||||
身份验证完成后,启动服务器:
|
||||
@@ -112,7 +123,7 @@ POST http://localhost:8317/v1/chat/completions
|
||||
```
|
||||
|
||||
说明:
|
||||
- 使用 `gemini-*` 模型(如 `gemini-2.5-pro`)走 Gemini,使用 `gpt-*` 模型(如 `gpt-5`)走 OpenAI,使用 `claude-*` 模型(如 `claude-3-5-sonnet-20241022`)走 Claude,服务会自动路由到对应提供商。
|
||||
- 使用 `gemini-*` 模型(如 `gemini-2.5-pro`)走 Gemini,使用 `gpt-*` 模型(如 `gpt-5`)走 OpenAI,使用 `claude-*` 模型(如 `claude-3-5-sonnet-20241022`)走 Claude,使用 `qwen-*` 模型(如 `qwen3-coder-plus`)走 Qwen,服务会自动路由到对应提供商。
|
||||
|
||||
#### Claude 消息(SSE 兼容)
|
||||
|
||||
@@ -210,6 +221,8 @@ console.log(await claudeResponse.json());
|
||||
- claude-sonnet-4-20250514
|
||||
- claude-3-7-sonnet-20250219
|
||||
- claude-3-5-haiku-20241022
|
||||
- qwen3-coder-plus
|
||||
- qwen3-coder-flash
|
||||
- Gemini 模型在需要时自动切换到对应的 preview 版本
|
||||
|
||||
## 配置
|
||||
@@ -338,6 +351,14 @@ export ANTHROPIC_MODEL=claude-sonnet-4-20250514
|
||||
export ANTHROPIC_SMALL_FAST_MODEL=claude-3-5-haiku-20241022
|
||||
```
|
||||
|
||||
使用 Qwen 模型:
|
||||
```bash
|
||||
export ANTHROPIC_BASE_URL=http://127.0.0.1:8317
|
||||
export ANTHROPIC_AUTH_TOKEN=sk-dummy
|
||||
export ANTHROPIC_MODEL=qwen3-coder-plus
|
||||
export ANTHROPIC_SMALL_FAST_MODEL=qwen3-coder-flash
|
||||
```
|
||||
|
||||
|
||||
## 使用 Docker 运行
|
||||
|
||||
|
||||
@@ -60,6 +60,7 @@ func main() {
|
||||
var login bool
|
||||
var codexLogin bool
|
||||
var claudeLogin bool
|
||||
var qwenLogin bool
|
||||
var noBrowser bool
|
||||
var projectID string
|
||||
var configPath string
|
||||
@@ -68,6 +69,7 @@ func main() {
|
||||
flag.BoolVar(&login, "login", false, "Login Google Account")
|
||||
flag.BoolVar(&codexLogin, "codex-login", false, "Login to Codex using OAuth")
|
||||
flag.BoolVar(&claudeLogin, "claude-login", false, "Login to Claude using OAuth")
|
||||
flag.BoolVar(&qwenLogin, "qwen-login", false, "Login to Qwen using OAuth")
|
||||
flag.BoolVar(&noBrowser, "no-browser", false, "Don't open browser automatically for OAuth")
|
||||
flag.StringVar(&projectID, "project_id", "", "Project ID (Gemini only, not required)")
|
||||
flag.StringVar(&configPath, "config", "", "Configure File Path")
|
||||
@@ -132,6 +134,8 @@ func main() {
|
||||
} else if claudeLogin {
|
||||
// Handle Claude login
|
||||
cmd.DoClaudeLogin(cfg, options)
|
||||
} else if qwenLogin {
|
||||
cmd.DoQwenLogin(cfg, options)
|
||||
} else {
|
||||
// Start the main proxy service
|
||||
cmd.StartService(cfg, configFilePath)
|
||||
|
||||
@@ -11,6 +11,7 @@ import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
@@ -18,6 +19,7 @@ import (
|
||||
"github.com/luispater/CLIProxyAPI/internal/client"
|
||||
translatorClaudeCodeToCodex "github.com/luispater/CLIProxyAPI/internal/translator/codex/claude/code"
|
||||
translatorClaudeCodeToGeminiCli "github.com/luispater/CLIProxyAPI/internal/translator/gemini-cli/claude/code"
|
||||
translatorClaudeCodeToQwen "github.com/luispater/CLIProxyAPI/internal/translator/openai/claude"
|
||||
"github.com/luispater/CLIProxyAPI/internal/util"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"github.com/tidwall/gjson"
|
||||
@@ -62,7 +64,7 @@ func (h *ClaudeCodeAPIHandlers) ClaudeMessages(c *gin.Context) {
|
||||
|
||||
// Check if the client requested a streaming response.
|
||||
streamResult := gjson.GetBytes(rawJSON, "stream")
|
||||
if streamResult.Type == gjson.False {
|
||||
if !streamResult.Exists() || streamResult.Type == gjson.False {
|
||||
return
|
||||
}
|
||||
|
||||
@@ -72,6 +74,8 @@ func (h *ClaudeCodeAPIHandlers) ClaudeMessages(c *gin.Context) {
|
||||
h.handleCodexStreamingResponse(c, rawJSON)
|
||||
} else if provider == "claude" {
|
||||
h.handleClaudeStreamingResponse(c, rawJSON)
|
||||
} else if provider == "qwen" {
|
||||
h.handleQwenStreamingResponse(c, rawJSON)
|
||||
} else {
|
||||
h.handleGeminiStreamingResponse(c, rawJSON)
|
||||
}
|
||||
@@ -518,3 +522,149 @@ outLoop:
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// handleQwenStreamingResponse streams Claude-compatible responses backed by OpenAI.
|
||||
// It converts the Claude request into Qwen responses format, establishes SSE,
|
||||
// and translates streaming chunks back into Claude Code events.
|
||||
func (h *ClaudeCodeAPIHandlers) handleQwenStreamingResponse(c *gin.Context, rawJSON []byte) {
|
||||
// Set up Server-Sent Events (SSE) headers for streaming response
|
||||
// These headers are essential for maintaining a persistent connection
|
||||
// and enabling real-time streaming of chat completions
|
||||
c.Header("Content-Type", "text/event-stream")
|
||||
c.Header("Cache-Control", "no-cache")
|
||||
c.Header("Connection", "keep-alive")
|
||||
c.Header("Access-Control-Allow-Origin", "*")
|
||||
|
||||
// Get the http.Flusher interface to manually flush the response.
|
||||
// This is crucial for streaming as it allows immediate sending of data chunks
|
||||
flusher, ok := c.Writer.(http.Flusher)
|
||||
if !ok {
|
||||
c.JSON(http.StatusInternalServerError, handlers.ErrorResponse{
|
||||
Error: handlers.ErrorDetail{
|
||||
Message: "Streaming not supported",
|
||||
Type: "server_error",
|
||||
},
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// Parse and prepare the Claude request, extracting model name, system instructions,
|
||||
// conversation contents, and available tools from the raw JSON
|
||||
newRequestJSON := translatorClaudeCodeToQwen.ConvertAnthropicRequestToOpenAI(rawJSON)
|
||||
modelName := gjson.GetBytes(rawJSON, "model").String()
|
||||
|
||||
newRequestJSON, _ = sjson.Set(newRequestJSON, "model", modelName)
|
||||
// log.Debugf(string(rawJSON))
|
||||
// log.Debugf(newRequestJSON)
|
||||
// return
|
||||
// Create a cancellable context for the backend client request
|
||||
// This allows proper cleanup and cancellation of ongoing requests
|
||||
cliCtx, cliCancel := h.GetContextWithCancel(c, context.Background())
|
||||
|
||||
var cliClient client.Client
|
||||
defer func() {
|
||||
// Ensure the client's mutex is unlocked on function exit.
|
||||
// This prevents deadlocks and ensures proper resource cleanup
|
||||
if cliClient != nil {
|
||||
cliClient.GetRequestMutex().Unlock()
|
||||
}
|
||||
}()
|
||||
|
||||
// Main client rotation loop with quota management
|
||||
// This loop implements a sophisticated load balancing and failover mechanism
|
||||
outLoop:
|
||||
for {
|
||||
var errorResponse *client.ErrorMessage
|
||||
cliClient, errorResponse = h.GetClient(modelName)
|
||||
if errorResponse != nil {
|
||||
c.Status(errorResponse.StatusCode)
|
||||
_, _ = fmt.Fprint(c.Writer, errorResponse.Error.Error())
|
||||
flusher.Flush()
|
||||
cliCancel()
|
||||
return
|
||||
}
|
||||
|
||||
log.Debugf("Request use qwen account: %s", cliClient.GetEmail())
|
||||
|
||||
// Initiate streaming communication with the backend client
|
||||
// This returns two channels: one for response chunks and one for errors
|
||||
respChan, errChan := cliClient.SendRawMessageStream(cliCtx, []byte(newRequestJSON), "")
|
||||
|
||||
// Track response state for proper Claude format conversion
|
||||
|
||||
params := &translatorClaudeCodeToQwen.ConvertOpenAIResponseToAnthropicParams{
|
||||
MessageID: "",
|
||||
Model: "",
|
||||
CreatedAt: 0,
|
||||
ContentAccumulator: strings.Builder{},
|
||||
ToolCallsAccumulator: nil,
|
||||
}
|
||||
|
||||
// Main streaming loop - handles multiple concurrent events using Go channels
|
||||
// This select statement manages four different types of events simultaneously
|
||||
for {
|
||||
select {
|
||||
// Case 1: Handle client disconnection
|
||||
// Detects when the HTTP client has disconnected and cleans up resources
|
||||
case <-c.Request.Context().Done():
|
||||
if c.Request.Context().Err().Error() == "context canceled" {
|
||||
log.Debugf("CodexClient disconnected: %v", c.Request.Context().Err())
|
||||
cliCancel() // Cancel the backend request to prevent resource leaks
|
||||
return
|
||||
}
|
||||
|
||||
// Case 2: Process incoming response chunks from the backend
|
||||
// This handles the actual streaming data from the AI model
|
||||
case chunk, okStream := <-respChan:
|
||||
if !okStream {
|
||||
flusher.Flush()
|
||||
cliCancel()
|
||||
return
|
||||
}
|
||||
|
||||
h.AddAPIResponseData(c, chunk)
|
||||
h.AddAPIResponseData(c, []byte("\n"))
|
||||
|
||||
// Convert the backend response to Claude-compatible format
|
||||
// This translation layer ensures API compatibility
|
||||
if bytes.HasPrefix(chunk, []byte("data: ")) {
|
||||
jsonData := chunk[6:]
|
||||
outputs := translatorClaudeCodeToQwen.ConvertOpenAIResponseToAnthropic(jsonData, params)
|
||||
if len(outputs) > 0 {
|
||||
for i := 0; i < len(outputs); i++ {
|
||||
_, _ = c.Writer.Write([]byte("data: "))
|
||||
_, _ = c.Writer.Write([]byte(outputs[i]))
|
||||
}
|
||||
}
|
||||
flusher.Flush() // Immediately send the chunk to the client
|
||||
// hasFirstResponse = true
|
||||
} else {
|
||||
// log.Debugf("chunk: %s", string(chunk))
|
||||
}
|
||||
// Case 3: Handle errors from the backend
|
||||
// This manages various error conditions and implements retry logic
|
||||
case errInfo, okError := <-errChan:
|
||||
if okError {
|
||||
// log.Debugf("Code: %d, Error: %v", errInfo.StatusCode, errInfo.Error)
|
||||
// Special handling for quota exceeded errors
|
||||
// If configured, attempt to switch to a different project/client
|
||||
if errInfo.StatusCode == 429 && h.Cfg.QuotaExceeded.SwitchProject {
|
||||
log.Debugf("quota exceeded, switch client")
|
||||
continue outLoop // Restart the client selection process
|
||||
} else {
|
||||
// Forward other errors directly to the client
|
||||
c.Status(errInfo.StatusCode)
|
||||
_, _ = fmt.Fprint(c.Writer, errInfo.Error.Error())
|
||||
flusher.Flush()
|
||||
cliCancel(errInfo.Error)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// Case 4: Send periodic keep-alive signals
|
||||
// Prevents connection timeouts during long-running requests
|
||||
case <-time.After(3000 * time.Millisecond):
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -18,6 +18,7 @@ import (
|
||||
"github.com/luispater/CLIProxyAPI/internal/client"
|
||||
translatorGeminiToClaude "github.com/luispater/CLIProxyAPI/internal/translator/claude/gemini"
|
||||
translatorGeminiToCodex "github.com/luispater/CLIProxyAPI/internal/translator/codex/gemini"
|
||||
translatorGeminiToQwen "github.com/luispater/CLIProxyAPI/internal/translator/openai/gemini"
|
||||
"github.com/luispater/CLIProxyAPI/internal/util"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"github.com/tidwall/gjson"
|
||||
@@ -64,6 +65,8 @@ func (h *GeminiCLIAPIHandlers) CLIHandler(c *gin.Context) {
|
||||
h.handleCodexInternalGenerateContent(c, rawJSON)
|
||||
} else if provider == "claude" {
|
||||
h.handleClaudeInternalGenerateContent(c, rawJSON)
|
||||
} else if provider == "qwen" {
|
||||
h.handleQwenInternalGenerateContent(c, rawJSON)
|
||||
}
|
||||
} else if requestRawURI == "/v1internal:streamGenerateContent" {
|
||||
if provider == "gemini" || provider == "unknow" {
|
||||
@@ -72,6 +75,8 @@ func (h *GeminiCLIAPIHandlers) CLIHandler(c *gin.Context) {
|
||||
h.handleCodexInternalStreamGenerateContent(c, rawJSON)
|
||||
} else if provider == "claude" {
|
||||
h.handleClaudeInternalStreamGenerateContent(c, rawJSON)
|
||||
} else if provider == "qwen" {
|
||||
h.handleQwenInternalStreamGenerateContent(c, rawJSON)
|
||||
}
|
||||
} else {
|
||||
reqBody := bytes.NewBuffer(rawJSON)
|
||||
@@ -733,3 +738,180 @@ outLoop:
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (h *GeminiCLIAPIHandlers) handleQwenInternalStreamGenerateContent(c *gin.Context, rawJSON []byte) {
|
||||
c.Header("Content-Type", "text/event-stream")
|
||||
c.Header("Cache-Control", "no-cache")
|
||||
c.Header("Connection", "keep-alive")
|
||||
c.Header("Access-Control-Allow-Origin", "*")
|
||||
|
||||
// Get the http.Flusher interface to manually flush the response.
|
||||
flusher, ok := c.Writer.(http.Flusher)
|
||||
if !ok {
|
||||
c.JSON(http.StatusInternalServerError, handlers.ErrorResponse{
|
||||
Error: handlers.ErrorDetail{
|
||||
Message: "Streaming not supported",
|
||||
Type: "server_error",
|
||||
},
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
modelResult := gjson.GetBytes(rawJSON, "model")
|
||||
rawJSON = []byte(gjson.GetBytes(rawJSON, "request").Raw)
|
||||
rawJSON, _ = sjson.SetBytes(rawJSON, "model", modelResult.String())
|
||||
rawJSON, _ = sjson.SetRawBytes(rawJSON, "system_instruction", []byte(gjson.GetBytes(rawJSON, "systemInstruction").Raw))
|
||||
rawJSON, _ = sjson.DeleteBytes(rawJSON, "systemInstruction")
|
||||
|
||||
// Prepare the request for the backend client.
|
||||
newRequestJSON := translatorGeminiToQwen.ConvertGeminiRequestToOpenAI(rawJSON)
|
||||
newRequestJSON, _ = sjson.Set(newRequestJSON, "stream", true)
|
||||
|
||||
// log.Debugf("Request: %s", string(rawJSON))
|
||||
// return
|
||||
|
||||
modelName := gjson.GetBytes(rawJSON, "model")
|
||||
|
||||
cliCtx, cliCancel := h.GetContextWithCancel(c, context.Background())
|
||||
|
||||
var cliClient client.Client
|
||||
defer func() {
|
||||
// Ensure the client's mutex is unlocked on function exit.
|
||||
if cliClient != nil {
|
||||
cliClient.GetRequestMutex().Unlock()
|
||||
}
|
||||
}()
|
||||
|
||||
outLoop:
|
||||
for {
|
||||
var errorResponse *client.ErrorMessage
|
||||
cliClient, errorResponse = h.GetClient(modelName.String())
|
||||
if errorResponse != nil {
|
||||
c.Status(errorResponse.StatusCode)
|
||||
_, _ = fmt.Fprint(c.Writer, errorResponse.Error.Error())
|
||||
flusher.Flush()
|
||||
cliCancel()
|
||||
return
|
||||
}
|
||||
|
||||
log.Debugf("Request qwen use account: %s", cliClient.(*client.QwenClient).GetEmail())
|
||||
|
||||
// Send the message and receive response chunks and errors via channels.
|
||||
respChan, errChan := cliClient.SendRawMessageStream(cliCtx, []byte(newRequestJSON), "")
|
||||
|
||||
params := &translatorGeminiToQwen.ConvertOpenAIResponseToGeminiParams{
|
||||
ToolCallsAccumulator: nil,
|
||||
ContentAccumulator: strings.Builder{},
|
||||
IsFirstChunk: false,
|
||||
}
|
||||
for {
|
||||
select {
|
||||
// Handle client disconnection.
|
||||
case <-c.Request.Context().Done():
|
||||
if c.Request.Context().Err().Error() == "context canceled" {
|
||||
log.Debugf("CodexClient disconnected: %v", c.Request.Context().Err())
|
||||
cliCancel() // Cancel the backend request.
|
||||
return
|
||||
}
|
||||
// Process incoming response chunks.
|
||||
case chunk, okStream := <-respChan:
|
||||
if !okStream {
|
||||
cliCancel()
|
||||
return
|
||||
}
|
||||
|
||||
h.AddAPIResponseData(c, chunk)
|
||||
h.AddAPIResponseData(c, []byte("\n\n"))
|
||||
|
||||
if bytes.HasPrefix(chunk, []byte("data: ")) {
|
||||
jsonData := chunk[6:]
|
||||
// log.Debugf(string(jsonData))
|
||||
outputs := translatorGeminiToQwen.ConvertOpenAIResponseToGemini(jsonData, params)
|
||||
if len(outputs) > 0 {
|
||||
for i := 0; i < len(outputs); i++ {
|
||||
outputs[i], _ = sjson.SetRaw("{}", "response", outputs[i])
|
||||
_, _ = c.Writer.Write([]byte("data: "))
|
||||
_, _ = c.Writer.Write([]byte(outputs[i]))
|
||||
_, _ = c.Writer.Write([]byte("\n\n"))
|
||||
}
|
||||
}
|
||||
// log.Debugf(string(jsonData))
|
||||
}
|
||||
flusher.Flush()
|
||||
// Handle errors from the backend.
|
||||
case err, okError := <-errChan:
|
||||
if okError {
|
||||
if err.StatusCode == 429 && h.Cfg.QuotaExceeded.SwitchProject {
|
||||
continue outLoop
|
||||
} else {
|
||||
c.Status(err.StatusCode)
|
||||
_, _ = fmt.Fprint(c.Writer, err.Error.Error())
|
||||
flusher.Flush()
|
||||
cliCancel(err.Error)
|
||||
}
|
||||
return
|
||||
}
|
||||
// Send a keep-alive signal to the client.
|
||||
case <-time.After(500 * time.Millisecond):
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (h *GeminiCLIAPIHandlers) handleQwenInternalGenerateContent(c *gin.Context, rawJSON []byte) {
|
||||
c.Header("Content-Type", "application/json")
|
||||
|
||||
modelResult := gjson.GetBytes(rawJSON, "model")
|
||||
rawJSON = []byte(gjson.GetBytes(rawJSON, "request").Raw)
|
||||
rawJSON, _ = sjson.SetBytes(rawJSON, "model", modelResult.String())
|
||||
rawJSON, _ = sjson.SetRawBytes(rawJSON, "system_instruction", []byte(gjson.GetBytes(rawJSON, "systemInstruction").Raw))
|
||||
rawJSON, _ = sjson.DeleteBytes(rawJSON, "systemInstruction")
|
||||
|
||||
// Prepare the request for the backend client.
|
||||
newRequestJSON := translatorGeminiToQwen.ConvertGeminiRequestToOpenAI(rawJSON)
|
||||
// log.Debugf("Request: %s", newRequestJSON)
|
||||
|
||||
modelName := gjson.GetBytes(rawJSON, "model")
|
||||
|
||||
cliCtx, cliCancel := h.GetContextWithCancel(c, context.Background())
|
||||
|
||||
var cliClient client.Client
|
||||
defer func() {
|
||||
if cliClient != nil {
|
||||
cliClient.GetRequestMutex().Unlock()
|
||||
}
|
||||
}()
|
||||
|
||||
for {
|
||||
var errorResponse *client.ErrorMessage
|
||||
cliClient, errorResponse = h.GetClient(modelName.String())
|
||||
if errorResponse != nil {
|
||||
c.Status(errorResponse.StatusCode)
|
||||
_, _ = fmt.Fprint(c.Writer, errorResponse.Error.Error())
|
||||
cliCancel()
|
||||
return
|
||||
}
|
||||
|
||||
log.Debugf("Request use qwen account: %s", cliClient.GetEmail())
|
||||
|
||||
resp, err := cliClient.SendRawMessage(cliCtx, []byte(newRequestJSON), "")
|
||||
if err != nil {
|
||||
if err.StatusCode == 429 && h.Cfg.QuotaExceeded.SwitchProject {
|
||||
continue
|
||||
} else {
|
||||
c.Status(err.StatusCode)
|
||||
_, _ = c.Writer.Write([]byte(err.Error.Error()))
|
||||
cliCancel(err.Error)
|
||||
}
|
||||
break
|
||||
} else {
|
||||
h.AddAPIResponseData(c, resp)
|
||||
h.AddAPIResponseData(c, []byte("\n"))
|
||||
|
||||
newResp := translatorGeminiToQwen.ConvertOpenAINonStreamResponseToGemini(resp)
|
||||
_, _ = c.Writer.Write([]byte(newResp))
|
||||
cliCancel(resp)
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -19,6 +19,7 @@ import (
|
||||
translatorGeminiToClaude "github.com/luispater/CLIProxyAPI/internal/translator/claude/gemini"
|
||||
translatorGeminiToCodex "github.com/luispater/CLIProxyAPI/internal/translator/codex/gemini"
|
||||
translatorGeminiToGeminiCli "github.com/luispater/CLIProxyAPI/internal/translator/gemini-cli/gemini/cli"
|
||||
translatorGeminiToQwen "github.com/luispater/CLIProxyAPI/internal/translator/openai/gemini"
|
||||
"github.com/luispater/CLIProxyAPI/internal/util"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"github.com/tidwall/gjson"
|
||||
@@ -241,6 +242,13 @@ func (h *GeminiAPIHandlers) GeminiHandler(c *gin.Context) {
|
||||
case "streamGenerateContent":
|
||||
h.handleClaudeStreamGenerateContent(c, rawJSON)
|
||||
}
|
||||
} else if provider == "qwen" {
|
||||
switch method {
|
||||
case "generateContent":
|
||||
h.handleQwenGenerateContent(c, rawJSON)
|
||||
case "streamGenerateContent":
|
||||
h.handleQwenStreamGenerateContent(c, rawJSON)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -961,3 +969,163 @@ outLoop:
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (h *GeminiAPIHandlers) handleQwenStreamGenerateContent(c *gin.Context, rawJSON []byte) {
|
||||
c.Header("Content-Type", "text/event-stream")
|
||||
c.Header("Cache-Control", "no-cache")
|
||||
c.Header("Connection", "keep-alive")
|
||||
c.Header("Access-Control-Allow-Origin", "*")
|
||||
|
||||
// Get the http.Flusher interface to manually flush the response.
|
||||
flusher, ok := c.Writer.(http.Flusher)
|
||||
if !ok {
|
||||
c.JSON(http.StatusInternalServerError, handlers.ErrorResponse{
|
||||
Error: handlers.ErrorDetail{
|
||||
Message: "Streaming not supported",
|
||||
Type: "server_error",
|
||||
},
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// Prepare the request for the backend client.
|
||||
newRequestJSON := translatorGeminiToQwen.ConvertGeminiRequestToOpenAI(rawJSON)
|
||||
newRequestJSON, _ = sjson.Set(newRequestJSON, "stream", true)
|
||||
// log.Debugf("Request: %s", newRequestJSON)
|
||||
|
||||
modelName := gjson.GetBytes(rawJSON, "model")
|
||||
|
||||
cliCtx, cliCancel := h.GetContextWithCancel(c, context.Background())
|
||||
|
||||
var cliClient client.Client
|
||||
defer func() {
|
||||
// Ensure the client's mutex is unlocked on function exit.
|
||||
if cliClient != nil {
|
||||
cliClient.GetRequestMutex().Unlock()
|
||||
}
|
||||
}()
|
||||
|
||||
outLoop:
|
||||
for {
|
||||
var errorResponse *client.ErrorMessage
|
||||
cliClient, errorResponse = h.GetClient(modelName.String())
|
||||
if errorResponse != nil {
|
||||
c.Status(errorResponse.StatusCode)
|
||||
_, _ = fmt.Fprint(c.Writer, errorResponse.Error.Error())
|
||||
flusher.Flush()
|
||||
cliCancel()
|
||||
return
|
||||
}
|
||||
|
||||
log.Debugf("Request use qwen account: %s", cliClient.GetEmail())
|
||||
|
||||
// Send the message and receive response chunks and errors via channels.
|
||||
respChan, errChan := cliClient.SendRawMessageStream(cliCtx, []byte(newRequestJSON), "")
|
||||
|
||||
params := &translatorGeminiToQwen.ConvertOpenAIResponseToGeminiParams{
|
||||
ToolCallsAccumulator: nil,
|
||||
ContentAccumulator: strings.Builder{},
|
||||
IsFirstChunk: false,
|
||||
}
|
||||
for {
|
||||
select {
|
||||
// Handle client disconnection.
|
||||
case <-c.Request.Context().Done():
|
||||
if c.Request.Context().Err().Error() == "context canceled" {
|
||||
log.Debugf("CodexClient disconnected: %v", c.Request.Context().Err())
|
||||
cliCancel() // Cancel the backend request.
|
||||
return
|
||||
}
|
||||
// Process incoming response chunks.
|
||||
case chunk, okStream := <-respChan:
|
||||
if !okStream {
|
||||
cliCancel()
|
||||
return
|
||||
}
|
||||
|
||||
h.AddAPIResponseData(c, chunk)
|
||||
h.AddAPIResponseData(c, []byte("\n\n"))
|
||||
if bytes.HasPrefix(chunk, []byte("data: ")) {
|
||||
jsonData := chunk[6:]
|
||||
outputs := translatorGeminiToQwen.ConvertOpenAIResponseToGemini(jsonData, params)
|
||||
if len(outputs) > 0 {
|
||||
for i := 0; i < len(outputs); i++ {
|
||||
_, _ = c.Writer.Write([]byte("data: "))
|
||||
_, _ = c.Writer.Write([]byte(outputs[i]))
|
||||
_, _ = c.Writer.Write([]byte("\n\n"))
|
||||
}
|
||||
}
|
||||
// log.Debugf(string(jsonData))
|
||||
}
|
||||
flusher.Flush()
|
||||
// Handle errors from the backend.
|
||||
case err, okError := <-errChan:
|
||||
if okError {
|
||||
if err.StatusCode == 429 && h.Cfg.QuotaExceeded.SwitchProject {
|
||||
continue outLoop
|
||||
} else {
|
||||
c.Status(err.StatusCode)
|
||||
_, _ = fmt.Fprint(c.Writer, err.Error.Error())
|
||||
flusher.Flush()
|
||||
cliCancel(err.Error)
|
||||
}
|
||||
return
|
||||
}
|
||||
// Send a keep-alive signal to the client.
|
||||
case <-time.After(500 * time.Millisecond):
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (h *GeminiAPIHandlers) handleQwenGenerateContent(c *gin.Context, rawJSON []byte) {
|
||||
c.Header("Content-Type", "application/json")
|
||||
|
||||
// Prepare the request for the backend client.
|
||||
newRequestJSON := translatorGeminiToQwen.ConvertGeminiRequestToOpenAI(rawJSON)
|
||||
// log.Debugf("Request: %s", newRequestJSON)
|
||||
|
||||
modelName := gjson.GetBytes(rawJSON, "model")
|
||||
|
||||
cliCtx, cliCancel := h.GetContextWithCancel(c, context.Background())
|
||||
|
||||
var cliClient client.Client
|
||||
defer func() {
|
||||
if cliClient != nil {
|
||||
cliClient.GetRequestMutex().Unlock()
|
||||
}
|
||||
}()
|
||||
|
||||
for {
|
||||
var errorResponse *client.ErrorMessage
|
||||
cliClient, errorResponse = h.GetClient(modelName.String())
|
||||
if errorResponse != nil {
|
||||
c.Status(errorResponse.StatusCode)
|
||||
_, _ = fmt.Fprint(c.Writer, errorResponse.Error.Error())
|
||||
cliCancel()
|
||||
return
|
||||
}
|
||||
|
||||
log.Debugf("Request use qwen account: %s", cliClient.GetEmail())
|
||||
|
||||
resp, err := cliClient.SendRawMessage(cliCtx, []byte(newRequestJSON), "")
|
||||
if err != nil {
|
||||
if err.StatusCode == 429 && h.Cfg.QuotaExceeded.SwitchProject {
|
||||
continue
|
||||
} else {
|
||||
c.Status(err.StatusCode)
|
||||
_, _ = c.Writer.Write([]byte(err.Error.Error()))
|
||||
cliCancel(err.Error)
|
||||
}
|
||||
break
|
||||
} else {
|
||||
h.AddAPIResponseData(c, resp)
|
||||
h.AddAPIResponseData(c, []byte("\n"))
|
||||
|
||||
newResp := translatorGeminiToQwen.ConvertOpenAINonStreamResponseToGemini(resp)
|
||||
_, _ = c.Writer.Write([]byte(newResp))
|
||||
cliCancel(resp)
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -118,6 +118,12 @@ func (h *APIHandlers) GetClient(modelName string, isGenerateContent ...bool) (cl
|
||||
clients = append(clients, cli)
|
||||
}
|
||||
}
|
||||
} else if provider == "qwen" {
|
||||
for i := 0; i < len(h.CliClients); i++ {
|
||||
if cli, ok := h.CliClients[i].(*client.QwenClient); ok {
|
||||
clients = append(clients, cli)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if _, hasKey := h.LastUsedClientIndex[provider]; !hasKey {
|
||||
@@ -150,6 +156,8 @@ func (h *APIHandlers) GetClient(modelName string, isGenerateContent ...bool) (cl
|
||||
log.Debugf("Codex Model %s is quota exceeded for account %s", modelName, cliClient.GetEmail())
|
||||
} else if provider == "claude" {
|
||||
log.Debugf("Claude Model %s is quota exceeded for account %s", modelName, cliClient.GetEmail())
|
||||
} else if provider == "qwen" {
|
||||
log.Debugf("Qwen Model %s is quota exceeded for account %s", modelName, cliClient.GetEmail())
|
||||
}
|
||||
cliClient = nil
|
||||
continue
|
||||
|
||||
@@ -171,6 +171,13 @@ func (h *OpenAIAPIHandlers) ChatCompletions(c *gin.Context) {
|
||||
} else {
|
||||
h.handleClaudeNonStreamingResponse(c, rawJSON)
|
||||
}
|
||||
} else if provider == "qwen" {
|
||||
// qwen3-coder-plus / qwen3-coder-flash
|
||||
if streamResult.Type == gjson.True {
|
||||
h.handleQwenStreamingResponse(c, rawJSON)
|
||||
} else {
|
||||
h.handleQwenNonStreamingResponse(c, rawJSON)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -761,3 +768,155 @@ outLoop:
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// handleQwenNonStreamingResponse handles non-streaming chat completion responses
|
||||
// for Qwen models. It selects a client from the pool, sends the request, and
|
||||
// aggregates the response before sending it back to the client in OpenAI format.
|
||||
//
|
||||
// Parameters:
|
||||
// - c: The Gin context containing the HTTP request and response
|
||||
// - rawJSON: The raw JSON bytes of the OpenAI-compatible request
|
||||
func (h *OpenAIAPIHandlers) handleQwenNonStreamingResponse(c *gin.Context, rawJSON []byte) {
|
||||
c.Header("Content-Type", "application/json")
|
||||
|
||||
modelResult := gjson.GetBytes(rawJSON, "model")
|
||||
modelName := modelResult.String()
|
||||
cliCtx, cliCancel := h.GetContextWithCancel(c, context.Background())
|
||||
|
||||
var cliClient client.Client
|
||||
defer func() {
|
||||
if cliClient != nil {
|
||||
cliClient.GetRequestMutex().Unlock()
|
||||
}
|
||||
}()
|
||||
|
||||
for {
|
||||
var errorResponse *client.ErrorMessage
|
||||
cliClient, errorResponse = h.GetClient(modelName)
|
||||
if errorResponse != nil {
|
||||
c.Status(errorResponse.StatusCode)
|
||||
_, _ = fmt.Fprint(c.Writer, errorResponse.Error.Error())
|
||||
cliCancel()
|
||||
return
|
||||
}
|
||||
|
||||
log.Debugf("Request qwen use account: %s", cliClient.(*client.QwenClient).GetEmail())
|
||||
|
||||
resp, err := cliClient.SendRawMessage(cliCtx, rawJSON, modelName)
|
||||
if err != nil {
|
||||
if err.StatusCode == 429 && h.Cfg.QuotaExceeded.SwitchProject {
|
||||
continue
|
||||
} else {
|
||||
c.Status(err.StatusCode)
|
||||
_, _ = c.Writer.Write([]byte(err.Error.Error()))
|
||||
cliCancel(err.Error)
|
||||
}
|
||||
break
|
||||
} else {
|
||||
_, _ = c.Writer.Write(resp)
|
||||
cliCancel(resp)
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// handleQwenStreamingResponse handles streaming responses for Qwen models.
|
||||
// It establishes a streaming connection with the backend service and forwards
|
||||
// the response chunks to the client in real-time using Server-Sent Events.
|
||||
//
|
||||
// Parameters:
|
||||
// - c: The Gin context containing the HTTP request and response
|
||||
// - rawJSON: The raw JSON bytes of the OpenAI-compatible request
|
||||
func (h *OpenAIAPIHandlers) handleQwenStreamingResponse(c *gin.Context, rawJSON []byte) {
|
||||
c.Header("Content-Type", "text/event-stream")
|
||||
c.Header("Cache-Control", "no-cache")
|
||||
c.Header("Connection", "keep-alive")
|
||||
c.Header("Access-Control-Allow-Origin", "*")
|
||||
|
||||
// Get the http.Flusher interface to manually flush the response.
|
||||
flusher, ok := c.Writer.(http.Flusher)
|
||||
if !ok {
|
||||
c.JSON(http.StatusInternalServerError, handlers.ErrorResponse{
|
||||
Error: handlers.ErrorDetail{
|
||||
Message: "Streaming not supported",
|
||||
Type: "server_error",
|
||||
},
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// Prepare the request for the backend client.
|
||||
modelResult := gjson.GetBytes(rawJSON, "model")
|
||||
modelName := modelResult.String()
|
||||
|
||||
cliCtx, cliCancel := h.GetContextWithCancel(c, context.Background())
|
||||
|
||||
var cliClient client.Client
|
||||
defer func() {
|
||||
// Ensure the client's mutex is unlocked on function exit.
|
||||
if cliClient != nil {
|
||||
cliClient.GetRequestMutex().Unlock()
|
||||
}
|
||||
}()
|
||||
|
||||
outLoop:
|
||||
for {
|
||||
var errorResponse *client.ErrorMessage
|
||||
cliClient, errorResponse = h.GetClient(modelName)
|
||||
if errorResponse != nil {
|
||||
c.Status(errorResponse.StatusCode)
|
||||
_, _ = fmt.Fprint(c.Writer, errorResponse.Error.Error())
|
||||
flusher.Flush()
|
||||
cliCancel()
|
||||
return
|
||||
}
|
||||
|
||||
log.Debugf("Request qwen use account: %s", cliClient.(*client.QwenClient).GetEmail())
|
||||
|
||||
// Send the message and receive response chunks and errors via channels.
|
||||
respChan, errChan := cliClient.SendRawMessageStream(cliCtx, rawJSON, modelName)
|
||||
|
||||
for {
|
||||
select {
|
||||
// Handle client disconnection.
|
||||
case <-c.Request.Context().Done():
|
||||
if c.Request.Context().Err().Error() == "context canceled" {
|
||||
log.Debugf("GeminiClient disconnected: %v", c.Request.Context().Err())
|
||||
cliCancel() // Cancel the backend request.
|
||||
return
|
||||
}
|
||||
// Process incoming response chunks.
|
||||
case chunk, okStream := <-respChan:
|
||||
if !okStream {
|
||||
flusher.Flush()
|
||||
cliCancel()
|
||||
return
|
||||
}
|
||||
|
||||
h.AddAPIResponseData(c, chunk)
|
||||
h.AddAPIResponseData(c, []byte("\n"))
|
||||
|
||||
// Convert the chunk to OpenAI format and send it to the client.
|
||||
_, _ = c.Writer.Write(chunk)
|
||||
_, _ = c.Writer.Write([]byte("\n"))
|
||||
|
||||
flusher.Flush()
|
||||
// Handle errors from the backend.
|
||||
case err, okError := <-errChan:
|
||||
if okError {
|
||||
if err.StatusCode == 429 && h.Cfg.QuotaExceeded.SwitchProject {
|
||||
continue outLoop
|
||||
} else {
|
||||
c.Status(err.StatusCode)
|
||||
_, _ = fmt.Fprint(c.Writer, err.Error.Error())
|
||||
flusher.Flush()
|
||||
cliCancel(err.Error)
|
||||
}
|
||||
return
|
||||
}
|
||||
// Send a keep-alive signal to the client.
|
||||
case <-time.After(500 * time.Millisecond):
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -66,7 +66,7 @@ func NewServer(cfg *config.Config, cliClients []client.Client) *Server {
|
||||
requestLogger := logging.NewFileRequestLogger(cfg.RequestLog, "logs")
|
||||
engine.Use(middleware.RequestLoggingMiddleware(requestLogger))
|
||||
|
||||
// engine.Use(corsMiddleware())
|
||||
engine.Use(corsMiddleware())
|
||||
|
||||
// Create server instance
|
||||
s := &Server{
|
||||
|
||||
337
internal/auth/qwen/qwen_auth.go
Normal file
337
internal/auth/qwen/qwen_auth.go
Normal file
@@ -0,0 +1,337 @@
|
||||
package qwen
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"crypto/sha256"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/luispater/CLIProxyAPI/internal/config"
|
||||
"github.com/luispater/CLIProxyAPI/internal/util"
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
const (
|
||||
// OAuth Configuration
|
||||
QwenOAuthDeviceCodeEndpoint = "https://chat.qwen.ai/api/v1/oauth2/device/code"
|
||||
QwenOAuthTokenEndpoint = "https://chat.qwen.ai/api/v1/oauth2/token"
|
||||
QwenOAuthClientID = "f0304373b74a44d2b584a3fb70ca9e56"
|
||||
QwenOAuthScope = "openid profile email model.completion"
|
||||
QwenOAuthGrantType = "urn:ietf:params:oauth:grant-type:device_code"
|
||||
)
|
||||
|
||||
// QwenTokenData represents OAuth credentials
|
||||
type QwenTokenData struct {
|
||||
AccessToken string `json:"access_token"`
|
||||
RefreshToken string `json:"refresh_token,omitempty"`
|
||||
TokenType string `json:"token_type"`
|
||||
ResourceURL string `json:"resource_url,omitempty"`
|
||||
Expire string `json:"expiry_date,omitempty"`
|
||||
}
|
||||
|
||||
// DeviceFlow represents device flow response
|
||||
type DeviceFlow struct {
|
||||
DeviceCode string `json:"device_code"`
|
||||
UserCode string `json:"user_code"`
|
||||
VerificationURI string `json:"verification_uri"`
|
||||
VerificationURIComplete string `json:"verification_uri_complete"`
|
||||
ExpiresIn int `json:"expires_in"`
|
||||
Interval int `json:"interval"`
|
||||
CodeVerifier string `json:"code_verifier"`
|
||||
}
|
||||
|
||||
// QwenTokenResponse represents token response
|
||||
type QwenTokenResponse struct {
|
||||
AccessToken string `json:"access_token"`
|
||||
RefreshToken string `json:"refresh_token,omitempty"`
|
||||
TokenType string `json:"token_type"`
|
||||
ResourceURL string `json:"resource_url,omitempty"`
|
||||
ExpiresIn int `json:"expires_in"`
|
||||
}
|
||||
|
||||
// QwenAuth manages authentication and credentials
|
||||
type QwenAuth struct {
|
||||
httpClient *http.Client
|
||||
}
|
||||
|
||||
// NewQwenAuth creates a new QwenAuth
|
||||
func NewQwenAuth(cfg *config.Config) *QwenAuth {
|
||||
return &QwenAuth{
|
||||
httpClient: util.SetProxy(cfg, &http.Client{}),
|
||||
}
|
||||
}
|
||||
|
||||
// generateCodeVerifier generates a random code verifier for PKCE
|
||||
func (qa *QwenAuth) generateCodeVerifier() (string, error) {
|
||||
bytes := make([]byte, 32)
|
||||
if _, err := rand.Read(bytes); err != nil {
|
||||
return "", err
|
||||
}
|
||||
return base64.RawURLEncoding.EncodeToString(bytes), nil
|
||||
}
|
||||
|
||||
// generateCodeChallenge generates a code challenge from a code verifier using SHA-256
|
||||
func (qa *QwenAuth) generateCodeChallenge(codeVerifier string) string {
|
||||
hash := sha256.Sum256([]byte(codeVerifier))
|
||||
return base64.RawURLEncoding.EncodeToString(hash[:])
|
||||
}
|
||||
|
||||
// generatePKCEPair generates PKCE code verifier and challenge pair
|
||||
func (qa *QwenAuth) generatePKCEPair() (string, string, error) {
|
||||
codeVerifier, err := qa.generateCodeVerifier()
|
||||
if err != nil {
|
||||
return "", "", err
|
||||
}
|
||||
codeChallenge := qa.generateCodeChallenge(codeVerifier)
|
||||
return codeVerifier, codeChallenge, nil
|
||||
}
|
||||
|
||||
// RefreshTokens refreshes the access token using refresh token
|
||||
func (qa *QwenAuth) RefreshTokens(ctx context.Context, refreshToken string) (*QwenTokenData, error) {
|
||||
data := url.Values{}
|
||||
data.Set("grant_type", "refresh_token")
|
||||
data.Set("refresh_token", refreshToken)
|
||||
data.Set("client_id", QwenOAuthClientID)
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, "POST", QwenOAuthTokenEndpoint, strings.NewReader(data.Encode()))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create token request: %w", err)
|
||||
}
|
||||
|
||||
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
||||
req.Header.Set("Accept", "application/json")
|
||||
|
||||
resp, err := qa.httpClient.Do(req)
|
||||
|
||||
// resp, err := qa.httpClient.PostForm(QwenOAuthTokenEndpoint, data)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("token refresh request failed: %w", err)
|
||||
}
|
||||
defer func() {
|
||||
_ = resp.Body.Close()
|
||||
}()
|
||||
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to read response body: %w", err)
|
||||
}
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
var errorData map[string]interface{}
|
||||
if err = json.Unmarshal(body, &errorData); err == nil {
|
||||
return nil, fmt.Errorf("token refresh failed: %v - %v", errorData["error"], errorData["error_description"])
|
||||
}
|
||||
return nil, fmt.Errorf("token refresh failed: %s", string(body))
|
||||
}
|
||||
|
||||
var tokenData QwenTokenResponse
|
||||
if err = json.Unmarshal(body, &tokenData); err != nil {
|
||||
return nil, fmt.Errorf("failed to parse token response: %w", err)
|
||||
}
|
||||
|
||||
return &QwenTokenData{
|
||||
AccessToken: tokenData.AccessToken,
|
||||
TokenType: tokenData.TokenType,
|
||||
RefreshToken: tokenData.RefreshToken,
|
||||
ResourceURL: tokenData.ResourceURL,
|
||||
Expire: time.Now().Add(time.Duration(tokenData.ExpiresIn) * time.Second).Format(time.RFC3339),
|
||||
}, nil
|
||||
}
|
||||
|
||||
// InitiateDeviceFlow initiates the OAuth device flow
|
||||
func (qa *QwenAuth) InitiateDeviceFlow(ctx context.Context) (*DeviceFlow, error) {
|
||||
// Generate PKCE code verifier and challenge
|
||||
codeVerifier, codeChallenge, err := qa.generatePKCEPair()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to generate PKCE pair: %w", err)
|
||||
}
|
||||
|
||||
data := url.Values{}
|
||||
data.Set("client_id", QwenOAuthClientID)
|
||||
data.Set("scope", QwenOAuthScope)
|
||||
data.Set("code_challenge", codeChallenge)
|
||||
data.Set("code_challenge_method", "S256")
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, "POST", QwenOAuthDeviceCodeEndpoint, strings.NewReader(data.Encode()))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create token request: %w", err)
|
||||
}
|
||||
|
||||
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
||||
req.Header.Set("Accept", "application/json")
|
||||
|
||||
resp, err := qa.httpClient.Do(req)
|
||||
|
||||
// resp, err := qa.httpClient.PostForm(QwenOAuthDeviceCodeEndpoint, data)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("device authorization request failed: %w", err)
|
||||
}
|
||||
defer func() {
|
||||
_ = resp.Body.Close()
|
||||
}()
|
||||
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to read response body: %w", err)
|
||||
}
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return nil, fmt.Errorf("device authorization failed: %d %s. Response: %s", resp.StatusCode, resp.Status, string(body))
|
||||
}
|
||||
|
||||
var result DeviceFlow
|
||||
if err = json.Unmarshal(body, &result); err != nil {
|
||||
return nil, fmt.Errorf("failed to parse device flow response: %w", err)
|
||||
}
|
||||
|
||||
// Check if the response indicates success
|
||||
if result.DeviceCode == "" {
|
||||
return nil, fmt.Errorf("device authorization failed: device_code not found in response")
|
||||
}
|
||||
|
||||
// Add the code_verifier to the result so it can be used later for polling
|
||||
result.CodeVerifier = codeVerifier
|
||||
|
||||
return &result, nil
|
||||
}
|
||||
|
||||
// PollForToken polls for the access token using device code
|
||||
func (qa *QwenAuth) PollForToken(deviceCode, codeVerifier string) (*QwenTokenData, error) {
|
||||
pollInterval := 5 * time.Second
|
||||
maxAttempts := 60 // 5 minutes max
|
||||
|
||||
for attempt := 0; attempt < maxAttempts; attempt++ {
|
||||
data := url.Values{}
|
||||
data.Set("grant_type", QwenOAuthGrantType)
|
||||
data.Set("client_id", QwenOAuthClientID)
|
||||
data.Set("device_code", deviceCode)
|
||||
data.Set("code_verifier", codeVerifier)
|
||||
|
||||
resp, err := http.PostForm(QwenOAuthTokenEndpoint, data)
|
||||
if err != nil {
|
||||
fmt.Printf("Polling attempt %d/%d failed: %v\n", attempt+1, maxAttempts, err)
|
||||
time.Sleep(pollInterval)
|
||||
continue
|
||||
}
|
||||
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
_ = resp.Body.Close()
|
||||
if err != nil {
|
||||
fmt.Printf("Polling attempt %d/%d failed: %v\n", attempt+1, maxAttempts, err)
|
||||
time.Sleep(pollInterval)
|
||||
continue
|
||||
}
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
// Parse the response as JSON to check for OAuth RFC 8628 standard errors
|
||||
var errorData map[string]interface{}
|
||||
if err = json.Unmarshal(body, &errorData); err == nil {
|
||||
// According to OAuth RFC 8628, handle standard polling responses
|
||||
if resp.StatusCode == http.StatusBadRequest {
|
||||
errorType, _ := errorData["error"].(string)
|
||||
switch errorType {
|
||||
case "authorization_pending":
|
||||
// User has not yet approved the authorization request. Continue polling.
|
||||
log.Infof("Polling attempt %d/%d...\n", attempt+1, maxAttempts)
|
||||
time.Sleep(pollInterval)
|
||||
continue
|
||||
case "slow_down":
|
||||
// Client is polling too frequently. Increase poll interval.
|
||||
pollInterval = time.Duration(float64(pollInterval) * 1.5)
|
||||
if pollInterval > 10*time.Second {
|
||||
pollInterval = 10 * time.Second
|
||||
}
|
||||
log.Infof("Server requested to slow down, increasing poll interval to %v\n", pollInterval)
|
||||
time.Sleep(pollInterval)
|
||||
continue
|
||||
case "expired_token":
|
||||
return nil, fmt.Errorf("device code expired. Please restart the authentication process")
|
||||
case "access_denied":
|
||||
return nil, fmt.Errorf("authorization denied by user. Please restart the authentication process")
|
||||
}
|
||||
}
|
||||
|
||||
// For other errors, return with proper error information
|
||||
errorType, _ := errorData["error"].(string)
|
||||
errorDesc, _ := errorData["error_description"].(string)
|
||||
return nil, fmt.Errorf("device token poll failed: %s - %s", errorType, errorDesc)
|
||||
}
|
||||
|
||||
// If JSON parsing fails, fall back to text response
|
||||
return nil, fmt.Errorf("device token poll failed: %d %s. Response: %s", resp.StatusCode, resp.Status, string(body))
|
||||
}
|
||||
log.Debugf(string(body))
|
||||
// Success - parse token data
|
||||
var response QwenTokenResponse
|
||||
if err = json.Unmarshal(body, &response); err != nil {
|
||||
return nil, fmt.Errorf("failed to parse token response: %w", err)
|
||||
}
|
||||
|
||||
// Convert to QwenTokenData format and save
|
||||
tokenData := &QwenTokenData{
|
||||
AccessToken: response.AccessToken,
|
||||
RefreshToken: response.RefreshToken,
|
||||
TokenType: response.TokenType,
|
||||
ResourceURL: response.ResourceURL,
|
||||
Expire: time.Now().Add(time.Duration(response.ExpiresIn) * time.Second).Format(time.RFC3339),
|
||||
}
|
||||
|
||||
return tokenData, nil
|
||||
}
|
||||
|
||||
return nil, fmt.Errorf("authentication timeout. Please restart the authentication process")
|
||||
}
|
||||
|
||||
// RefreshTokensWithRetry refreshes tokens with automatic retry logic
|
||||
func (o *QwenAuth) RefreshTokensWithRetry(ctx context.Context, refreshToken string, maxRetries int) (*QwenTokenData, error) {
|
||||
var lastErr error
|
||||
|
||||
for attempt := 0; attempt < maxRetries; attempt++ {
|
||||
if attempt > 0 {
|
||||
// Wait before retry
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return nil, ctx.Err()
|
||||
case <-time.After(time.Duration(attempt) * time.Second):
|
||||
}
|
||||
}
|
||||
|
||||
tokenData, err := o.RefreshTokens(ctx, refreshToken)
|
||||
if err == nil {
|
||||
return tokenData, nil
|
||||
}
|
||||
|
||||
lastErr = err
|
||||
log.Warnf("Token refresh attempt %d failed: %v", attempt+1, err)
|
||||
}
|
||||
|
||||
return nil, fmt.Errorf("token refresh failed after %d attempts: %w", maxRetries, lastErr)
|
||||
}
|
||||
|
||||
func (o *QwenAuth) CreateTokenStorage(tokenData *QwenTokenData) *QwenTokenStorage {
|
||||
storage := &QwenTokenStorage{
|
||||
AccessToken: tokenData.AccessToken,
|
||||
RefreshToken: tokenData.RefreshToken,
|
||||
LastRefresh: time.Now().Format(time.RFC3339),
|
||||
ResourceURL: tokenData.ResourceURL,
|
||||
Expire: tokenData.Expire,
|
||||
}
|
||||
|
||||
return storage
|
||||
}
|
||||
|
||||
// UpdateTokenStorage updates an existing token storage with new token data
|
||||
func (o *QwenAuth) UpdateTokenStorage(storage *QwenTokenStorage, tokenData *QwenTokenData) {
|
||||
storage.AccessToken = tokenData.AccessToken
|
||||
storage.RefreshToken = tokenData.RefreshToken
|
||||
storage.LastRefresh = time.Now().Format(time.RFC3339)
|
||||
storage.ResourceURL = tokenData.ResourceURL
|
||||
storage.Expire = tokenData.Expire
|
||||
}
|
||||
61
internal/auth/qwen/qwen_token.go
Normal file
61
internal/auth/qwen/qwen_token.go
Normal file
@@ -0,0 +1,61 @@
|
||||
// Package gemini provides authentication and token management functionality
|
||||
// for Google's Gemini AI services. It handles OAuth2 token storage, serialization,
|
||||
// and retrieval for maintaining authenticated sessions with the Gemini API.
|
||||
package qwen
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"os"
|
||||
"path"
|
||||
)
|
||||
|
||||
// QwenTokenStorage defines the structure for storing OAuth2 token information,
|
||||
// along with associated user and project details. This data is typically
|
||||
// serialized to a JSON file for persistence.
|
||||
type QwenTokenStorage struct {
|
||||
// AccessToken is the OAuth2 access token for API access
|
||||
AccessToken string `json:"access_token"`
|
||||
// RefreshToken is used to obtain new access tokens
|
||||
RefreshToken string `json:"refresh_token"`
|
||||
// LastRefresh is the timestamp of the last token refresh
|
||||
LastRefresh string `json:"last_refresh"`
|
||||
// ResourceURL is the request base url
|
||||
ResourceURL string `json:"resource_url"`
|
||||
// Email is the OpenAI account email
|
||||
Email string `json:"email"`
|
||||
// Type indicates the type (gemini, chatgpt, claude) of token storage.
|
||||
Type string `json:"type"`
|
||||
// Expire is the timestamp of the token expire
|
||||
Expire string `json:"expired"`
|
||||
}
|
||||
|
||||
// SaveTokenToFile serializes the token storage to a JSON file.
|
||||
// This method creates the necessary directory structure and writes the token
|
||||
// data in JSON format to the specified file path. It ensures the file is
|
||||
// properly closed after writing.
|
||||
//
|
||||
// Parameters:
|
||||
// - authFilePath: The full path where the token file should be saved
|
||||
//
|
||||
// Returns:
|
||||
// - error: An error if the operation fails, nil otherwise
|
||||
func (ts *QwenTokenStorage) SaveTokenToFile(authFilePath string) error {
|
||||
ts.Type = "qwen"
|
||||
if err := os.MkdirAll(path.Dir(authFilePath), 0700); err != nil {
|
||||
return fmt.Errorf("failed to create directory: %v", err)
|
||||
}
|
||||
|
||||
f, err := os.Create(authFilePath)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create token file: %w", err)
|
||||
}
|
||||
defer func() {
|
||||
_ = f.Close()
|
||||
}()
|
||||
|
||||
if err = json.NewEncoder(f).Encode(ts); err != nil {
|
||||
return fmt.Errorf("failed to write token to file: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
288
internal/client/qwen_client.go
Normal file
288
internal/client/qwen_client.go
Normal file
@@ -0,0 +1,288 @@
|
||||
package client
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/luispater/CLIProxyAPI/internal/auth"
|
||||
"github.com/luispater/CLIProxyAPI/internal/auth/qwen"
|
||||
"github.com/luispater/CLIProxyAPI/internal/config"
|
||||
"github.com/luispater/CLIProxyAPI/internal/util"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"github.com/tidwall/gjson"
|
||||
"github.com/tidwall/sjson"
|
||||
)
|
||||
|
||||
const (
|
||||
qwenEndpoint = "https://portal.qwen.ai/v1"
|
||||
)
|
||||
|
||||
// QwenClient implements the Client interface for OpenAI API
|
||||
type QwenClient struct {
|
||||
ClientBase
|
||||
qwenAuth *qwen.QwenAuth
|
||||
}
|
||||
|
||||
// NewQwenClient creates a new OpenAI client instance
|
||||
func NewQwenClient(cfg *config.Config, ts *qwen.QwenTokenStorage) *QwenClient {
|
||||
httpClient := util.SetProxy(cfg, &http.Client{})
|
||||
client := &QwenClient{
|
||||
ClientBase: ClientBase{
|
||||
RequestMutex: &sync.Mutex{},
|
||||
httpClient: httpClient,
|
||||
cfg: cfg,
|
||||
modelQuotaExceeded: make(map[string]*time.Time),
|
||||
tokenStorage: ts,
|
||||
},
|
||||
qwenAuth: qwen.NewQwenAuth(cfg),
|
||||
}
|
||||
|
||||
return client
|
||||
}
|
||||
|
||||
// GetUserAgent returns the user agent string for OpenAI API requests
|
||||
func (c *QwenClient) GetUserAgent() string {
|
||||
return "google-api-nodejs-client/9.15.1"
|
||||
}
|
||||
|
||||
func (c *QwenClient) TokenStorage() auth.TokenStorage {
|
||||
return c.tokenStorage
|
||||
}
|
||||
|
||||
// SendMessage sends a message to OpenAI API (non-streaming)
|
||||
func (c *QwenClient) SendMessage(_ context.Context, _ []byte, _ string, _ *Content, _ []Content, _ []ToolDeclaration) ([]byte, *ErrorMessage) {
|
||||
// For now, return an error as OpenAI integration is not fully implemented
|
||||
return nil, &ErrorMessage{
|
||||
StatusCode: http.StatusNotImplemented,
|
||||
Error: fmt.Errorf("qwen message sending not yet implemented"),
|
||||
}
|
||||
}
|
||||
|
||||
// SendMessageStream sends a streaming message to OpenAI API
|
||||
func (c *QwenClient) SendMessageStream(_ context.Context, _ []byte, _ string, _ *Content, _ []Content, _ []ToolDeclaration, _ ...bool) (<-chan []byte, <-chan *ErrorMessage) {
|
||||
errChan := make(chan *ErrorMessage, 1)
|
||||
errChan <- &ErrorMessage{
|
||||
StatusCode: http.StatusNotImplemented,
|
||||
Error: fmt.Errorf("qwen streaming not yet implemented"),
|
||||
}
|
||||
close(errChan)
|
||||
|
||||
return nil, errChan
|
||||
}
|
||||
|
||||
// SendRawMessage sends a raw message to OpenAI API
|
||||
func (c *QwenClient) SendRawMessage(ctx context.Context, rawJSON []byte, alt string) ([]byte, *ErrorMessage) {
|
||||
modelResult := gjson.GetBytes(rawJSON, "model")
|
||||
model := modelResult.String()
|
||||
modelName := model
|
||||
|
||||
respBody, err := c.APIRequest(ctx, "/chat/completions", rawJSON, alt, false)
|
||||
if err != nil {
|
||||
if err.StatusCode == 429 {
|
||||
now := time.Now()
|
||||
c.modelQuotaExceeded[modelName] = &now
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
delete(c.modelQuotaExceeded, modelName)
|
||||
bodyBytes, errReadAll := io.ReadAll(respBody)
|
||||
if errReadAll != nil {
|
||||
return nil, &ErrorMessage{StatusCode: 500, Error: errReadAll}
|
||||
}
|
||||
return bodyBytes, nil
|
||||
|
||||
}
|
||||
|
||||
// SendRawMessageStream sends a raw streaming message to OpenAI API
|
||||
func (c *QwenClient) SendRawMessageStream(ctx context.Context, rawJSON []byte, alt string) (<-chan []byte, <-chan *ErrorMessage) {
|
||||
errChan := make(chan *ErrorMessage)
|
||||
dataChan := make(chan []byte)
|
||||
go func() {
|
||||
defer close(errChan)
|
||||
defer close(dataChan)
|
||||
|
||||
modelResult := gjson.GetBytes(rawJSON, "model")
|
||||
model := modelResult.String()
|
||||
modelName := model
|
||||
var stream io.ReadCloser
|
||||
for {
|
||||
var err *ErrorMessage
|
||||
stream, err = c.APIRequest(ctx, "/chat/completions", rawJSON, alt, true)
|
||||
if err != nil {
|
||||
if err.StatusCode == 429 {
|
||||
now := time.Now()
|
||||
c.modelQuotaExceeded[modelName] = &now
|
||||
}
|
||||
errChan <- err
|
||||
return
|
||||
}
|
||||
delete(c.modelQuotaExceeded, modelName)
|
||||
break
|
||||
}
|
||||
|
||||
scanner := bufio.NewScanner(stream)
|
||||
buffer := make([]byte, 10240*1024)
|
||||
scanner.Buffer(buffer, 10240*1024)
|
||||
for scanner.Scan() {
|
||||
line := scanner.Bytes()
|
||||
dataChan <- line
|
||||
}
|
||||
|
||||
if errScanner := scanner.Err(); errScanner != nil {
|
||||
errChan <- &ErrorMessage{500, errScanner, nil}
|
||||
_ = stream.Close()
|
||||
return
|
||||
}
|
||||
|
||||
_ = stream.Close()
|
||||
}()
|
||||
|
||||
return dataChan, errChan
|
||||
}
|
||||
|
||||
// SendRawTokenCount sends a token count request to OpenAI API
|
||||
func (c *QwenClient) SendRawTokenCount(_ context.Context, _ []byte, _ string) ([]byte, *ErrorMessage) {
|
||||
return nil, &ErrorMessage{
|
||||
StatusCode: http.StatusNotImplemented,
|
||||
Error: fmt.Errorf("qwen token counting not yet implemented"),
|
||||
}
|
||||
}
|
||||
|
||||
// SaveTokenToFile persists the token storage to disk
|
||||
func (c *QwenClient) SaveTokenToFile() error {
|
||||
fileName := filepath.Join(c.cfg.AuthDir, fmt.Sprintf("qwen-%s.json", c.tokenStorage.(*qwen.QwenTokenStorage).Email))
|
||||
return c.tokenStorage.SaveTokenToFile(fileName)
|
||||
}
|
||||
|
||||
// RefreshTokens refreshes the access tokens if needed
|
||||
func (c *QwenClient) RefreshTokens(ctx context.Context) error {
|
||||
if c.tokenStorage == nil || c.tokenStorage.(*qwen.QwenTokenStorage).RefreshToken == "" {
|
||||
return fmt.Errorf("no refresh token available")
|
||||
}
|
||||
|
||||
// Refresh tokens using the auth service
|
||||
newTokenData, err := c.qwenAuth.RefreshTokensWithRetry(ctx, c.tokenStorage.(*qwen.QwenTokenStorage).RefreshToken, 3)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to refresh tokens: %w", err)
|
||||
}
|
||||
|
||||
// Update token storage
|
||||
c.qwenAuth.UpdateTokenStorage(c.tokenStorage.(*qwen.QwenTokenStorage), newTokenData)
|
||||
|
||||
// Save updated tokens
|
||||
if err = c.SaveTokenToFile(); err != nil {
|
||||
log.Warnf("Failed to save refreshed tokens: %v", err)
|
||||
}
|
||||
|
||||
log.Debug("qwen tokens refreshed successfully")
|
||||
return nil
|
||||
}
|
||||
|
||||
// APIRequest handles making requests to the CLI API endpoints.
|
||||
func (c *QwenClient) APIRequest(ctx context.Context, endpoint string, body interface{}, _ string, _ bool) (io.ReadCloser, *ErrorMessage) {
|
||||
var jsonBody []byte
|
||||
var err error
|
||||
if byteBody, ok := body.([]byte); ok {
|
||||
jsonBody = byteBody
|
||||
} else {
|
||||
jsonBody, err = json.Marshal(body)
|
||||
if err != nil {
|
||||
return nil, &ErrorMessage{500, fmt.Errorf("failed to marshal request body: %w", err), nil}
|
||||
}
|
||||
}
|
||||
|
||||
streamResult := gjson.GetBytes(jsonBody, "stream")
|
||||
if streamResult.Exists() && streamResult.Type == gjson.True {
|
||||
jsonBody, _ = sjson.SetBytes(jsonBody, "stream_options.include_usage", true)
|
||||
}
|
||||
|
||||
var url string
|
||||
if c.tokenStorage.(*qwen.QwenTokenStorage).ResourceURL == "" {
|
||||
url = fmt.Sprintf("https://%s/v1%s", c.tokenStorage.(*qwen.QwenTokenStorage).ResourceURL, endpoint)
|
||||
} else {
|
||||
url = fmt.Sprintf("%s%s", qwenEndpoint, endpoint)
|
||||
}
|
||||
|
||||
// log.Debug(string(jsonBody))
|
||||
// log.Debug(url)
|
||||
reqBody := bytes.NewBuffer(jsonBody)
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, "POST", url, reqBody)
|
||||
if err != nil {
|
||||
return nil, &ErrorMessage{500, fmt.Errorf("failed to create request: %v", err), nil}
|
||||
}
|
||||
|
||||
// Set headers
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("User-Agent", c.GetUserAgent())
|
||||
req.Header.Set("X-Goog-Api-Client", "gl-node/22.17.0")
|
||||
req.Header.Set("Client-Metadata", c.getClientMetadataString())
|
||||
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", c.tokenStorage.(*qwen.QwenTokenStorage).AccessToken))
|
||||
|
||||
if ginContext, ok := ctx.Value("gin").(*gin.Context); ok {
|
||||
ginContext.Set("API_REQUEST", jsonBody)
|
||||
}
|
||||
|
||||
resp, err := c.httpClient.Do(req)
|
||||
if err != nil {
|
||||
return nil, &ErrorMessage{500, fmt.Errorf("failed to execute request: %v", err), nil}
|
||||
}
|
||||
|
||||
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
|
||||
defer func() {
|
||||
if err = resp.Body.Close(); err != nil {
|
||||
log.Printf("warn: failed to close response body: %v", err)
|
||||
}
|
||||
}()
|
||||
bodyBytes, _ := io.ReadAll(resp.Body)
|
||||
// log.Debug(string(jsonBody))
|
||||
return nil, &ErrorMessage{resp.StatusCode, fmt.Errorf(string(bodyBytes)), nil}
|
||||
}
|
||||
|
||||
return resp.Body, nil
|
||||
}
|
||||
|
||||
func (c *QwenClient) getClientMetadata() map[string]string {
|
||||
return map[string]string{
|
||||
"ideType": "IDE_UNSPECIFIED",
|
||||
"platform": "PLATFORM_UNSPECIFIED",
|
||||
"pluginType": "GEMINI",
|
||||
// "pluginVersion": pluginVersion,
|
||||
}
|
||||
}
|
||||
|
||||
func (c *QwenClient) getClientMetadataString() string {
|
||||
md := c.getClientMetadata()
|
||||
parts := make([]string, 0, len(md))
|
||||
for k, v := range md {
|
||||
parts = append(parts, fmt.Sprintf("%s=%s", k, v))
|
||||
}
|
||||
return strings.Join(parts, ",")
|
||||
}
|
||||
|
||||
func (c *QwenClient) GetEmail() string {
|
||||
return c.tokenStorage.(*qwen.QwenTokenStorage).Email
|
||||
}
|
||||
|
||||
// IsModelQuotaExceeded returns true if the specified model has exceeded its quota
|
||||
// and no fallback options are available.
|
||||
func (c *QwenClient) IsModelQuotaExceeded(model string) bool {
|
||||
if lastExceededTime, hasKey := c.modelQuotaExceeded[model]; hasKey {
|
||||
duration := time.Now().Sub(*lastExceededTime)
|
||||
if duration > 30*time.Minute {
|
||||
return false
|
||||
}
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
85
internal/cmd/qwen_login.go
Normal file
85
internal/cmd/qwen_login.go
Normal file
@@ -0,0 +1,85 @@
|
||||
package cmd
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"os"
|
||||
|
||||
"github.com/luispater/CLIProxyAPI/internal/auth/qwen"
|
||||
"github.com/luispater/CLIProxyAPI/internal/browser"
|
||||
"github.com/luispater/CLIProxyAPI/internal/client"
|
||||
"github.com/luispater/CLIProxyAPI/internal/config"
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
// DoQwenLogin handles the Qwen OAuth login process
|
||||
func DoQwenLogin(cfg *config.Config, options *LoginOptions) {
|
||||
if options == nil {
|
||||
options = &LoginOptions{}
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
log.Info("Initializing Qwen authentication...")
|
||||
|
||||
// Initialize Qwen auth service
|
||||
qwenAuth := qwen.NewQwenAuth(cfg)
|
||||
|
||||
// Generate authorization URL
|
||||
deviceFlow, err := qwenAuth.InitiateDeviceFlow(ctx)
|
||||
if err != nil {
|
||||
log.Fatalf("Failed to generate authorization URL: %v", err)
|
||||
return
|
||||
}
|
||||
authURL := deviceFlow.VerificationURIComplete
|
||||
|
||||
// Open browser or display URL
|
||||
if !options.NoBrowser {
|
||||
log.Info("Opening browser for authentication...")
|
||||
|
||||
// Check if browser is available
|
||||
if !browser.IsAvailable() {
|
||||
log.Warn("No browser available on this system")
|
||||
log.Infof("Please manually open this URL in your browser:\n\n%s\n", authURL)
|
||||
} else {
|
||||
if err = browser.OpenURL(authURL); err != nil {
|
||||
log.Infof("Please manually open this URL in your browser:\n\n%s\n", authURL)
|
||||
|
||||
// Log platform info for debugging
|
||||
platformInfo := browser.GetPlatformInfo()
|
||||
log.Debugf("Browser platform info: %+v", platformInfo)
|
||||
} else {
|
||||
log.Debug("Browser opened successfully")
|
||||
}
|
||||
}
|
||||
} else {
|
||||
log.Infof("Please open this URL in your browser:\n\n%s\n", authURL)
|
||||
}
|
||||
|
||||
log.Info("Waiting for authentication...")
|
||||
tokenData, err := qwenAuth.PollForToken(deviceFlow.DeviceCode, deviceFlow.CodeVerifier)
|
||||
if err != nil {
|
||||
fmt.Printf("Authentication failed: %v\n", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
// Create token storage
|
||||
tokenStorage := qwenAuth.CreateTokenStorage(tokenData)
|
||||
|
||||
// Initialize Qwen client
|
||||
qwenClient := client.NewQwenClient(cfg, tokenStorage)
|
||||
|
||||
fmt.Println("\nPlease input your email address or any alias:")
|
||||
var email string
|
||||
_, _ = fmt.Scanln(&email)
|
||||
tokenStorage.Email = email
|
||||
|
||||
// Save token storage
|
||||
if err = qwenClient.SaveTokenToFile(); err != nil {
|
||||
log.Fatalf("Failed to save authentication tokens: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
log.Info("Authentication successful!")
|
||||
log.Info("You can now use Qwen services through this CLI")
|
||||
}
|
||||
@@ -22,6 +22,7 @@ import (
|
||||
"github.com/luispater/CLIProxyAPI/internal/auth/claude"
|
||||
"github.com/luispater/CLIProxyAPI/internal/auth/codex"
|
||||
"github.com/luispater/CLIProxyAPI/internal/auth/gemini"
|
||||
"github.com/luispater/CLIProxyAPI/internal/auth/qwen"
|
||||
"github.com/luispater/CLIProxyAPI/internal/client"
|
||||
"github.com/luispater/CLIProxyAPI/internal/config"
|
||||
"github.com/luispater/CLIProxyAPI/internal/util"
|
||||
@@ -102,6 +103,15 @@ func StartService(cfg *config.Config, configPath string) {
|
||||
log.Info("Authentication successful.")
|
||||
cliClients = append(cliClients, claudeClient)
|
||||
}
|
||||
} else if tokenType == "qwen" {
|
||||
var ts qwen.QwenTokenStorage
|
||||
if err = json.Unmarshal(data, &ts); err == nil {
|
||||
// For each valid token, create an authenticated client.
|
||||
log.Info("Initializing qwen authentication for token...")
|
||||
qwenClient := client.NewQwenClient(cfg, &ts)
|
||||
log.Info("Authentication successful.")
|
||||
cliClients = append(cliClients, qwenClient)
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
@@ -200,12 +210,23 @@ func StartService(cfg *config.Config, configPath string) {
|
||||
if ts != nil && ts.Expire != "" {
|
||||
if expTime, errParse := time.Parse(time.RFC3339, ts.Expire); errParse == nil {
|
||||
if time.Until(expTime) <= 4*time.Hour {
|
||||
log.Debugf("refreshing codex tokens for %s", claudeCli.GetEmail())
|
||||
log.Debugf("refreshing claude tokens for %s", claudeCli.GetEmail())
|
||||
_ = claudeCli.RefreshTokens(ctxRefresh)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
} else if qwenCli, isQwenOK := cliClients[i].(*client.QwenClient); isQwenOK {
|
||||
if ts, isQwenTS := qwenCli.TokenStorage().(*qwen.QwenTokenStorage); isQwenTS {
|
||||
if ts != nil && ts.Expire != "" {
|
||||
if expTime, errParse := time.Parse(time.RFC3339, ts.Expire); errParse == nil {
|
||||
if time.Until(expTime) <= 3*time.Hour {
|
||||
log.Debugf("refreshing qwen tokens for %s", qwenCli.GetEmail())
|
||||
_ = qwenCli.RefreshTokens(ctxRefresh)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
253
internal/translator/openai/claude/openai_claude_request.go
Normal file
253
internal/translator/openai/claude/openai_claude_request.go
Normal file
@@ -0,0 +1,253 @@
|
||||
// Package claude provides request translation functionality for Anthropic to OpenAI API.
|
||||
// It handles parsing and transforming Anthropic API requests into OpenAI Chat Completions API format,
|
||||
// extracting model information, system instructions, message contents, and tool declarations.
|
||||
// The package performs JSON data transformation to ensure compatibility
|
||||
// between Anthropic API format and OpenAI API's expected format.
|
||||
package claude
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"strings"
|
||||
|
||||
"github.com/tidwall/gjson"
|
||||
"github.com/tidwall/sjson"
|
||||
)
|
||||
|
||||
// ConvertAnthropicRequestToOpenAI parses and transforms an Anthropic API request into OpenAI Chat Completions API format.
|
||||
// It extracts the model name, system instruction, message contents, and tool declarations
|
||||
// from the raw JSON request and returns them in the format expected by the OpenAI API.
|
||||
func ConvertAnthropicRequestToOpenAI(rawJSON []byte) string {
|
||||
// Base OpenAI Chat Completions API template
|
||||
out := `{"model":"","messages":[]}`
|
||||
|
||||
root := gjson.ParseBytes(rawJSON)
|
||||
|
||||
// Model mapping
|
||||
if model := root.Get("model"); model.Exists() {
|
||||
modelStr := model.String()
|
||||
out, _ = sjson.Set(out, "model", modelStr)
|
||||
}
|
||||
|
||||
// Max tokens
|
||||
if maxTokens := root.Get("max_tokens"); maxTokens.Exists() {
|
||||
out, _ = sjson.Set(out, "max_tokens", maxTokens.Int())
|
||||
}
|
||||
|
||||
// Temperature
|
||||
if temp := root.Get("temperature"); temp.Exists() {
|
||||
out, _ = sjson.Set(out, "temperature", temp.Float())
|
||||
}
|
||||
|
||||
// Top P
|
||||
if topP := root.Get("top_p"); topP.Exists() {
|
||||
out, _ = sjson.Set(out, "top_p", topP.Float())
|
||||
}
|
||||
|
||||
// Stop sequences -> stop
|
||||
if stopSequences := root.Get("stop_sequences"); stopSequences.Exists() {
|
||||
if stopSequences.IsArray() {
|
||||
var stops []string
|
||||
stopSequences.ForEach(func(_, value gjson.Result) bool {
|
||||
stops = append(stops, value.String())
|
||||
return true
|
||||
})
|
||||
if len(stops) > 0 {
|
||||
if len(stops) == 1 {
|
||||
out, _ = sjson.Set(out, "stop", stops[0])
|
||||
} else {
|
||||
out, _ = sjson.Set(out, "stop", stops)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Stream
|
||||
if stream := root.Get("stream"); stream.Exists() {
|
||||
out, _ = sjson.Set(out, "stream", stream.Bool())
|
||||
}
|
||||
|
||||
// Process messages and system
|
||||
var openAIMessages []interface{}
|
||||
|
||||
// Handle system message first
|
||||
if system := root.Get("system"); system.Exists() && system.String() != "" {
|
||||
systemMsg := map[string]interface{}{
|
||||
"role": "system",
|
||||
"content": system.String(),
|
||||
}
|
||||
openAIMessages = append(openAIMessages, systemMsg)
|
||||
}
|
||||
|
||||
// Process Anthropic messages
|
||||
if messages := root.Get("messages"); messages.Exists() && messages.IsArray() {
|
||||
messages.ForEach(func(_, message gjson.Result) bool {
|
||||
role := message.Get("role").String()
|
||||
contentResult := message.Get("content")
|
||||
|
||||
msg := map[string]interface{}{
|
||||
"role": role,
|
||||
}
|
||||
|
||||
// Handle content
|
||||
if contentResult.Exists() && contentResult.IsArray() {
|
||||
var textParts []string
|
||||
var toolCalls []interface{}
|
||||
var toolResults []interface{}
|
||||
|
||||
contentResult.ForEach(func(_, part gjson.Result) bool {
|
||||
partType := part.Get("type").String()
|
||||
|
||||
switch partType {
|
||||
case "text":
|
||||
textParts = append(textParts, part.Get("text").String())
|
||||
|
||||
case "image":
|
||||
// Convert Anthropic image format to OpenAI format
|
||||
if source := part.Get("source"); source.Exists() {
|
||||
sourceType := source.Get("type").String()
|
||||
if sourceType == "base64" {
|
||||
mediaType := source.Get("media_type").String()
|
||||
data := source.Get("data").String()
|
||||
imageURL := "data:" + mediaType + ";base64," + data
|
||||
|
||||
// For now, add as text since OpenAI image handling is complex
|
||||
// In a real implementation, you'd need to handle this properly
|
||||
textParts = append(textParts, "[Image: "+imageURL+"]")
|
||||
}
|
||||
}
|
||||
|
||||
case "tool_use":
|
||||
// Convert to OpenAI tool call format
|
||||
toolCall := map[string]interface{}{
|
||||
"id": part.Get("id").String(),
|
||||
"type": "function",
|
||||
"function": map[string]interface{}{
|
||||
"name": part.Get("name").String(),
|
||||
},
|
||||
}
|
||||
|
||||
// Convert input to arguments JSON string
|
||||
if input := part.Get("input"); input.Exists() {
|
||||
if inputJSON, err := json.Marshal(input.Value()); err == nil {
|
||||
if function, ok := toolCall["function"].(map[string]interface{}); ok {
|
||||
function["arguments"] = string(inputJSON)
|
||||
}
|
||||
} else {
|
||||
if function, ok := toolCall["function"].(map[string]interface{}); ok {
|
||||
function["arguments"] = "{}"
|
||||
}
|
||||
}
|
||||
} else {
|
||||
if function, ok := toolCall["function"].(map[string]interface{}); ok {
|
||||
function["arguments"] = "{}"
|
||||
}
|
||||
}
|
||||
|
||||
toolCalls = append(toolCalls, toolCall)
|
||||
|
||||
case "tool_result":
|
||||
// Convert to OpenAI tool message format
|
||||
toolResult := map[string]interface{}{
|
||||
"role": "tool",
|
||||
"tool_call_id": part.Get("tool_use_id").String(),
|
||||
"content": part.Get("content").String(),
|
||||
}
|
||||
toolResults = append(toolResults, toolResult)
|
||||
}
|
||||
return true
|
||||
})
|
||||
|
||||
// Set content
|
||||
if len(textParts) > 0 {
|
||||
msg["content"] = strings.Join(textParts, "")
|
||||
} else {
|
||||
msg["content"] = ""
|
||||
}
|
||||
|
||||
// Set tool calls for assistant messages
|
||||
if role == "assistant" && len(toolCalls) > 0 {
|
||||
msg["tool_calls"] = toolCalls
|
||||
}
|
||||
|
||||
openAIMessages = append(openAIMessages, msg)
|
||||
|
||||
// Add tool result messages separately
|
||||
for _, toolResult := range toolResults {
|
||||
openAIMessages = append(openAIMessages, toolResult)
|
||||
}
|
||||
|
||||
} else if contentResult.Exists() && contentResult.Type == gjson.String {
|
||||
// Simple string content
|
||||
msg["content"] = contentResult.String()
|
||||
openAIMessages = append(openAIMessages, msg)
|
||||
}
|
||||
|
||||
return true
|
||||
})
|
||||
}
|
||||
|
||||
// Set messages
|
||||
if len(openAIMessages) > 0 {
|
||||
messagesJSON, _ := json.Marshal(openAIMessages)
|
||||
out, _ = sjson.SetRaw(out, "messages", string(messagesJSON))
|
||||
}
|
||||
|
||||
// Process tools - convert Anthropic tools to OpenAI functions
|
||||
if tools := root.Get("tools"); tools.Exists() && tools.IsArray() {
|
||||
var openAITools []interface{}
|
||||
|
||||
tools.ForEach(func(_, tool gjson.Result) bool {
|
||||
openAITool := map[string]interface{}{
|
||||
"type": "function",
|
||||
"function": map[string]interface{}{
|
||||
"name": tool.Get("name").String(),
|
||||
"description": tool.Get("description").String(),
|
||||
},
|
||||
}
|
||||
|
||||
// Convert Anthropic input_schema to OpenAI function parameters
|
||||
if inputSchema := tool.Get("input_schema"); inputSchema.Exists() {
|
||||
if function, ok := openAITool["function"].(map[string]interface{}); ok {
|
||||
function["parameters"] = inputSchema.Value()
|
||||
}
|
||||
}
|
||||
|
||||
openAITools = append(openAITools, openAITool)
|
||||
return true
|
||||
})
|
||||
|
||||
if len(openAITools) > 0 {
|
||||
toolsJSON, _ := json.Marshal(openAITools)
|
||||
out, _ = sjson.SetRaw(out, "tools", string(toolsJSON))
|
||||
}
|
||||
}
|
||||
|
||||
// Tool choice mapping - convert Anthropic tool_choice to OpenAI format
|
||||
if toolChoice := root.Get("tool_choice"); toolChoice.Exists() {
|
||||
switch toolChoice.Get("type").String() {
|
||||
case "auto":
|
||||
out, _ = sjson.Set(out, "tool_choice", "auto")
|
||||
case "any":
|
||||
out, _ = sjson.Set(out, "tool_choice", "required")
|
||||
case "tool":
|
||||
// Specific tool choice
|
||||
toolName := toolChoice.Get("name").String()
|
||||
out, _ = sjson.Set(out, "tool_choice", map[string]interface{}{
|
||||
"type": "function",
|
||||
"function": map[string]interface{}{
|
||||
"name": toolName,
|
||||
},
|
||||
})
|
||||
default:
|
||||
// Default to auto if not specified
|
||||
out, _ = sjson.Set(out, "tool_choice", "auto")
|
||||
}
|
||||
}
|
||||
|
||||
// Handle user parameter (for tracking)
|
||||
if user := root.Get("user"); user.Exists() {
|
||||
out, _ = sjson.Set(out, "user", user.String())
|
||||
}
|
||||
|
||||
return out
|
||||
}
|
||||
389
internal/translator/openai/claude/openai_claude_response.go
Normal file
389
internal/translator/openai/claude/openai_claude_response.go
Normal file
@@ -0,0 +1,389 @@
|
||||
// Package claude provides response translation functionality for OpenAI to Anthropic API.
|
||||
// This package handles the conversion of OpenAI Chat Completions API responses into Anthropic API-compatible
|
||||
// JSON format, transforming streaming events and non-streaming responses into the format
|
||||
// expected by Anthropic API clients. It supports both streaming and non-streaming modes,
|
||||
// handling text content, tool calls, and usage metadata appropriately.
|
||||
package claude
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"strings"
|
||||
|
||||
"github.com/tidwall/gjson"
|
||||
)
|
||||
|
||||
// ConvertOpenAIResponseToAnthropicParams holds parameters for response conversion
|
||||
type ConvertOpenAIResponseToAnthropicParams struct {
|
||||
MessageID string
|
||||
Model string
|
||||
CreatedAt int64
|
||||
// Content accumulator for streaming
|
||||
ContentAccumulator strings.Builder
|
||||
// Tool calls accumulator for streaming
|
||||
ToolCallsAccumulator map[int]*ToolCallAccumulator
|
||||
// Track if text content block has been started
|
||||
TextContentBlockStarted bool
|
||||
// Track finish reason for later use
|
||||
FinishReason string
|
||||
// Track if content blocks have been stopped
|
||||
ContentBlocksStopped bool
|
||||
// Track if message_delta has been sent
|
||||
MessageDeltaSent bool
|
||||
}
|
||||
|
||||
// ToolCallAccumulator holds the state for accumulating tool call data
|
||||
type ToolCallAccumulator struct {
|
||||
ID string
|
||||
Name string
|
||||
Arguments strings.Builder
|
||||
}
|
||||
|
||||
// ConvertOpenAIResponseToAnthropic converts OpenAI streaming response format to Anthropic API format.
|
||||
// This function processes OpenAI streaming chunks and transforms them into Anthropic-compatible JSON responses.
|
||||
// It handles text content, tool calls, and usage metadata, outputting responses that match the Anthropic API format.
|
||||
func ConvertOpenAIResponseToAnthropic(rawJSON []byte, param *ConvertOpenAIResponseToAnthropicParams) []string {
|
||||
// Check if this is the [DONE] marker
|
||||
rawStr := strings.TrimSpace(string(rawJSON))
|
||||
if rawStr == "[DONE]" {
|
||||
return convertOpenAIDoneToAnthropic(param)
|
||||
}
|
||||
|
||||
root := gjson.ParseBytes(rawJSON)
|
||||
|
||||
// Check if this is a streaming chunk or non-streaming response
|
||||
objectType := root.Get("object").String()
|
||||
|
||||
if objectType == "chat.completion.chunk" {
|
||||
// Handle streaming response
|
||||
return convertOpenAIStreamingChunkToAnthropic(rawJSON, param)
|
||||
} else if objectType == "chat.completion" {
|
||||
// Handle non-streaming response
|
||||
return convertOpenAINonStreamingToAnthropic(rawJSON)
|
||||
}
|
||||
|
||||
return []string{}
|
||||
}
|
||||
|
||||
// convertOpenAIStreamingChunkToAnthropic converts OpenAI streaming chunk to Anthropic streaming events
|
||||
func convertOpenAIStreamingChunkToAnthropic(rawJSON []byte, param *ConvertOpenAIResponseToAnthropicParams) []string {
|
||||
root := gjson.ParseBytes(rawJSON)
|
||||
var results []string
|
||||
|
||||
// Initialize parameters if needed
|
||||
if param.MessageID == "" {
|
||||
param.MessageID = root.Get("id").String()
|
||||
}
|
||||
if param.Model == "" {
|
||||
param.Model = root.Get("model").String()
|
||||
}
|
||||
if param.CreatedAt == 0 {
|
||||
param.CreatedAt = root.Get("created").Int()
|
||||
}
|
||||
|
||||
// Check if this is the first chunk (has role)
|
||||
if delta := root.Get("choices.0.delta"); delta.Exists() {
|
||||
if role := delta.Get("role"); role.Exists() && role.String() == "assistant" {
|
||||
// Send message_start event
|
||||
messageStart := map[string]interface{}{
|
||||
"type": "message_start",
|
||||
"message": map[string]interface{}{
|
||||
"id": param.MessageID,
|
||||
"type": "message",
|
||||
"role": "assistant",
|
||||
"model": param.Model,
|
||||
"content": []interface{}{},
|
||||
"stop_reason": nil,
|
||||
"stop_sequence": nil,
|
||||
"usage": map[string]interface{}{
|
||||
"input_tokens": 0,
|
||||
"output_tokens": 0,
|
||||
},
|
||||
},
|
||||
}
|
||||
messageStartJSON, _ := json.Marshal(messageStart)
|
||||
results = append(results, "event: message_start\ndata: "+string(messageStartJSON)+"\n\n")
|
||||
|
||||
// Don't send content_block_start for text here - wait for actual content
|
||||
}
|
||||
|
||||
// Handle content delta
|
||||
if content := delta.Get("content"); content.Exists() && content.String() != "" {
|
||||
// Send content_block_start for text if not already sent
|
||||
if !param.TextContentBlockStarted {
|
||||
contentBlockStart := map[string]interface{}{
|
||||
"type": "content_block_start",
|
||||
"index": 0,
|
||||
"content_block": map[string]interface{}{
|
||||
"type": "text",
|
||||
"text": "",
|
||||
},
|
||||
}
|
||||
contentBlockStartJSON, _ := json.Marshal(contentBlockStart)
|
||||
results = append(results, "event: content_block_start\ndata: "+string(contentBlockStartJSON)+"\n\n")
|
||||
param.TextContentBlockStarted = true
|
||||
}
|
||||
|
||||
contentDelta := map[string]interface{}{
|
||||
"type": "content_block_delta",
|
||||
"index": 0,
|
||||
"delta": map[string]interface{}{
|
||||
"type": "text_delta",
|
||||
"text": content.String(),
|
||||
},
|
||||
}
|
||||
contentDeltaJSON, _ := json.Marshal(contentDelta)
|
||||
results = append(results, "event: content_block_delta\ndata: "+string(contentDeltaJSON)+"\n\n")
|
||||
|
||||
// Accumulate content
|
||||
param.ContentAccumulator.WriteString(content.String())
|
||||
}
|
||||
|
||||
// Handle tool calls
|
||||
if toolCalls := delta.Get("tool_calls"); toolCalls.Exists() && toolCalls.IsArray() {
|
||||
if param.ToolCallsAccumulator == nil {
|
||||
param.ToolCallsAccumulator = make(map[int]*ToolCallAccumulator)
|
||||
}
|
||||
|
||||
toolCalls.ForEach(func(_, toolCall gjson.Result) bool {
|
||||
index := int(toolCall.Get("index").Int())
|
||||
|
||||
// Initialize accumulator if needed
|
||||
if _, exists := param.ToolCallsAccumulator[index]; !exists {
|
||||
param.ToolCallsAccumulator[index] = &ToolCallAccumulator{}
|
||||
}
|
||||
|
||||
accumulator := param.ToolCallsAccumulator[index]
|
||||
|
||||
// Handle tool call ID
|
||||
if id := toolCall.Get("id"); id.Exists() {
|
||||
accumulator.ID = id.String()
|
||||
}
|
||||
|
||||
// Handle function name
|
||||
if function := toolCall.Get("function"); function.Exists() {
|
||||
if name := function.Get("name"); name.Exists() {
|
||||
accumulator.Name = name.String()
|
||||
|
||||
// Send content_block_start for tool_use
|
||||
contentBlockStart := map[string]interface{}{
|
||||
"type": "content_block_start",
|
||||
"index": index + 1, // Offset by 1 since text is at index 0
|
||||
"content_block": map[string]interface{}{
|
||||
"type": "tool_use",
|
||||
"id": accumulator.ID,
|
||||
"name": accumulator.Name,
|
||||
"input": map[string]interface{}{},
|
||||
},
|
||||
}
|
||||
contentBlockStartJSON, _ := json.Marshal(contentBlockStart)
|
||||
results = append(results, "event: content_block_start\ndata: "+string(contentBlockStartJSON)+"\n\n")
|
||||
}
|
||||
|
||||
// Handle function arguments
|
||||
if args := function.Get("arguments"); args.Exists() {
|
||||
argsText := args.String()
|
||||
accumulator.Arguments.WriteString(argsText)
|
||||
|
||||
// Send input_json_delta
|
||||
inputDelta := map[string]interface{}{
|
||||
"type": "content_block_delta",
|
||||
"index": index + 1,
|
||||
"delta": map[string]interface{}{
|
||||
"type": "input_json_delta",
|
||||
"partial_json": argsText,
|
||||
},
|
||||
}
|
||||
inputDeltaJSON, _ := json.Marshal(inputDelta)
|
||||
results = append(results, "event: content_block_delta\ndata: "+string(inputDeltaJSON)+"\n\n")
|
||||
}
|
||||
}
|
||||
|
||||
return true
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// Handle finish_reason (but don't send message_delta/message_stop yet)
|
||||
if finishReason := root.Get("choices.0.finish_reason"); finishReason.Exists() && finishReason.String() != "" {
|
||||
reason := finishReason.String()
|
||||
param.FinishReason = reason
|
||||
|
||||
// Send content_block_stop for text if text content block was started
|
||||
if param.TextContentBlockStarted && !param.ContentBlocksStopped {
|
||||
contentBlockStop := map[string]interface{}{
|
||||
"type": "content_block_stop",
|
||||
"index": 0,
|
||||
}
|
||||
contentBlockStopJSON, _ := json.Marshal(contentBlockStop)
|
||||
results = append(results, "event: content_block_stop\ndata: "+string(contentBlockStopJSON)+"\n\n")
|
||||
}
|
||||
|
||||
// Send content_block_stop for any tool calls
|
||||
if !param.ContentBlocksStopped {
|
||||
for index := range param.ToolCallsAccumulator {
|
||||
contentBlockStop := map[string]interface{}{
|
||||
"type": "content_block_stop",
|
||||
"index": index + 1,
|
||||
}
|
||||
contentBlockStopJSON, _ := json.Marshal(contentBlockStop)
|
||||
results = append(results, "event: content_block_stop\ndata: "+string(contentBlockStopJSON)+"\n\n")
|
||||
}
|
||||
param.ContentBlocksStopped = true
|
||||
}
|
||||
|
||||
// Don't send message_delta here - wait for usage info or [DONE]
|
||||
}
|
||||
|
||||
// Handle usage information separately (this comes in a later chunk)
|
||||
// Only process if usage has actual values (not null)
|
||||
if usage := root.Get("usage"); usage.Exists() && usage.Type != gjson.Null && param.FinishReason != "" {
|
||||
// Check if usage has actual token counts
|
||||
promptTokens := usage.Get("prompt_tokens")
|
||||
completionTokens := usage.Get("completion_tokens")
|
||||
|
||||
if promptTokens.Exists() && completionTokens.Exists() {
|
||||
// Send message_delta with usage
|
||||
messageDelta := map[string]interface{}{
|
||||
"type": "message_delta",
|
||||
"delta": map[string]interface{}{
|
||||
"stop_reason": mapOpenAIFinishReasonToAnthropic(param.FinishReason),
|
||||
"stop_sequence": nil,
|
||||
},
|
||||
"usage": map[string]interface{}{
|
||||
"input_tokens": promptTokens.Int(),
|
||||
"output_tokens": completionTokens.Int(),
|
||||
},
|
||||
}
|
||||
|
||||
messageDeltaJSON, _ := json.Marshal(messageDelta)
|
||||
results = append(results, "event: message_delta\ndata: "+string(messageDeltaJSON)+"\n\n")
|
||||
param.MessageDeltaSent = true
|
||||
}
|
||||
}
|
||||
|
||||
return results
|
||||
}
|
||||
|
||||
// convertOpenAIDoneToAnthropic handles the [DONE] marker and sends final events
|
||||
func convertOpenAIDoneToAnthropic(param *ConvertOpenAIResponseToAnthropicParams) []string {
|
||||
var results []string
|
||||
|
||||
// If we haven't sent message_delta yet (no usage info was received), send it now
|
||||
if param.FinishReason != "" && !param.MessageDeltaSent {
|
||||
messageDelta := map[string]interface{}{
|
||||
"type": "message_delta",
|
||||
"delta": map[string]interface{}{
|
||||
"stop_reason": mapOpenAIFinishReasonToAnthropic(param.FinishReason),
|
||||
"stop_sequence": nil,
|
||||
},
|
||||
}
|
||||
|
||||
messageDeltaJSON, _ := json.Marshal(messageDelta)
|
||||
results = append(results, "event: message_delta\ndata: "+string(messageDeltaJSON)+"\n\n")
|
||||
param.MessageDeltaSent = true
|
||||
}
|
||||
|
||||
// Send message_stop
|
||||
results = append(results, "event: message_stop\ndata: {\"type\":\"message_stop\"}\n\n")
|
||||
|
||||
return results
|
||||
}
|
||||
|
||||
// convertOpenAINonStreamingToAnthropic converts OpenAI non-streaming response to Anthropic format
|
||||
func convertOpenAINonStreamingToAnthropic(rawJSON []byte) []string {
|
||||
root := gjson.ParseBytes(rawJSON)
|
||||
|
||||
// Build Anthropic response
|
||||
response := map[string]interface{}{
|
||||
"id": root.Get("id").String(),
|
||||
"type": "message",
|
||||
"role": "assistant",
|
||||
"model": root.Get("model").String(),
|
||||
"content": []interface{}{},
|
||||
"stop_reason": nil,
|
||||
"stop_sequence": nil,
|
||||
"usage": map[string]interface{}{
|
||||
"input_tokens": 0,
|
||||
"output_tokens": 0,
|
||||
},
|
||||
}
|
||||
|
||||
// Process message content and tool calls
|
||||
var contentBlocks []interface{}
|
||||
|
||||
if choices := root.Get("choices"); choices.Exists() && choices.IsArray() {
|
||||
choice := choices.Array()[0] // Take first choice
|
||||
|
||||
// Handle text content
|
||||
if content := choice.Get("message.content"); content.Exists() && content.String() != "" {
|
||||
textBlock := map[string]interface{}{
|
||||
"type": "text",
|
||||
"text": content.String(),
|
||||
}
|
||||
contentBlocks = append(contentBlocks, textBlock)
|
||||
}
|
||||
|
||||
// Handle tool calls
|
||||
if toolCalls := choice.Get("message.tool_calls"); toolCalls.Exists() && toolCalls.IsArray() {
|
||||
toolCalls.ForEach(func(_, toolCall gjson.Result) bool {
|
||||
toolUseBlock := map[string]interface{}{
|
||||
"type": "tool_use",
|
||||
"id": toolCall.Get("id").String(),
|
||||
"name": toolCall.Get("function.name").String(),
|
||||
}
|
||||
|
||||
// Parse arguments
|
||||
argsStr := toolCall.Get("function.arguments").String()
|
||||
if argsStr != "" {
|
||||
var args interface{}
|
||||
if err := json.Unmarshal([]byte(argsStr), &args); err == nil {
|
||||
toolUseBlock["input"] = args
|
||||
} else {
|
||||
toolUseBlock["input"] = map[string]interface{}{}
|
||||
}
|
||||
} else {
|
||||
toolUseBlock["input"] = map[string]interface{}{}
|
||||
}
|
||||
|
||||
contentBlocks = append(contentBlocks, toolUseBlock)
|
||||
return true
|
||||
})
|
||||
}
|
||||
|
||||
// Set stop reason
|
||||
if finishReason := choice.Get("finish_reason"); finishReason.Exists() {
|
||||
response["stop_reason"] = mapOpenAIFinishReasonToAnthropic(finishReason.String())
|
||||
}
|
||||
}
|
||||
|
||||
response["content"] = contentBlocks
|
||||
|
||||
// Set usage information
|
||||
if usage := root.Get("usage"); usage.Exists() {
|
||||
response["usage"] = map[string]interface{}{
|
||||
"input_tokens": usage.Get("prompt_tokens").Int(),
|
||||
"output_tokens": usage.Get("completion_tokens").Int(),
|
||||
}
|
||||
}
|
||||
|
||||
responseJSON, _ := json.Marshal(response)
|
||||
return []string{string(responseJSON)}
|
||||
}
|
||||
|
||||
// mapOpenAIFinishReasonToAnthropic maps OpenAI finish reasons to Anthropic equivalents
|
||||
func mapOpenAIFinishReasonToAnthropic(openAIReason string) string {
|
||||
switch openAIReason {
|
||||
case "stop":
|
||||
return "end_turn"
|
||||
case "length":
|
||||
return "max_tokens"
|
||||
case "tool_calls":
|
||||
return "tool_use"
|
||||
case "content_filter":
|
||||
return "end_turn" // Anthropic doesn't have direct equivalent
|
||||
case "function_call": // Legacy OpenAI
|
||||
return "tool_use"
|
||||
default:
|
||||
return "end_turn"
|
||||
}
|
||||
}
|
||||
359
internal/translator/openai/gemini/openai_gemini_request.go
Normal file
359
internal/translator/openai/gemini/openai_gemini_request.go
Normal file
@@ -0,0 +1,359 @@
|
||||
// Package gemini provides request translation functionality for Gemini to OpenAI API.
|
||||
// It handles parsing and transforming Gemini API requests into OpenAI Chat Completions API format,
|
||||
// extracting model information, generation config, message contents, and tool declarations.
|
||||
// The package performs JSON data transformation to ensure compatibility
|
||||
// between Gemini API format and OpenAI API's expected format.
|
||||
package gemini
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"encoding/json"
|
||||
"math/big"
|
||||
"strings"
|
||||
|
||||
"github.com/tidwall/gjson"
|
||||
"github.com/tidwall/sjson"
|
||||
)
|
||||
|
||||
// ConvertGeminiRequestToOpenAI parses and transforms a Gemini API request into OpenAI Chat Completions API format.
|
||||
// It extracts the model name, generation config, message contents, and tool declarations
|
||||
// from the raw JSON request and returns them in the format expected by the OpenAI API.
|
||||
func ConvertGeminiRequestToOpenAI(rawJSON []byte) string {
|
||||
// Base OpenAI Chat Completions API template
|
||||
out := `{"model":"","messages":[]}`
|
||||
|
||||
root := gjson.ParseBytes(rawJSON)
|
||||
|
||||
// Helper for generating tool call IDs in the form: call_<alphanum>
|
||||
genToolCallID := func() string {
|
||||
const letters = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789"
|
||||
var b strings.Builder
|
||||
// 24 chars random suffix
|
||||
for i := 0; i < 24; i++ {
|
||||
n, _ := rand.Int(rand.Reader, big.NewInt(int64(len(letters))))
|
||||
b.WriteByte(letters[n.Int64()])
|
||||
}
|
||||
return "call_" + b.String()
|
||||
}
|
||||
|
||||
// Model mapping
|
||||
if model := root.Get("model"); model.Exists() {
|
||||
modelStr := model.String()
|
||||
out, _ = sjson.Set(out, "model", modelStr)
|
||||
}
|
||||
|
||||
// Generation config mapping
|
||||
if genConfig := root.Get("generationConfig"); genConfig.Exists() {
|
||||
// Temperature
|
||||
if temp := genConfig.Get("temperature"); temp.Exists() {
|
||||
out, _ = sjson.Set(out, "temperature", temp.Float())
|
||||
}
|
||||
|
||||
// Max tokens
|
||||
if maxTokens := genConfig.Get("maxOutputTokens"); maxTokens.Exists() {
|
||||
out, _ = sjson.Set(out, "max_tokens", maxTokens.Int())
|
||||
}
|
||||
|
||||
// Top P
|
||||
if topP := genConfig.Get("topP"); topP.Exists() {
|
||||
out, _ = sjson.Set(out, "top_p", topP.Float())
|
||||
}
|
||||
|
||||
// Top K (OpenAI doesn't have direct equivalent, but we can map it)
|
||||
if topK := genConfig.Get("topK"); topK.Exists() {
|
||||
// Store as custom parameter for potential use
|
||||
out, _ = sjson.Set(out, "top_k", topK.Int())
|
||||
}
|
||||
|
||||
// Stop sequences
|
||||
if stopSequences := genConfig.Get("stopSequences"); stopSequences.Exists() && stopSequences.IsArray() {
|
||||
var stops []string
|
||||
stopSequences.ForEach(func(_, value gjson.Result) bool {
|
||||
stops = append(stops, value.String())
|
||||
return true
|
||||
})
|
||||
if len(stops) > 0 {
|
||||
out, _ = sjson.Set(out, "stop", stops)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Stream parameter
|
||||
if stream := root.Get("stream"); stream.Exists() {
|
||||
out, _ = sjson.Set(out, "stream", stream.Bool())
|
||||
}
|
||||
|
||||
// Process contents (Gemini messages) -> OpenAI messages
|
||||
var openAIMessages []interface{}
|
||||
var toolCallIDs []string // Track tool call IDs for matching with tool results
|
||||
|
||||
if contents := root.Get("contents"); contents.Exists() && contents.IsArray() {
|
||||
contents.ForEach(func(_, content gjson.Result) bool {
|
||||
role := content.Get("role").String()
|
||||
parts := content.Get("parts")
|
||||
|
||||
// Convert role: model -> assistant
|
||||
if role == "model" {
|
||||
role = "assistant"
|
||||
}
|
||||
|
||||
// Create OpenAI message
|
||||
msg := map[string]interface{}{
|
||||
"role": role,
|
||||
"content": "",
|
||||
}
|
||||
|
||||
var contentParts []string
|
||||
var toolCalls []interface{}
|
||||
|
||||
if parts.Exists() && parts.IsArray() {
|
||||
parts.ForEach(func(_, part gjson.Result) bool {
|
||||
// Handle text parts
|
||||
if text := part.Get("text"); text.Exists() {
|
||||
contentParts = append(contentParts, text.String())
|
||||
}
|
||||
|
||||
// Handle function calls (Gemini) -> tool calls (OpenAI)
|
||||
if functionCall := part.Get("functionCall"); functionCall.Exists() {
|
||||
toolCallID := genToolCallID()
|
||||
toolCallIDs = append(toolCallIDs, toolCallID)
|
||||
|
||||
toolCall := map[string]interface{}{
|
||||
"id": toolCallID,
|
||||
"type": "function",
|
||||
"function": map[string]interface{}{
|
||||
"name": functionCall.Get("name").String(),
|
||||
},
|
||||
}
|
||||
|
||||
// Convert args to arguments JSON string
|
||||
if args := functionCall.Get("args"); args.Exists() {
|
||||
argsJSON, _ := json.Marshal(args.Value())
|
||||
toolCall["function"].(map[string]interface{})["arguments"] = string(argsJSON)
|
||||
} else {
|
||||
toolCall["function"].(map[string]interface{})["arguments"] = "{}"
|
||||
}
|
||||
|
||||
toolCalls = append(toolCalls, toolCall)
|
||||
}
|
||||
|
||||
// Handle function responses (Gemini) -> tool role messages (OpenAI)
|
||||
if functionResponse := part.Get("functionResponse"); functionResponse.Exists() {
|
||||
// Create tool message for function response
|
||||
toolMsg := map[string]interface{}{
|
||||
"role": "tool",
|
||||
"tool_call_id": "", // Will be set based on context
|
||||
"content": "",
|
||||
}
|
||||
|
||||
// Convert response.content to JSON string
|
||||
if response := functionResponse.Get("response"); response.Exists() {
|
||||
if content = response.Get("content"); content.Exists() {
|
||||
// Use the content field from the response
|
||||
contentJSON, _ := json.Marshal(content.Value())
|
||||
toolMsg["content"] = string(contentJSON)
|
||||
} else {
|
||||
// Fallback to entire response
|
||||
responseJSON, _ := json.Marshal(response.Value())
|
||||
toolMsg["content"] = string(responseJSON)
|
||||
}
|
||||
}
|
||||
|
||||
// Try to match with previous tool call ID
|
||||
_ = functionResponse.Get("name").String() // functionName not used for now
|
||||
if len(toolCallIDs) > 0 {
|
||||
// Use the last tool call ID (simple matching by function name)
|
||||
// In a real implementation, you might want more sophisticated matching
|
||||
toolMsg["tool_call_id"] = toolCallIDs[len(toolCallIDs)-1]
|
||||
} else {
|
||||
// Generate a tool call ID if none available
|
||||
toolMsg["tool_call_id"] = genToolCallID()
|
||||
}
|
||||
|
||||
openAIMessages = append(openAIMessages, toolMsg)
|
||||
}
|
||||
|
||||
return true
|
||||
})
|
||||
}
|
||||
|
||||
// Set content
|
||||
if len(contentParts) > 0 {
|
||||
msg["content"] = strings.Join(contentParts, "")
|
||||
}
|
||||
|
||||
// Set tool calls if any
|
||||
if len(toolCalls) > 0 {
|
||||
msg["tool_calls"] = toolCalls
|
||||
}
|
||||
|
||||
openAIMessages = append(openAIMessages, msg)
|
||||
|
||||
// switch role {
|
||||
// case "user", "model":
|
||||
// // Convert role: model -> assistant
|
||||
// if role == "model" {
|
||||
// role = "assistant"
|
||||
// }
|
||||
//
|
||||
// // Create OpenAI message
|
||||
// msg := map[string]interface{}{
|
||||
// "role": role,
|
||||
// "content": "",
|
||||
// }
|
||||
//
|
||||
// var contentParts []string
|
||||
// var toolCalls []interface{}
|
||||
//
|
||||
// if parts.Exists() && parts.IsArray() {
|
||||
// parts.ForEach(func(_, part gjson.Result) bool {
|
||||
// // Handle text parts
|
||||
// if text := part.Get("text"); text.Exists() {
|
||||
// contentParts = append(contentParts, text.String())
|
||||
// }
|
||||
//
|
||||
// // Handle function calls (Gemini) -> tool calls (OpenAI)
|
||||
// if functionCall := part.Get("functionCall"); functionCall.Exists() {
|
||||
// toolCallID := genToolCallID()
|
||||
// toolCallIDs = append(toolCallIDs, toolCallID)
|
||||
//
|
||||
// toolCall := map[string]interface{}{
|
||||
// "id": toolCallID,
|
||||
// "type": "function",
|
||||
// "function": map[string]interface{}{
|
||||
// "name": functionCall.Get("name").String(),
|
||||
// },
|
||||
// }
|
||||
//
|
||||
// // Convert args to arguments JSON string
|
||||
// if args := functionCall.Get("args"); args.Exists() {
|
||||
// argsJSON, _ := json.Marshal(args.Value())
|
||||
// toolCall["function"].(map[string]interface{})["arguments"] = string(argsJSON)
|
||||
// } else {
|
||||
// toolCall["function"].(map[string]interface{})["arguments"] = "{}"
|
||||
// }
|
||||
//
|
||||
// toolCalls = append(toolCalls, toolCall)
|
||||
// }
|
||||
//
|
||||
// return true
|
||||
// })
|
||||
// }
|
||||
//
|
||||
// // Set content
|
||||
// if len(contentParts) > 0 {
|
||||
// msg["content"] = strings.Join(contentParts, "")
|
||||
// }
|
||||
//
|
||||
// // Set tool calls if any
|
||||
// if len(toolCalls) > 0 {
|
||||
// msg["tool_calls"] = toolCalls
|
||||
// }
|
||||
//
|
||||
// openAIMessages = append(openAIMessages, msg)
|
||||
//
|
||||
// case "function":
|
||||
// // Handle Gemini function role -> OpenAI tool role
|
||||
// if parts.Exists() && parts.IsArray() {
|
||||
// parts.ForEach(func(_, part gjson.Result) bool {
|
||||
// // Handle function responses (Gemini) -> tool role messages (OpenAI)
|
||||
// if functionResponse := part.Get("functionResponse"); functionResponse.Exists() {
|
||||
// // Create tool message for function response
|
||||
// toolMsg := map[string]interface{}{
|
||||
// "role": "tool",
|
||||
// "tool_call_id": "", // Will be set based on context
|
||||
// "content": "",
|
||||
// }
|
||||
//
|
||||
// // Convert response.content to JSON string
|
||||
// if response := functionResponse.Get("response"); response.Exists() {
|
||||
// if content = response.Get("content"); content.Exists() {
|
||||
// // Use the content field from the response
|
||||
// contentJSON, _ := json.Marshal(content.Value())
|
||||
// toolMsg["content"] = string(contentJSON)
|
||||
// } else {
|
||||
// // Fallback to entire response
|
||||
// responseJSON, _ := json.Marshal(response.Value())
|
||||
// toolMsg["content"] = string(responseJSON)
|
||||
// }
|
||||
// }
|
||||
//
|
||||
// // Try to match with previous tool call ID
|
||||
// _ = functionResponse.Get("name").String() // functionName not used for now
|
||||
// if len(toolCallIDs) > 0 {
|
||||
// // Use the last tool call ID (simple matching by function name)
|
||||
// // In a real implementation, you might want more sophisticated matching
|
||||
// toolMsg["tool_call_id"] = toolCallIDs[len(toolCallIDs)-1]
|
||||
// } else {
|
||||
// // Generate a tool call ID if none available
|
||||
// toolMsg["tool_call_id"] = genToolCallID()
|
||||
// }
|
||||
//
|
||||
// openAIMessages = append(openAIMessages, toolMsg)
|
||||
// }
|
||||
//
|
||||
// return true
|
||||
// })
|
||||
// }
|
||||
// }
|
||||
return true
|
||||
})
|
||||
}
|
||||
|
||||
// Set messages
|
||||
if len(openAIMessages) > 0 {
|
||||
messagesJSON, _ := json.Marshal(openAIMessages)
|
||||
out, _ = sjson.SetRaw(out, "messages", string(messagesJSON))
|
||||
}
|
||||
|
||||
// Tools mapping: Gemini tools -> OpenAI tools
|
||||
if tools := root.Get("tools"); tools.Exists() && tools.IsArray() {
|
||||
var openAITools []interface{}
|
||||
tools.ForEach(func(_, tool gjson.Result) bool {
|
||||
if functionDeclarations := tool.Get("functionDeclarations"); functionDeclarations.Exists() && functionDeclarations.IsArray() {
|
||||
functionDeclarations.ForEach(func(_, funcDecl gjson.Result) bool {
|
||||
openAITool := map[string]interface{}{
|
||||
"type": "function",
|
||||
"function": map[string]interface{}{
|
||||
"name": funcDecl.Get("name").String(),
|
||||
"description": funcDecl.Get("description").String(),
|
||||
},
|
||||
}
|
||||
|
||||
// Convert parameters schema
|
||||
if parameters := funcDecl.Get("parameters"); parameters.Exists() {
|
||||
openAITool["function"].(map[string]interface{})["parameters"] = parameters.Value()
|
||||
} else if parameters = funcDecl.Get("parametersJsonSchema"); parameters.Exists() {
|
||||
openAITool["function"].(map[string]interface{})["parameters"] = parameters.Value()
|
||||
}
|
||||
|
||||
openAITools = append(openAITools, openAITool)
|
||||
return true
|
||||
})
|
||||
}
|
||||
return true
|
||||
})
|
||||
|
||||
if len(openAITools) > 0 {
|
||||
toolsJSON, _ := json.Marshal(openAITools)
|
||||
out, _ = sjson.SetRaw(out, "tools", string(toolsJSON))
|
||||
}
|
||||
}
|
||||
|
||||
// Tool choice mapping (Gemini doesn't have direct equivalent, but we can handle it)
|
||||
if toolConfig := root.Get("toolConfig"); toolConfig.Exists() {
|
||||
if functionCallingConfig := toolConfig.Get("functionCallingConfig"); functionCallingConfig.Exists() {
|
||||
mode := functionCallingConfig.Get("mode").String()
|
||||
switch mode {
|
||||
case "NONE":
|
||||
out, _ = sjson.Set(out, "tool_choice", "none")
|
||||
case "AUTO":
|
||||
out, _ = sjson.Set(out, "tool_choice", "auto")
|
||||
case "ANY":
|
||||
out, _ = sjson.Set(out, "tool_choice", "required")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return out
|
||||
}
|
||||
353
internal/translator/openai/gemini/openai_gemini_response.go
Normal file
353
internal/translator/openai/gemini/openai_gemini_response.go
Normal file
@@ -0,0 +1,353 @@
|
||||
// Package gemini provides response translation functionality for OpenAI to Gemini API.
|
||||
// This package handles the conversion of OpenAI Chat Completions API responses into Gemini API-compatible
|
||||
// JSON format, transforming streaming events and non-streaming responses into the format
|
||||
// expected by Gemini API clients. It supports both streaming and non-streaming modes,
|
||||
// handling text content, tool calls, and usage metadata appropriately.
|
||||
package gemini
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"strings"
|
||||
|
||||
"github.com/tidwall/gjson"
|
||||
"github.com/tidwall/sjson"
|
||||
)
|
||||
|
||||
// ConvertOpenAIResponseToGeminiParams holds parameters for response conversion
|
||||
type ConvertOpenAIResponseToGeminiParams struct {
|
||||
// Tool calls accumulator for streaming
|
||||
ToolCallsAccumulator map[int]*ToolCallAccumulator
|
||||
// Content accumulator for streaming
|
||||
ContentAccumulator strings.Builder
|
||||
// Track if this is the first chunk
|
||||
IsFirstChunk bool
|
||||
}
|
||||
|
||||
// ToolCallAccumulator holds the state for accumulating tool call data
|
||||
type ToolCallAccumulator struct {
|
||||
ID string
|
||||
Name string
|
||||
Arguments strings.Builder
|
||||
}
|
||||
|
||||
// ConvertOpenAIResponseToGemini converts OpenAI Chat Completions streaming response format to Gemini API format.
|
||||
// This function processes OpenAI streaming chunks and transforms them into Gemini-compatible JSON responses.
|
||||
// It handles text content, tool calls, and usage metadata, outputting responses that match the Gemini API format.
|
||||
func ConvertOpenAIResponseToGemini(rawJSON []byte, param *ConvertOpenAIResponseToGeminiParams) []string {
|
||||
// Handle [DONE] marker
|
||||
if strings.TrimSpace(string(rawJSON)) == "[DONE]" {
|
||||
return []string{}
|
||||
}
|
||||
|
||||
root := gjson.ParseBytes(rawJSON)
|
||||
|
||||
// Initialize accumulators if needed
|
||||
if param.ToolCallsAccumulator == nil {
|
||||
param.ToolCallsAccumulator = make(map[int]*ToolCallAccumulator)
|
||||
}
|
||||
|
||||
// Process choices
|
||||
if choices := root.Get("choices"); choices.Exists() && choices.IsArray() {
|
||||
// Handle empty choices array (usage-only chunk)
|
||||
if len(choices.Array()) == 0 {
|
||||
// This is a usage-only chunk, handle usage and return
|
||||
if usage := root.Get("usage"); usage.Exists() {
|
||||
template := `{"candidates":[],"usageMetadata":{}}`
|
||||
|
||||
// Set model if available
|
||||
if model := root.Get("model"); model.Exists() {
|
||||
template, _ = sjson.Set(template, "model", model.String())
|
||||
}
|
||||
|
||||
usageObj := map[string]interface{}{
|
||||
"promptTokenCount": usage.Get("prompt_tokens").Int(),
|
||||
"candidatesTokenCount": usage.Get("completion_tokens").Int(),
|
||||
"totalTokenCount": usage.Get("total_tokens").Int(),
|
||||
}
|
||||
template, _ = sjson.Set(template, "usageMetadata", usageObj)
|
||||
return []string{template}
|
||||
}
|
||||
return []string{}
|
||||
}
|
||||
|
||||
var results []string
|
||||
|
||||
choices.ForEach(func(choiceIndex, choice gjson.Result) bool {
|
||||
// Base Gemini response template
|
||||
template := `{"candidates":[{"content":{"parts":[],"role":"model"},"finishReason":"STOP","index":0}]}`
|
||||
|
||||
// Set model if available
|
||||
if model := root.Get("model"); model.Exists() {
|
||||
template, _ = sjson.Set(template, "model", model.String())
|
||||
}
|
||||
|
||||
_ = int(choice.Get("index").Int()) // choiceIdx not used in streaming
|
||||
delta := choice.Get("delta")
|
||||
|
||||
// Handle role (only in first chunk)
|
||||
if role := delta.Get("role"); role.Exists() && param.IsFirstChunk {
|
||||
// OpenAI assistant -> Gemini model
|
||||
if role.String() == "assistant" {
|
||||
template, _ = sjson.Set(template, "candidates.0.content.role", "model")
|
||||
}
|
||||
param.IsFirstChunk = false
|
||||
results = append(results, template)
|
||||
return true
|
||||
}
|
||||
|
||||
// Handle content delta
|
||||
if content := delta.Get("content"); content.Exists() && content.String() != "" {
|
||||
contentText := content.String()
|
||||
param.ContentAccumulator.WriteString(contentText)
|
||||
|
||||
// Create text part for this delta
|
||||
parts := []interface{}{
|
||||
map[string]interface{}{
|
||||
"text": contentText,
|
||||
},
|
||||
}
|
||||
template, _ = sjson.Set(template, "candidates.0.content.parts", parts)
|
||||
results = append(results, template)
|
||||
return true
|
||||
}
|
||||
|
||||
// Handle tool calls delta
|
||||
if toolCalls := delta.Get("tool_calls"); toolCalls.Exists() && toolCalls.IsArray() {
|
||||
toolCalls.ForEach(func(_, toolCall gjson.Result) bool {
|
||||
toolIndex := int(toolCall.Get("index").Int())
|
||||
toolID := toolCall.Get("id").String()
|
||||
toolType := toolCall.Get("type").String()
|
||||
|
||||
if toolType == "function" {
|
||||
function := toolCall.Get("function")
|
||||
functionName := function.Get("name").String()
|
||||
functionArgs := function.Get("arguments").String()
|
||||
|
||||
// Initialize accumulator if needed
|
||||
if _, exists := param.ToolCallsAccumulator[toolIndex]; !exists {
|
||||
param.ToolCallsAccumulator[toolIndex] = &ToolCallAccumulator{
|
||||
ID: toolID,
|
||||
Name: functionName,
|
||||
}
|
||||
}
|
||||
|
||||
// Update ID if provided
|
||||
if toolID != "" {
|
||||
param.ToolCallsAccumulator[toolIndex].ID = toolID
|
||||
}
|
||||
|
||||
// Update name if provided
|
||||
if functionName != "" {
|
||||
param.ToolCallsAccumulator[toolIndex].Name = functionName
|
||||
}
|
||||
|
||||
// Accumulate arguments
|
||||
if functionArgs != "" {
|
||||
param.ToolCallsAccumulator[toolIndex].Arguments.WriteString(functionArgs)
|
||||
}
|
||||
}
|
||||
return true
|
||||
})
|
||||
|
||||
// Don't output anything for tool call deltas - wait for completion
|
||||
return true
|
||||
}
|
||||
|
||||
// Handle finish reason
|
||||
if finishReason := choice.Get("finish_reason"); finishReason.Exists() {
|
||||
geminiFinishReason := mapOpenAIFinishReasonToGemini(finishReason.String())
|
||||
template, _ = sjson.Set(template, "candidates.0.finishReason", geminiFinishReason)
|
||||
|
||||
// If we have accumulated tool calls, output them now
|
||||
if len(param.ToolCallsAccumulator) > 0 {
|
||||
var parts []interface{}
|
||||
for _, accumulator := range param.ToolCallsAccumulator {
|
||||
argsStr := accumulator.Arguments.String()
|
||||
var argsMap map[string]interface{}
|
||||
|
||||
if argsStr != "" && argsStr != "{}" {
|
||||
// Handle malformed JSON by trying to fix common issues
|
||||
fixedArgs := argsStr
|
||||
// Fix unquoted keys and values (common in the sample)
|
||||
if strings.Contains(fixedArgs, "北京") && !strings.Contains(fixedArgs, "\"北京\"") {
|
||||
fixedArgs = strings.ReplaceAll(fixedArgs, "北京", "\"北京\"")
|
||||
}
|
||||
if strings.Contains(fixedArgs, "celsius") && !strings.Contains(fixedArgs, "\"celsius\"") {
|
||||
fixedArgs = strings.ReplaceAll(fixedArgs, "celsius", "\"celsius\"")
|
||||
}
|
||||
|
||||
if err := json.Unmarshal([]byte(fixedArgs), &argsMap); err != nil {
|
||||
// If still fails, try to parse as raw string
|
||||
if err2 := json.Unmarshal([]byte("\""+argsStr+"\""), &argsMap); err2 != nil {
|
||||
// Last resort: use empty object
|
||||
argsMap = map[string]interface{}{}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
argsMap = map[string]interface{}{}
|
||||
}
|
||||
|
||||
functionCallPart := map[string]interface{}{
|
||||
"functionCall": map[string]interface{}{
|
||||
"name": accumulator.Name,
|
||||
"args": argsMap,
|
||||
},
|
||||
}
|
||||
parts = append(parts, functionCallPart)
|
||||
}
|
||||
|
||||
if len(parts) > 0 {
|
||||
template, _ = sjson.Set(template, "candidates.0.content.parts", parts)
|
||||
}
|
||||
|
||||
// Clear accumulators
|
||||
param.ToolCallsAccumulator = make(map[int]*ToolCallAccumulator)
|
||||
}
|
||||
|
||||
results = append(results, template)
|
||||
return true
|
||||
}
|
||||
|
||||
// Handle usage information
|
||||
if usage := root.Get("usage"); usage.Exists() {
|
||||
usageObj := map[string]interface{}{
|
||||
"promptTokenCount": usage.Get("prompt_tokens").Int(),
|
||||
"candidatesTokenCount": usage.Get("completion_tokens").Int(),
|
||||
"totalTokenCount": usage.Get("total_tokens").Int(),
|
||||
}
|
||||
template, _ = sjson.Set(template, "usageMetadata", usageObj)
|
||||
results = append(results, template)
|
||||
return true
|
||||
}
|
||||
|
||||
return true
|
||||
})
|
||||
return results
|
||||
}
|
||||
return []string{}
|
||||
}
|
||||
|
||||
// mapOpenAIFinishReasonToGemini maps OpenAI finish reasons to Gemini finish reasons
|
||||
func mapOpenAIFinishReasonToGemini(openAIReason string) string {
|
||||
switch openAIReason {
|
||||
case "stop":
|
||||
return "STOP"
|
||||
case "length":
|
||||
return "MAX_TOKENS"
|
||||
case "tool_calls":
|
||||
return "STOP" // Gemini doesn't have a specific tool_calls finish reason
|
||||
case "content_filter":
|
||||
return "SAFETY"
|
||||
default:
|
||||
return "STOP"
|
||||
}
|
||||
}
|
||||
|
||||
// ConvertOpenAINonStreamResponseToGemini converts OpenAI non-streaming response to Gemini format
|
||||
func ConvertOpenAINonStreamResponseToGemini(rawJSON []byte) string {
|
||||
root := gjson.ParseBytes(rawJSON)
|
||||
|
||||
// Base Gemini response template
|
||||
out := `{"candidates":[{"content":{"parts":[],"role":"model"},"finishReason":"STOP","index":0}]}`
|
||||
|
||||
// Set model if available
|
||||
if model := root.Get("model"); model.Exists() {
|
||||
out, _ = sjson.Set(out, "model", model.String())
|
||||
}
|
||||
|
||||
// Process choices
|
||||
if choices := root.Get("choices"); choices.Exists() && choices.IsArray() {
|
||||
choices.ForEach(func(choiceIndex, choice gjson.Result) bool {
|
||||
choiceIdx := int(choice.Get("index").Int())
|
||||
message := choice.Get("message")
|
||||
|
||||
// Set role
|
||||
if role := message.Get("role"); role.Exists() {
|
||||
if role.String() == "assistant" {
|
||||
out, _ = sjson.Set(out, "candidates.0.content.role", "model")
|
||||
}
|
||||
}
|
||||
|
||||
var parts []interface{}
|
||||
|
||||
// Handle content first
|
||||
if content := message.Get("content"); content.Exists() && content.String() != "" {
|
||||
parts = append(parts, map[string]interface{}{
|
||||
"text": content.String(),
|
||||
})
|
||||
}
|
||||
|
||||
// Handle tool calls
|
||||
if toolCalls := message.Get("tool_calls"); toolCalls.Exists() && toolCalls.IsArray() {
|
||||
toolCalls.ForEach(func(_, toolCall gjson.Result) bool {
|
||||
if toolCall.Get("type").String() == "function" {
|
||||
function := toolCall.Get("function")
|
||||
functionName := function.Get("name").String()
|
||||
functionArgs := function.Get("arguments").String()
|
||||
|
||||
// Parse arguments
|
||||
var argsMap map[string]interface{}
|
||||
if functionArgs != "" && functionArgs != "{}" {
|
||||
// Handle malformed JSON by trying to fix common issues
|
||||
fixedArgs := functionArgs
|
||||
// Fix unquoted keys and values (common in the sample)
|
||||
if strings.Contains(fixedArgs, "北京") && !strings.Contains(fixedArgs, "\"北京\"") {
|
||||
fixedArgs = strings.ReplaceAll(fixedArgs, "北京", "\"北京\"")
|
||||
}
|
||||
if strings.Contains(fixedArgs, "celsius") && !strings.Contains(fixedArgs, "\"celsius\"") {
|
||||
fixedArgs = strings.ReplaceAll(fixedArgs, "celsius", "\"celsius\"")
|
||||
}
|
||||
|
||||
if err := json.Unmarshal([]byte(fixedArgs), &argsMap); err != nil {
|
||||
// If still fails, try to parse as raw string
|
||||
if err2 := json.Unmarshal([]byte("\""+functionArgs+"\""), &argsMap); err2 != nil {
|
||||
// Last resort: use empty object
|
||||
argsMap = map[string]interface{}{}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
argsMap = map[string]interface{}{}
|
||||
}
|
||||
|
||||
functionCallPart := map[string]interface{}{
|
||||
"functionCall": map[string]interface{}{
|
||||
"name": functionName,
|
||||
"args": argsMap,
|
||||
},
|
||||
}
|
||||
parts = append(parts, functionCallPart)
|
||||
}
|
||||
return true
|
||||
})
|
||||
}
|
||||
|
||||
// Set parts
|
||||
if len(parts) > 0 {
|
||||
out, _ = sjson.Set(out, "candidates.0.content.parts", parts)
|
||||
}
|
||||
|
||||
// Handle finish reason
|
||||
if finishReason := choice.Get("finish_reason"); finishReason.Exists() {
|
||||
geminiFinishReason := mapOpenAIFinishReasonToGemini(finishReason.String())
|
||||
out, _ = sjson.Set(out, "candidates.0.finishReason", geminiFinishReason)
|
||||
}
|
||||
|
||||
// Set index
|
||||
out, _ = sjson.Set(out, "candidates.0.index", choiceIdx)
|
||||
|
||||
return true
|
||||
})
|
||||
}
|
||||
|
||||
// Handle usage information
|
||||
if usage := root.Get("usage"); usage.Exists() {
|
||||
usageObj := map[string]interface{}{
|
||||
"promptTokenCount": usage.Get("prompt_tokens").Int(),
|
||||
"candidatesTokenCount": usage.Get("completion_tokens").Int(),
|
||||
"totalTokenCount": usage.Get("total_tokens").Int(),
|
||||
}
|
||||
out, _ = sjson.Set(out, "usageMetadata", usageObj)
|
||||
}
|
||||
|
||||
return out
|
||||
}
|
||||
@@ -23,6 +23,8 @@ func GetProviderName(modelName string) string {
|
||||
return "gpt"
|
||||
} else if strings.HasPrefix(modelName, "claude") {
|
||||
return "claude"
|
||||
} else if strings.HasPrefix(modelName, "qwen") {
|
||||
return "qwen"
|
||||
}
|
||||
return "unknow"
|
||||
}
|
||||
|
||||
@@ -19,6 +19,7 @@ import (
|
||||
"github.com/luispater/CLIProxyAPI/internal/auth/claude"
|
||||
"github.com/luispater/CLIProxyAPI/internal/auth/codex"
|
||||
"github.com/luispater/CLIProxyAPI/internal/auth/gemini"
|
||||
"github.com/luispater/CLIProxyAPI/internal/auth/qwen"
|
||||
"github.com/luispater/CLIProxyAPI/internal/client"
|
||||
"github.com/luispater/CLIProxyAPI/internal/config"
|
||||
"github.com/luispater/CLIProxyAPI/internal/util"
|
||||
@@ -281,6 +282,20 @@ func (w *Watcher) reloadClients() {
|
||||
} else {
|
||||
log.Errorf(" failed to decode token file %s: %v", path, err)
|
||||
}
|
||||
} else if tokenType == "qwen" {
|
||||
var ts qwen.QwenTokenStorage
|
||||
if err = json.Unmarshal(data, &ts); err == nil {
|
||||
// For each valid token, create an authenticated client
|
||||
log.Debugf(" initializing qwen authentication for token from %s...", filepath.Base(path))
|
||||
qwenClient := client.NewQwenClient(cfg, &ts)
|
||||
log.Debugf(" authentication successful for token from %s", filepath.Base(path))
|
||||
|
||||
// Add the new client to the pool
|
||||
newClients = append(newClients, qwenClient)
|
||||
successfulAuthCount++
|
||||
} else {
|
||||
log.Errorf(" failed to decode token file %s: %v", path, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
|
||||
Reference in New Issue
Block a user