mirror of
https://github.com/router-for-me/CLIProxyAPI.git
synced 2026-02-02 20:40:52 +08:00
Numerous Comments Added and Extensive Optimization Performed using Roo-Code with CLIProxyAPI itself.
This commit is contained in:
@@ -2,14 +2,12 @@ package api
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"github.com/luispater/CLIProxyAPI/internal/api/translator"
|
||||
"github.com/luispater/CLIProxyAPI/internal/client"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"github.com/tidwall/gjson"
|
||||
"github.com/tidwall/sjson"
|
||||
"net/http"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
@@ -21,13 +19,15 @@ var (
|
||||
lastUsedClientIndex = 0
|
||||
)
|
||||
|
||||
// APIHandlers contains the handlers for API endpoints
|
||||
// APIHandlers contains the handlers for API endpoints.
|
||||
// It holds a pool of clients to interact with the backend service.
|
||||
type APIHandlers struct {
|
||||
cliClients []*client.Client
|
||||
debug bool
|
||||
}
|
||||
|
||||
// NewAPIHandlers creates a new API handlers instance
|
||||
// NewAPIHandlers creates a new API handlers instance.
|
||||
// It takes a slice of clients and a debug flag as input.
|
||||
func NewAPIHandlers(cliClients []*client.Client, debug bool) *APIHandlers {
|
||||
return &APIHandlers{
|
||||
cliClients: cliClients,
|
||||
@@ -35,6 +35,8 @@ func NewAPIHandlers(cliClients []*client.Client, debug bool) *APIHandlers {
|
||||
}
|
||||
}
|
||||
|
||||
// Models handles the /v1/models endpoint.
|
||||
// It returns a hardcoded list of available AI models.
|
||||
func (h *APIHandlers) Models(c *gin.Context) {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"data": []map[string]any{
|
||||
@@ -162,15 +164,23 @@ func (h *APIHandlers) Models(c *gin.Context) {
|
||||
})
|
||||
}
|
||||
|
||||
// ChatCompletions handles the /v1/chat/completions endpoint
|
||||
// ChatCompletions handles the /v1/chat/completions endpoint.
|
||||
// It determines whether the request is for a streaming or non-streaming response
|
||||
// and calls the appropriate handler.
|
||||
func (h *APIHandlers) ChatCompletions(c *gin.Context) {
|
||||
rawJson, err := c.GetRawData()
|
||||
// If data retrieval fails, return 400 error
|
||||
// If data retrieval fails, return a 400 Bad Request error.
|
||||
if err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("Invalid request: %v", err), "code": 400})
|
||||
c.JSON(http.StatusBadRequest, ErrorResponse{
|
||||
Error: ErrorDetail{
|
||||
Message: fmt.Sprintf("Invalid request: %v", err),
|
||||
Type: "invalid_request_error",
|
||||
},
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// Check if the client requested a streaming response.
|
||||
streamResult := gjson.GetBytes(rawJson, "stream")
|
||||
if streamResult.Type == gjson.True {
|
||||
h.handleStreamingResponse(c, rawJson)
|
||||
@@ -179,184 +189,9 @@ func (h *APIHandlers) ChatCompletions(c *gin.Context) {
|
||||
}
|
||||
}
|
||||
|
||||
func (h *APIHandlers) prepareRequest(rawJson []byte) (string, []client.Content, []client.ToolDeclaration) {
|
||||
// log.Debug(string(rawJson))
|
||||
modelName := "gemini-2.5-pro"
|
||||
modelResult := gjson.GetBytes(rawJson, "model")
|
||||
if modelResult.Type == gjson.String {
|
||||
modelName = modelResult.String()
|
||||
}
|
||||
|
||||
contents := make([]client.Content, 0)
|
||||
messagesResult := gjson.GetBytes(rawJson, "messages")
|
||||
if messagesResult.IsArray() {
|
||||
messagesResults := messagesResult.Array()
|
||||
for i := 0; i < len(messagesResults); i++ {
|
||||
messageResult := messagesResults[i]
|
||||
roleResult := messageResult.Get("role")
|
||||
contentResult := messageResult.Get("content")
|
||||
if roleResult.Type == gjson.String {
|
||||
if roleResult.String() == "system" {
|
||||
if contentResult.Type == gjson.String {
|
||||
contents = append(contents, client.Content{Role: "user", Parts: []client.Part{{Text: contentResult.String()}}})
|
||||
} else if contentResult.IsObject() {
|
||||
contentTypeResult := contentResult.Get("type")
|
||||
if contentTypeResult.Type == gjson.String && contentTypeResult.String() == "text" {
|
||||
contentTextResult := contentResult.Get("text")
|
||||
if contentTextResult.Type == gjson.String {
|
||||
contents = append(contents, client.Content{Role: "user", Parts: []client.Part{{Text: contentTextResult.String()}}})
|
||||
contents = append(contents, client.Content{Role: "model", Parts: []client.Part{{Text: "Understood. I will follow these instructions and use my tools to assist you."}}})
|
||||
}
|
||||
}
|
||||
}
|
||||
} else if roleResult.String() == "user" {
|
||||
if contentResult.Type == gjson.String {
|
||||
contents = append(contents, client.Content{Role: "user", Parts: []client.Part{{Text: contentResult.String()}}})
|
||||
} else if contentResult.IsObject() {
|
||||
contentTypeResult := contentResult.Get("type")
|
||||
if contentTypeResult.Type == gjson.String && contentTypeResult.String() == "text" {
|
||||
contentTextResult := contentResult.Get("text")
|
||||
if contentTextResult.Type == gjson.String {
|
||||
contents = append(contents, client.Content{Role: "user", Parts: []client.Part{{Text: contentTextResult.String()}}})
|
||||
}
|
||||
}
|
||||
} else if contentResult.IsArray() {
|
||||
contentItemResults := contentResult.Array()
|
||||
parts := make([]client.Part, 0)
|
||||
for j := 0; j < len(contentItemResults); j++ {
|
||||
contentItemResult := contentItemResults[j]
|
||||
contentTypeResult := contentItemResult.Get("type")
|
||||
if contentTypeResult.Type == gjson.String && contentTypeResult.String() == "text" {
|
||||
contentTextResult := contentItemResult.Get("text")
|
||||
if contentTextResult.Type == gjson.String {
|
||||
parts = append(parts, client.Part{Text: contentTextResult.String()})
|
||||
}
|
||||
} else if contentTypeResult.Type == gjson.String && contentTypeResult.String() == "image_url" {
|
||||
imageURLResult := contentItemResult.Get("image_url.url")
|
||||
if imageURLResult.Type == gjson.String {
|
||||
imageURL := imageURLResult.String()
|
||||
if len(imageURL) > 5 {
|
||||
imageURLs := strings.SplitN(imageURL[5:], ";", 2)
|
||||
if len(imageURLs) == 2 {
|
||||
if len(imageURLs[1]) > 7 {
|
||||
parts = append(parts, client.Part{InlineData: &client.InlineData{
|
||||
MimeType: imageURLs[0],
|
||||
Data: imageURLs[1][7:],
|
||||
}})
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
} else if contentTypeResult.Type == gjson.String && contentTypeResult.String() == "file" {
|
||||
filenameResult := contentItemResult.Get("file.filename")
|
||||
fileDataResult := contentItemResult.Get("file.file_data")
|
||||
if filenameResult.Type == gjson.String && fileDataResult.Type == gjson.String {
|
||||
filename := filenameResult.String()
|
||||
splitFilename := strings.Split(filename, ".")
|
||||
ext := splitFilename[len(splitFilename)-1]
|
||||
|
||||
mimeType, ok := MimeTypes[ext]
|
||||
if !ok {
|
||||
log.Warnf("Unknown file name extension '%s' at index %d, skipping file", ext, j)
|
||||
continue
|
||||
}
|
||||
|
||||
parts = append(parts, client.Part{InlineData: &client.InlineData{
|
||||
MimeType: mimeType,
|
||||
Data: fileDataResult.String(),
|
||||
}})
|
||||
}
|
||||
}
|
||||
}
|
||||
contents = append(contents, client.Content{Role: "user", Parts: parts})
|
||||
}
|
||||
} else if roleResult.String() == "assistant" {
|
||||
if contentResult.Type == gjson.String {
|
||||
contents = append(contents, client.Content{Role: "model", Parts: []client.Part{{Text: contentResult.String()}}})
|
||||
} else if contentResult.IsObject() {
|
||||
contentTypeResult := contentResult.Get("type")
|
||||
if contentTypeResult.Type == gjson.String && contentTypeResult.String() == "text" {
|
||||
contentTextResult := contentResult.Get("text")
|
||||
if contentTextResult.Type == gjson.String {
|
||||
contents = append(contents, client.Content{Role: "user", Parts: []client.Part{{Text: contentTextResult.String()}}})
|
||||
}
|
||||
}
|
||||
} else if !contentResult.Exists() || contentResult.Type == gjson.Null {
|
||||
toolCallsResult := messageResult.Get("tool_calls")
|
||||
if toolCallsResult.IsArray() {
|
||||
tcsResult := toolCallsResult.Array()
|
||||
for j := 0; j < len(tcsResult); j++ {
|
||||
tcResult := tcsResult[j]
|
||||
functionNameResult := tcResult.Get("function.name")
|
||||
functionArguments := tcResult.Get("function.arguments")
|
||||
if functionNameResult.Exists() && functionNameResult.Type == gjson.String && functionArguments.Exists() && functionArguments.Type == gjson.String {
|
||||
var args map[string]any
|
||||
err := json.Unmarshal([]byte(functionArguments.String()), &args)
|
||||
if err == nil {
|
||||
contents = append(contents, client.Content{
|
||||
Role: "model", Parts: []client.Part{
|
||||
{
|
||||
FunctionCall: &client.FunctionCall{
|
||||
Name: functionNameResult.String(),
|
||||
Args: args,
|
||||
},
|
||||
},
|
||||
},
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
} else if roleResult.String() == "tool" {
|
||||
toolCallIDResult := messageResult.Get("tool_call_id")
|
||||
if toolCallIDResult.Exists() && toolCallIDResult.Type == gjson.String {
|
||||
if contentResult.Type == gjson.String {
|
||||
functionResponse := client.FunctionResponse{Name: toolCallIDResult.String(), Response: map[string]interface{}{"result": contentResult.String()}}
|
||||
contents = append(contents, client.Content{Role: "tool", Parts: []client.Part{{FunctionResponse: &functionResponse}}})
|
||||
} else if contentResult.IsObject() {
|
||||
contentTypeResult := contentResult.Get("type")
|
||||
if contentTypeResult.Type == gjson.String && contentTypeResult.String() == "text" {
|
||||
contentTextResult := contentResult.Get("text")
|
||||
if contentTextResult.Type == gjson.String {
|
||||
functionResponse := client.FunctionResponse{Name: toolCallIDResult.String(), Response: map[string]interface{}{"result": contentResult.String()}}
|
||||
contents = append(contents, client.Content{Role: "tool", Parts: []client.Part{{FunctionResponse: &functionResponse}}})
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
var tools []client.ToolDeclaration
|
||||
toolsResult := gjson.GetBytes(rawJson, "tools")
|
||||
if toolsResult.IsArray() {
|
||||
tools = make([]client.ToolDeclaration, 1)
|
||||
tools[0].FunctionDeclarations = make([]any, 0)
|
||||
toolsResults := toolsResult.Array()
|
||||
for i := 0; i < len(toolsResults); i++ {
|
||||
toolTypeResult := toolsResults[i].Get("type")
|
||||
if toolTypeResult.Type != gjson.String || toolTypeResult.String() != "function" {
|
||||
continue
|
||||
}
|
||||
functionTypeResult := toolsResults[i].Get("function")
|
||||
if functionTypeResult.Exists() && functionTypeResult.IsObject() {
|
||||
var functionDeclaration any
|
||||
err := json.Unmarshal([]byte(functionTypeResult.Raw), &functionDeclaration)
|
||||
if err == nil {
|
||||
tools[0].FunctionDeclarations = append(tools[0].FunctionDeclarations, functionDeclaration)
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
tools = make([]client.ToolDeclaration, 0)
|
||||
}
|
||||
return modelName, contents, tools
|
||||
}
|
||||
|
||||
// handleNonStreamingResponse handles non-streaming responses
|
||||
// handleNonStreamingResponse handles non-streaming chat completion responses.
|
||||
// It selects a client from the pool, sends the request, and aggregates the response
|
||||
// before sending it back to the client.
|
||||
func (h *APIHandlers) handleNonStreamingResponse(c *gin.Context, rawJson []byte) {
|
||||
c.Header("Content-Type", "application/json")
|
||||
|
||||
@@ -372,7 +207,7 @@ func (h *APIHandlers) handleNonStreamingResponse(c *gin.Context, rawJson []byte)
|
||||
return
|
||||
}
|
||||
|
||||
modelName, contents, tools := h.prepareRequest(rawJson)
|
||||
modelName, contents, tools := translator.PrepareRequest(rawJson)
|
||||
cliCtx, cliCancel := context.WithCancel(context.Background())
|
||||
var cliClient *client.Client
|
||||
defer func() {
|
||||
@@ -425,19 +260,13 @@ func (h *APIHandlers) handleNonStreamingResponse(c *gin.Context, rawJson []byte)
|
||||
cliCancel()
|
||||
return
|
||||
} else {
|
||||
jsonTemplate = h.convertCliToOpenAINonStream(jsonTemplate, chunk)
|
||||
jsonTemplate = translator.ConvertCliToOpenAINonStream(jsonTemplate, chunk)
|
||||
}
|
||||
case err, okError := <-errChan:
|
||||
if okError {
|
||||
c.Status(err.StatusCode)
|
||||
_, _ = fmt.Fprint(c.Writer, err.Error.Error())
|
||||
flusher.Flush()
|
||||
// c.JSON(http.StatusInternalServerError, ErrorResponse{
|
||||
// Error: ErrorDetail{
|
||||
// Message: err.Error(),
|
||||
// Type: "server_error",
|
||||
// },
|
||||
// })
|
||||
cliCancel()
|
||||
return
|
||||
}
|
||||
@@ -455,7 +284,7 @@ func (h *APIHandlers) handleStreamingResponse(c *gin.Context, rawJson []byte) {
|
||||
c.Header("Connection", "keep-alive")
|
||||
c.Header("Access-Control-Allow-Origin", "*")
|
||||
|
||||
// Handle streaming manually
|
||||
// Get the http.Flusher interface to manually flush the response.
|
||||
flusher, ok := c.Writer.(http.Flusher)
|
||||
if !ok {
|
||||
c.JSON(http.StatusInternalServerError, ErrorResponse{
|
||||
@@ -466,28 +295,33 @@ func (h *APIHandlers) handleStreamingResponse(c *gin.Context, rawJson []byte) {
|
||||
})
|
||||
return
|
||||
}
|
||||
modelName, contents, tools := h.prepareRequest(rawJson)
|
||||
|
||||
// Prepare the request for the backend client.
|
||||
modelName, contents, tools := translator.PrepareRequest(rawJson)
|
||||
cliCtx, cliCancel := context.WithCancel(context.Background())
|
||||
var cliClient *client.Client
|
||||
defer func() {
|
||||
// Ensure the client's mutex is unlocked on function exit.
|
||||
if cliClient != nil {
|
||||
cliClient.RequestMutex.Unlock()
|
||||
}
|
||||
}()
|
||||
|
||||
// Lock the mutex to update the last used page index
|
||||
// Use a round-robin approach to select the next available client.
|
||||
// This distributes the load among the available clients.
|
||||
mutex.Lock()
|
||||
startIndex := lastUsedClientIndex
|
||||
currentIndex := (startIndex + 1) % len(h.cliClients)
|
||||
lastUsedClientIndex = currentIndex
|
||||
mutex.Unlock()
|
||||
|
||||
// Reorder the pages to start from the last used index
|
||||
// Reorder the clients to start from the next client in the rotation.
|
||||
reorderedPages := make([]*client.Client, len(h.cliClients))
|
||||
for i := 0; i < len(h.cliClients); i++ {
|
||||
reorderedPages[i] = h.cliClients[(startIndex+1+i)%len(h.cliClients)]
|
||||
}
|
||||
|
||||
// Attempt to lock a client for the request.
|
||||
locked := false
|
||||
for i := 0; i < len(reorderedPages); i++ {
|
||||
cliClient = reorderedPages[i]
|
||||
@@ -496,235 +330,52 @@ func (h *APIHandlers) handleStreamingResponse(c *gin.Context, rawJson []byte) {
|
||||
break
|
||||
}
|
||||
}
|
||||
// If no client is available, block and wait for the first client.
|
||||
if !locked {
|
||||
cliClient = h.cliClients[0]
|
||||
cliClient.RequestMutex.Lock()
|
||||
}
|
||||
|
||||
log.Debugf("Request use account: %s, project id: %s", cliClient.GetEmail(), cliClient.GetProjectID())
|
||||
// Send the message and receive response chunks and errors via channels.
|
||||
respChan, errChan := cliClient.SendMessageStream(cliCtx, rawJson, modelName, contents, tools)
|
||||
for {
|
||||
select {
|
||||
// Handle client disconnection.
|
||||
case <-c.Request.Context().Done():
|
||||
if c.Request.Context().Err().Error() == "context canceled" {
|
||||
log.Debugf("Client disconnected: %v", c.Request.Context().Err())
|
||||
cliCancel()
|
||||
cliCancel() // Cancel the backend request.
|
||||
return
|
||||
}
|
||||
// Process incoming response chunks.
|
||||
case chunk, okStream := <-respChan:
|
||||
if !okStream {
|
||||
// Stream is closed, send the final [DONE] message.
|
||||
_, _ = fmt.Fprintf(c.Writer, "data: [DONE]\n\n")
|
||||
flusher.Flush()
|
||||
cliCancel()
|
||||
return
|
||||
} else {
|
||||
openAIFormat := h.convertCliToOpenAI(chunk)
|
||||
// Convert the chunk to OpenAI format and send it to the client.
|
||||
openAIFormat := translator.ConvertCliToOpenAI(chunk)
|
||||
if openAIFormat != "" {
|
||||
_, _ = fmt.Fprintf(c.Writer, "data: %s\n\n", openAIFormat)
|
||||
flusher.Flush()
|
||||
}
|
||||
}
|
||||
// Handle errors from the backend.
|
||||
case err, okError := <-errChan:
|
||||
if okError {
|
||||
c.Status(err.StatusCode)
|
||||
_, _ = fmt.Fprint(c.Writer, err.Error.Error())
|
||||
flusher.Flush()
|
||||
// c.JSON(http.StatusInternalServerError, ErrorResponse{
|
||||
// Error: ErrorDetail{
|
||||
// Message: err.Error(),
|
||||
// Type: "server_error",
|
||||
// },
|
||||
// })
|
||||
cliCancel()
|
||||
return
|
||||
}
|
||||
// Send a keep-alive signal to the client.
|
||||
case <-time.After(500 * time.Millisecond):
|
||||
_, _ = c.Writer.Write([]byte(": CLI-PROXY-API PROCESSING\n\n"))
|
||||
flusher.Flush()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (h *APIHandlers) convertCliToOpenAI(rawJson []byte) string {
|
||||
// log.Debugf(string(rawJson))
|
||||
template := `{"id":"","object":"chat.completion.chunk","created":12345,"model":"model","choices":[{"index":0,"delta":{"role":null,"content":null,"reasoning_content":null,"tool_calls":null},"finish_reason":null,"native_finish_reason":null}]}`
|
||||
|
||||
modelVersionResult := gjson.GetBytes(rawJson, "response.modelVersion")
|
||||
if modelVersionResult.Exists() && modelVersionResult.Type == gjson.String {
|
||||
template, _ = sjson.Set(template, "model", modelVersionResult.String())
|
||||
}
|
||||
|
||||
createTimeResult := gjson.GetBytes(rawJson, "response.createTime")
|
||||
if createTimeResult.Exists() && createTimeResult.Type == gjson.String {
|
||||
t, err := time.Parse(time.RFC3339Nano, createTimeResult.String())
|
||||
var unixTimestamp int64
|
||||
if err == nil {
|
||||
unixTimestamp = t.Unix()
|
||||
} else {
|
||||
unixTimestamp = time.Now().Unix()
|
||||
}
|
||||
template, _ = sjson.Set(template, "created", unixTimestamp)
|
||||
}
|
||||
|
||||
responseIdResult := gjson.GetBytes(rawJson, "response.responseId")
|
||||
if responseIdResult.Exists() && responseIdResult.Type == gjson.String {
|
||||
template, _ = sjson.Set(template, "id", responseIdResult.String())
|
||||
}
|
||||
|
||||
finishReasonResult := gjson.GetBytes(rawJson, "response.candidates.0.finishReason")
|
||||
if finishReasonResult.Exists() && finishReasonResult.Type == gjson.String {
|
||||
template, _ = sjson.Set(template, "choices.0.finish_reason", finishReasonResult.String())
|
||||
template, _ = sjson.Set(template, "choices.0.native_finish_reason", finishReasonResult.String())
|
||||
}
|
||||
|
||||
usageResult := gjson.GetBytes(rawJson, "response.usageMetadata")
|
||||
candidatesTokenCountResult := usageResult.Get("candidatesTokenCount")
|
||||
if candidatesTokenCountResult.Exists() && candidatesTokenCountResult.Type == gjson.Number {
|
||||
template, _ = sjson.Set(template, "usage.completion_tokens", candidatesTokenCountResult.Int())
|
||||
}
|
||||
totalTokenCountResult := usageResult.Get("totalTokenCount")
|
||||
if totalTokenCountResult.Exists() && totalTokenCountResult.Type == gjson.Number {
|
||||
template, _ = sjson.Set(template, "usage.total_tokens", totalTokenCountResult.Int())
|
||||
}
|
||||
thoughtsTokenCountResult := usageResult.Get("thoughtsTokenCount")
|
||||
promptTokenCountResult := usageResult.Get("promptTokenCount")
|
||||
if promptTokenCountResult.Exists() && promptTokenCountResult.Type == gjson.Number {
|
||||
if thoughtsTokenCountResult.Exists() && thoughtsTokenCountResult.Type == gjson.Number {
|
||||
template, _ = sjson.Set(template, "usage.prompt_tokens", promptTokenCountResult.Int()+thoughtsTokenCountResult.Int())
|
||||
} else {
|
||||
template, _ = sjson.Set(template, "usage.prompt_tokens", promptTokenCountResult.Int())
|
||||
}
|
||||
}
|
||||
if thoughtsTokenCountResult.Exists() && thoughtsTokenCountResult.Type == gjson.Number {
|
||||
template, _ = sjson.Set(template, "usage.completion_tokens_details.reasoning_tokens", thoughtsTokenCountResult.Int())
|
||||
}
|
||||
|
||||
partResult := gjson.GetBytes(rawJson, "response.candidates.0.content.parts.0")
|
||||
partTextResult := partResult.Get("text")
|
||||
functionCallResult := partResult.Get("functionCall")
|
||||
|
||||
if partTextResult.Exists() && partTextResult.Type == gjson.String {
|
||||
partThoughtResult := partResult.Get("thought")
|
||||
if partThoughtResult.Exists() && partThoughtResult.Type == gjson.True {
|
||||
template, _ = sjson.Set(template, "choices.0.delta.reasoning_content", partTextResult.String())
|
||||
} else {
|
||||
template, _ = sjson.Set(template, "choices.0.delta.content", partTextResult.String())
|
||||
}
|
||||
template, _ = sjson.Set(template, "choices.0.delta.role", "assistant")
|
||||
} else if functionCallResult.Exists() {
|
||||
functionCallTemplate := `[{"id": "","type": "function","function": {"name": "","arguments": ""}}]`
|
||||
fcNameResult := functionCallResult.Get("name")
|
||||
if fcNameResult.Exists() && fcNameResult.Type == gjson.String {
|
||||
functionCallTemplate, _ = sjson.Set(functionCallTemplate, "0.id", fcNameResult.String())
|
||||
functionCallTemplate, _ = sjson.Set(functionCallTemplate, "0.function.name", fcNameResult.String())
|
||||
}
|
||||
fcArgsResult := functionCallResult.Get("args")
|
||||
if fcArgsResult.Exists() && fcArgsResult.IsObject() {
|
||||
functionCallTemplate, _ = sjson.Set(functionCallTemplate, "0.function.arguments", fcArgsResult.Raw)
|
||||
}
|
||||
template, _ = sjson.Set(template, "choices.0.delta.role", "assistant")
|
||||
template, _ = sjson.SetRaw(template, "choices.0.delta.tool_calls", functionCallTemplate)
|
||||
} else {
|
||||
return ""
|
||||
}
|
||||
|
||||
return template
|
||||
}
|
||||
|
||||
func (h *APIHandlers) convertCliToOpenAINonStream(template string, rawJson []byte) string {
|
||||
modelVersionResult := gjson.GetBytes(rawJson, "response.modelVersion")
|
||||
if modelVersionResult.Exists() && modelVersionResult.Type == gjson.String {
|
||||
template, _ = sjson.Set(template, "model", modelVersionResult.String())
|
||||
}
|
||||
|
||||
createTimeResult := gjson.GetBytes(rawJson, "response.createTime")
|
||||
if createTimeResult.Exists() && createTimeResult.Type == gjson.String {
|
||||
t, err := time.Parse(time.RFC3339Nano, createTimeResult.String())
|
||||
var unixTimestamp int64
|
||||
if err == nil {
|
||||
unixTimestamp = t.Unix()
|
||||
} else {
|
||||
unixTimestamp = time.Now().Unix()
|
||||
}
|
||||
template, _ = sjson.Set(template, "created", unixTimestamp)
|
||||
}
|
||||
|
||||
responseIdResult := gjson.GetBytes(rawJson, "response.responseId")
|
||||
if responseIdResult.Exists() && responseIdResult.Type == gjson.String {
|
||||
template, _ = sjson.Set(template, "id", responseIdResult.String())
|
||||
}
|
||||
|
||||
finishReasonResult := gjson.GetBytes(rawJson, "response.candidates.0.finishReason")
|
||||
if finishReasonResult.Exists() && finishReasonResult.Type == gjson.String {
|
||||
template, _ = sjson.Set(template, "choices.0.finish_reason", finishReasonResult.String())
|
||||
template, _ = sjson.Set(template, "choices.0.native_finish_reason", finishReasonResult.String())
|
||||
}
|
||||
|
||||
usageResult := gjson.GetBytes(rawJson, "response.usageMetadata")
|
||||
candidatesTokenCountResult := usageResult.Get("candidatesTokenCount")
|
||||
if candidatesTokenCountResult.Exists() && candidatesTokenCountResult.Type == gjson.Number {
|
||||
template, _ = sjson.Set(template, "usage.completion_tokens", candidatesTokenCountResult.Int())
|
||||
}
|
||||
totalTokenCountResult := usageResult.Get("totalTokenCount")
|
||||
if totalTokenCountResult.Exists() && totalTokenCountResult.Type == gjson.Number {
|
||||
template, _ = sjson.Set(template, "usage.total_tokens", totalTokenCountResult.Int())
|
||||
}
|
||||
thoughtsTokenCountResult := usageResult.Get("thoughtsTokenCount")
|
||||
promptTokenCountResult := usageResult.Get("promptTokenCount")
|
||||
if promptTokenCountResult.Exists() && promptTokenCountResult.Type == gjson.Number {
|
||||
if thoughtsTokenCountResult.Exists() && thoughtsTokenCountResult.Type == gjson.Number {
|
||||
template, _ = sjson.Set(template, "usage.prompt_tokens", promptTokenCountResult.Int()+thoughtsTokenCountResult.Int())
|
||||
} else {
|
||||
template, _ = sjson.Set(template, "usage.prompt_tokens", promptTokenCountResult.Int())
|
||||
}
|
||||
}
|
||||
if thoughtsTokenCountResult.Exists() && thoughtsTokenCountResult.Type == gjson.Number {
|
||||
template, _ = sjson.Set(template, "usage.completion_tokens_details.reasoning_tokens", thoughtsTokenCountResult.Int())
|
||||
}
|
||||
|
||||
partResult := gjson.GetBytes(rawJson, "response.candidates.0.content.parts.0")
|
||||
partTextResult := partResult.Get("text")
|
||||
functionCallResult := partResult.Get("functionCall")
|
||||
|
||||
if partTextResult.Exists() && partTextResult.Type == gjson.String {
|
||||
partThoughtResult := partResult.Get("thought")
|
||||
if partThoughtResult.Exists() && partThoughtResult.Type == gjson.True {
|
||||
reasoningContentResult := gjson.Get(template, "choices.0.message.reasoning_content")
|
||||
if reasoningContentResult.Type == gjson.String {
|
||||
template, _ = sjson.Set(template, "choices.0.message.reasoning_content", reasoningContentResult.String()+partTextResult.String())
|
||||
} else {
|
||||
template, _ = sjson.Set(template, "choices.0.message.reasoning_content", partTextResult.String())
|
||||
}
|
||||
} else {
|
||||
reasoningContentResult := gjson.Get(template, "choices.0.message.content")
|
||||
if reasoningContentResult.Type == gjson.String {
|
||||
template, _ = sjson.Set(template, "choices.0.message.content", reasoningContentResult.String()+partTextResult.String())
|
||||
} else {
|
||||
template, _ = sjson.Set(template, "choices.0.message.content", partTextResult.String())
|
||||
}
|
||||
}
|
||||
template, _ = sjson.Set(template, "choices.0.message.role", "assistant")
|
||||
} else if functionCallResult.Exists() {
|
||||
toolCallsResult := gjson.Get(template, "choices.0.message.tool_calls")
|
||||
if !toolCallsResult.Exists() || toolCallsResult.Type == gjson.Null {
|
||||
template, _ = sjson.SetRaw(template, "choices.0.message.tool_calls", `[]`)
|
||||
}
|
||||
|
||||
functionCallItemTemplate := `{"id": "","type": "function","function": {"name": "","arguments": ""}}`
|
||||
fcNameResult := functionCallResult.Get("name")
|
||||
if fcNameResult.Exists() && fcNameResult.Type == gjson.String {
|
||||
functionCallItemTemplate, _ = sjson.Set(functionCallItemTemplate, "id", fcNameResult.String())
|
||||
functionCallItemTemplate, _ = sjson.Set(functionCallItemTemplate, "function.name", fcNameResult.String())
|
||||
}
|
||||
fcArgsResult := functionCallResult.Get("args")
|
||||
if fcArgsResult.Exists() && fcArgsResult.IsObject() {
|
||||
functionCallItemTemplate, _ = sjson.Set(functionCallItemTemplate, "function.arguments", fcArgsResult.Raw)
|
||||
}
|
||||
template, _ = sjson.Set(template, "choices.0.message.role", "assistant")
|
||||
template, _ = sjson.SetRaw(template, "choices.0.message.tool_calls.-1", functionCallItemTemplate)
|
||||
} else {
|
||||
return ""
|
||||
}
|
||||
|
||||
return template
|
||||
}
|
||||
|
||||
@@ -1,13 +1,18 @@
|
||||
package api
|
||||
|
||||
// ErrorResponse represents an error response
|
||||
// ErrorResponse represents a standard error response format for the API.
|
||||
// It contains a single ErrorDetail field.
|
||||
type ErrorResponse struct {
|
||||
Error ErrorDetail `json:"error"`
|
||||
}
|
||||
|
||||
// ErrorDetail represents error details
|
||||
// ErrorDetail provides specific information about an error that occurred.
|
||||
// It includes a human-readable message, an error type, and an optional error code.
|
||||
type ErrorDetail struct {
|
||||
// A human-readable message providing more details about the error.
|
||||
Message string `json:"message"`
|
||||
Type string `json:"type"`
|
||||
Code string `json:"code,omitempty"`
|
||||
// The type of error that occurred (e.g., "invalid_request_error").
|
||||
Type string `json:"type"`
|
||||
// A short code identifying the error, if applicable.
|
||||
Code string `json:"code,omitempty"`
|
||||
}
|
||||
|
||||
@@ -11,7 +11,8 @@ import (
|
||||
"strings"
|
||||
)
|
||||
|
||||
// Server represents the API server
|
||||
// Server represents the main API server.
|
||||
// It encapsulates the Gin engine, HTTP server, handlers, and configuration.
|
||||
type Server struct {
|
||||
engine *gin.Engine
|
||||
server *http.Server
|
||||
@@ -19,14 +20,18 @@ type Server struct {
|
||||
cfg *ServerConfig
|
||||
}
|
||||
|
||||
// ServerConfig contains configuration for the API server
|
||||
// ServerConfig contains the configuration for the API server.
|
||||
type ServerConfig struct {
|
||||
Port string
|
||||
Debug bool
|
||||
// Port is the port number the server will listen on.
|
||||
Port string
|
||||
// Debug enables or disables debug mode for the server and Gin.
|
||||
Debug bool
|
||||
// ApiKeys is a list of valid API keys for authentication.
|
||||
ApiKeys []string
|
||||
}
|
||||
|
||||
// NewServer creates a new API server instance
|
||||
// NewServer creates and initializes a new API server instance.
|
||||
// It sets up the Gin engine, middleware, routes, and handlers.
|
||||
func NewServer(config *ServerConfig, cliClients []*client.Client) *Server {
|
||||
// Set gin mode
|
||||
if !config.Debug {
|
||||
@@ -63,7 +68,8 @@ func NewServer(config *ServerConfig, cliClients []*client.Client) *Server {
|
||||
return s
|
||||
}
|
||||
|
||||
// setupRoutes configures the API routes
|
||||
// setupRoutes configures the API routes for the server.
|
||||
// It defines the endpoints and associates them with their respective handlers.
|
||||
func (s *Server) setupRoutes() {
|
||||
// OpenAI compatible API routes
|
||||
v1 := s.engine.Group("/v1")
|
||||
@@ -86,11 +92,12 @@ func (s *Server) setupRoutes() {
|
||||
})
|
||||
}
|
||||
|
||||
// Start starts the API server
|
||||
// Start begins listening for and serving HTTP requests.
|
||||
// It's a blocking call and will only return on an unrecoverable error.
|
||||
func (s *Server) Start() error {
|
||||
log.Debugf("Starting API server on %s", s.server.Addr)
|
||||
|
||||
// Start the HTTP server
|
||||
// Start the HTTP server.
|
||||
if err := s.server.ListenAndServe(); err != nil && !errors.Is(err, http.ErrServerClosed) {
|
||||
return fmt.Errorf("failed to start HTTP server: %v", err)
|
||||
}
|
||||
@@ -98,11 +105,12 @@ func (s *Server) Start() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Stop gracefully stops the API server
|
||||
// Stop gracefully shuts down the API server without interrupting any
|
||||
// active connections.
|
||||
func (s *Server) Stop(ctx context.Context) error {
|
||||
log.Debug("Stopping API server...")
|
||||
|
||||
// Shutdown the HTTP server
|
||||
// Shutdown the HTTP server.
|
||||
if err := s.server.Shutdown(ctx); err != nil {
|
||||
return fmt.Errorf("failed to shutdown HTTP server: %v", err)
|
||||
}
|
||||
@@ -111,7 +119,8 @@ func (s *Server) Stop(ctx context.Context) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// corsMiddleware adds CORS headers
|
||||
// corsMiddleware returns a Gin middleware handler that adds CORS headers
|
||||
// to every response, allowing cross-origin requests.
|
||||
func corsMiddleware() gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
c.Header("Access-Control-Allow-Origin", "*")
|
||||
@@ -127,7 +136,8 @@ func corsMiddleware() gin.HandlerFunc {
|
||||
}
|
||||
}
|
||||
|
||||
// AuthMiddleware authenticates requests using API keys
|
||||
// AuthMiddleware returns a Gin middleware handler that authenticates requests
|
||||
// using API keys. If no API keys are configured, it allows all requests.
|
||||
func AuthMiddleware(cfg *ServerConfig) gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
if len(cfg.ApiKeys) == 0 {
|
||||
|
||||
@@ -1,5 +1,7 @@
|
||||
package api
|
||||
package translator
|
||||
|
||||
// MimeTypes is a comprehensive map of file extensions to their corresponding MIME types.
|
||||
// This is used to identify the type of file being uploaded or processed.
|
||||
var MimeTypes = map[string]string{
|
||||
"ez": "application/andrew-inset",
|
||||
"aw": "application/applixware",
|
||||
163
internal/api/translator/request.go
Normal file
163
internal/api/translator/request.go
Normal file
@@ -0,0 +1,163 @@
|
||||
package translator
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"strings"
|
||||
|
||||
"github.com/luispater/CLIProxyAPI/internal/client"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"github.com/tidwall/gjson"
|
||||
)
|
||||
|
||||
// PrepareRequest translates a raw JSON request from an OpenAI-compatible format
|
||||
// to the internal format expected by the backend client. It parses messages,
|
||||
// roles, content types (text, image, file), and tool calls.
|
||||
func PrepareRequest(rawJson []byte) (string, []client.Content, []client.ToolDeclaration) {
|
||||
// Extract the model name from the request, defaulting to "gemini-2.5-pro".
|
||||
modelName := "gemini-2.5-pro"
|
||||
modelResult := gjson.GetBytes(rawJson, "model")
|
||||
if modelResult.Type == gjson.String {
|
||||
modelName = modelResult.String()
|
||||
}
|
||||
|
||||
// Process the array of messages.
|
||||
contents := make([]client.Content, 0)
|
||||
messagesResult := gjson.GetBytes(rawJson, "messages")
|
||||
if messagesResult.IsArray() {
|
||||
messagesResults := messagesResult.Array()
|
||||
for i := 0; i < len(messagesResults); i++ {
|
||||
messageResult := messagesResults[i]
|
||||
roleResult := messageResult.Get("role")
|
||||
contentResult := messageResult.Get("content")
|
||||
if roleResult.Type != gjson.String {
|
||||
continue
|
||||
}
|
||||
|
||||
switch roleResult.String() {
|
||||
// System messages are converted to a user message followed by a model's acknowledgment.
|
||||
case "system":
|
||||
if contentResult.Type == gjson.String {
|
||||
contents = append(contents, client.Content{Role: "user", Parts: []client.Part{{Text: contentResult.String()}}})
|
||||
contents = append(contents, client.Content{Role: "model", Parts: []client.Part{{Text: "Understood. I will follow these instructions and use my tools to assist you."}}})
|
||||
} else if contentResult.IsObject() {
|
||||
// Handle object-based system messages.
|
||||
if contentResult.Get("type").String() == "text" {
|
||||
contents = append(contents, client.Content{Role: "user", Parts: []client.Part{{Text: contentResult.Get("text").String()}}})
|
||||
contents = append(contents, client.Content{Role: "model", Parts: []client.Part{{Text: "Understood. I will follow these instructions and use my tools to assist you."}}})
|
||||
}
|
||||
}
|
||||
// User messages can contain simple text or a multi-part body.
|
||||
case "user":
|
||||
if contentResult.Type == gjson.String {
|
||||
contents = append(contents, client.Content{Role: "user", Parts: []client.Part{{Text: contentResult.String()}}})
|
||||
} else if contentResult.IsArray() {
|
||||
// Handle multi-part user messages (text, images, files).
|
||||
contentItemResults := contentResult.Array()
|
||||
parts := make([]client.Part, 0)
|
||||
for j := 0; j < len(contentItemResults); j++ {
|
||||
contentItemResult := contentItemResults[j]
|
||||
contentTypeResult := contentItemResult.Get("type")
|
||||
switch contentTypeResult.String() {
|
||||
case "text":
|
||||
parts = append(parts, client.Part{Text: contentItemResult.Get("text").String()})
|
||||
case "image_url":
|
||||
// Parse data URI for images.
|
||||
imageURL := contentItemResult.Get("image_url.url").String()
|
||||
if len(imageURL) > 5 {
|
||||
imageURLs := strings.SplitN(imageURL[5:], ";", 2)
|
||||
if len(imageURLs) == 2 && len(imageURLs[1]) > 7 {
|
||||
parts = append(parts, client.Part{InlineData: &client.InlineData{
|
||||
MimeType: imageURLs[0],
|
||||
Data: imageURLs[1][7:],
|
||||
}})
|
||||
}
|
||||
}
|
||||
case "file":
|
||||
// Handle file attachments by determining MIME type from extension.
|
||||
filename := contentItemResult.Get("file.filename").String()
|
||||
fileData := contentItemResult.Get("file.file_data").String()
|
||||
ext := ""
|
||||
if split := strings.Split(filename, "."); len(split) > 1 {
|
||||
ext = split[len(split)-1]
|
||||
}
|
||||
if mimeType, ok := MimeTypes[ext]; ok {
|
||||
parts = append(parts, client.Part{InlineData: &client.InlineData{
|
||||
MimeType: mimeType,
|
||||
Data: fileData,
|
||||
}})
|
||||
} else {
|
||||
log.Warnf("Unknown file name extension '%s' at index %d, skipping file", ext, j)
|
||||
}
|
||||
}
|
||||
}
|
||||
contents = append(contents, client.Content{Role: "user", Parts: parts})
|
||||
}
|
||||
// Assistant messages can contain text or tool calls.
|
||||
case "assistant":
|
||||
if contentResult.Type == gjson.String {
|
||||
contents = append(contents, client.Content{Role: "model", Parts: []client.Part{{Text: contentResult.String()}}})
|
||||
} else if !contentResult.Exists() || contentResult.Type == gjson.Null {
|
||||
// Handle tool calls made by the assistant.
|
||||
toolCallsResult := messageResult.Get("tool_calls")
|
||||
if toolCallsResult.IsArray() {
|
||||
tcsResult := toolCallsResult.Array()
|
||||
for j := 0; j < len(tcsResult); j++ {
|
||||
tcResult := tcsResult[j]
|
||||
functionName := tcResult.Get("function.name").String()
|
||||
functionArgs := tcResult.Get("function.arguments").String()
|
||||
var args map[string]any
|
||||
if err := json.Unmarshal([]byte(functionArgs), &args); err == nil {
|
||||
contents = append(contents, client.Content{
|
||||
Role: "model", Parts: []client.Part{{
|
||||
FunctionCall: &client.FunctionCall{
|
||||
Name: functionName,
|
||||
Args: args,
|
||||
},
|
||||
}},
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
// Tool messages contain the output of a tool call.
|
||||
case "tool":
|
||||
toolCallID := messageResult.Get("tool_call_id").String()
|
||||
if toolCallID != "" {
|
||||
var responseData string
|
||||
if contentResult.Type == gjson.String {
|
||||
responseData = contentResult.String()
|
||||
} else if contentResult.IsObject() && contentResult.Get("type").String() == "text" {
|
||||
responseData = contentResult.Get("text").String()
|
||||
}
|
||||
functionResponse := client.FunctionResponse{Name: toolCallID, Response: map[string]interface{}{"result": responseData}}
|
||||
contents = append(contents, client.Content{Role: "tool", Parts: []client.Part{{FunctionResponse: &functionResponse}}})
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Translate the tool declarations from the request.
|
||||
var tools []client.ToolDeclaration
|
||||
toolsResult := gjson.GetBytes(rawJson, "tools")
|
||||
if toolsResult.IsArray() {
|
||||
tools = make([]client.ToolDeclaration, 1)
|
||||
tools[0].FunctionDeclarations = make([]any, 0)
|
||||
toolsResults := toolsResult.Array()
|
||||
for i := 0; i < len(toolsResults); i++ {
|
||||
toolResult := toolsResults[i]
|
||||
if toolResult.Get("type").String() == "function" {
|
||||
functionTypeResult := toolResult.Get("function")
|
||||
if functionTypeResult.Exists() && functionTypeResult.IsObject() {
|
||||
var functionDeclaration any
|
||||
if err := json.Unmarshal([]byte(functionTypeResult.Raw), &functionDeclaration); err == nil {
|
||||
tools[0].FunctionDeclarations = append(tools[0].FunctionDeclarations, functionDeclaration)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
tools = make([]client.ToolDeclaration, 0)
|
||||
}
|
||||
|
||||
return modelName, contents, tools
|
||||
}
|
||||
169
internal/api/translator/response.go
Normal file
169
internal/api/translator/response.go
Normal file
@@ -0,0 +1,169 @@
|
||||
package translator
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"github.com/tidwall/gjson"
|
||||
"github.com/tidwall/sjson"
|
||||
)
|
||||
|
||||
// ConvertCliToOpenAI translates a single chunk of a streaming response from the
|
||||
// backend client format to the OpenAI Server-Sent Events (SSE) format.
|
||||
// It returns an empty string if the chunk contains no useful data.
|
||||
func ConvertCliToOpenAI(rawJson []byte) string {
|
||||
// Initialize the OpenAI SSE template.
|
||||
template := `{"id":"","object":"chat.completion.chunk","created":12345,"model":"model","choices":[{"index":0,"delta":{"role":null,"content":null,"reasoning_content":null,"tool_calls":null},"finish_reason":null,"native_finish_reason":null}]}`
|
||||
|
||||
// Extract and set the model version.
|
||||
if modelVersionResult := gjson.GetBytes(rawJson, "response.modelVersion"); modelVersionResult.Exists() {
|
||||
template, _ = sjson.Set(template, "model", modelVersionResult.String())
|
||||
}
|
||||
|
||||
// Extract and set the creation timestamp.
|
||||
if createTimeResult := gjson.GetBytes(rawJson, "response.createTime"); createTimeResult.Exists() {
|
||||
t, err := time.Parse(time.RFC3339Nano, createTimeResult.String())
|
||||
unixTimestamp := time.Now().Unix()
|
||||
if err == nil {
|
||||
unixTimestamp = t.Unix()
|
||||
}
|
||||
template, _ = sjson.Set(template, "created", unixTimestamp)
|
||||
}
|
||||
|
||||
// Extract and set the response ID.
|
||||
if responseIdResult := gjson.GetBytes(rawJson, "response.responseId"); responseIdResult.Exists() {
|
||||
template, _ = sjson.Set(template, "id", responseIdResult.String())
|
||||
}
|
||||
|
||||
// Extract and set the finish reason.
|
||||
if finishReasonResult := gjson.GetBytes(rawJson, "response.candidates.0.finishReason"); finishReasonResult.Exists() {
|
||||
template, _ = sjson.Set(template, "choices.0.finish_reason", finishReasonResult.String())
|
||||
template, _ = sjson.Set(template, "choices.0.native_finish_reason", finishReasonResult.String())
|
||||
}
|
||||
|
||||
// Extract and set usage metadata (token counts).
|
||||
if usageResult := gjson.GetBytes(rawJson, "response.usageMetadata"); usageResult.Exists() {
|
||||
if candidatesTokenCountResult := usageResult.Get("candidatesTokenCount"); candidatesTokenCountResult.Exists() {
|
||||
template, _ = sjson.Set(template, "usage.completion_tokens", candidatesTokenCountResult.Int())
|
||||
}
|
||||
if totalTokenCountResult := usageResult.Get("totalTokenCount"); totalTokenCountResult.Exists() {
|
||||
template, _ = sjson.Set(template, "usage.total_tokens", totalTokenCountResult.Int())
|
||||
}
|
||||
promptTokenCount := usageResult.Get("promptTokenCount").Int()
|
||||
thoughtsTokenCount := usageResult.Get("thoughtsTokenCount").Int()
|
||||
template, _ = sjson.Set(template, "usage.prompt_tokens", promptTokenCount+thoughtsTokenCount)
|
||||
if thoughtsTokenCount > 0 {
|
||||
template, _ = sjson.Set(template, "usage.completion_tokens_details.reasoning_tokens", thoughtsTokenCount)
|
||||
}
|
||||
}
|
||||
|
||||
// Process the main content part of the response.
|
||||
partResult := gjson.GetBytes(rawJson, "response.candidates.0.content.parts.0")
|
||||
partTextResult := partResult.Get("text")
|
||||
functionCallResult := partResult.Get("functionCall")
|
||||
|
||||
if partTextResult.Exists() {
|
||||
// Handle text content, distinguishing between regular content and reasoning/thoughts.
|
||||
if partResult.Get("thought").Bool() {
|
||||
template, _ = sjson.Set(template, "choices.0.delta.reasoning_content", partTextResult.String())
|
||||
} else {
|
||||
template, _ = sjson.Set(template, "choices.0.delta.content", partTextResult.String())
|
||||
}
|
||||
template, _ = sjson.Set(template, "choices.0.delta.role", "assistant")
|
||||
} else if functionCallResult.Exists() {
|
||||
// Handle function call content.
|
||||
functionCallTemplate := `[{"id": "","type": "function","function": {"name": "","arguments": ""}}]`
|
||||
fcName := functionCallResult.Get("name").String()
|
||||
functionCallTemplate, _ = sjson.Set(functionCallTemplate, "0.id", fcName)
|
||||
functionCallTemplate, _ = sjson.Set(functionCallTemplate, "0.function.name", fcName)
|
||||
if fcArgsResult := functionCallResult.Get("args"); fcArgsResult.Exists() {
|
||||
functionCallTemplate, _ = sjson.Set(functionCallTemplate, "0.function.arguments", fcArgsResult.Raw)
|
||||
}
|
||||
template, _ = sjson.Set(template, "choices.0.delta.role", "assistant")
|
||||
template, _ = sjson.SetRaw(template, "choices.0.delta.tool_calls", functionCallTemplate)
|
||||
} else {
|
||||
// If no usable content is found, return an empty string.
|
||||
return ""
|
||||
}
|
||||
|
||||
return template
|
||||
}
|
||||
|
||||
// ConvertCliToOpenAINonStream aggregates response chunks from the backend client
|
||||
// into a single, non-streaming OpenAI-compatible JSON response.
|
||||
func ConvertCliToOpenAINonStream(template string, rawJson []byte) string {
|
||||
// Extract and set metadata fields that are typically set once per response.
|
||||
if gjson.Get(template, "id").String() == "" {
|
||||
if modelVersionResult := gjson.GetBytes(rawJson, "response.modelVersion"); modelVersionResult.Exists() {
|
||||
template, _ = sjson.Set(template, "model", modelVersionResult.String())
|
||||
}
|
||||
if createTimeResult := gjson.GetBytes(rawJson, "response.createTime"); createTimeResult.Exists() {
|
||||
t, err := time.Parse(time.RFC3339Nano, createTimeResult.String())
|
||||
unixTimestamp := time.Now().Unix()
|
||||
if err == nil {
|
||||
unixTimestamp = t.Unix()
|
||||
}
|
||||
template, _ = sjson.Set(template, "created", unixTimestamp)
|
||||
}
|
||||
if responseIdResult := gjson.GetBytes(rawJson, "response.responseId"); responseIdResult.Exists() {
|
||||
template, _ = sjson.Set(template, "id", responseIdResult.String())
|
||||
}
|
||||
}
|
||||
|
||||
// Extract and set the finish reason.
|
||||
if finishReasonResult := gjson.GetBytes(rawJson, "response.candidates.0.finishReason"); finishReasonResult.Exists() {
|
||||
template, _ = sjson.Set(template, "choices.0.finish_reason", finishReasonResult.String())
|
||||
template, _ = sjson.Set(template, "choices.0.native_finish_reason", finishReasonResult.String())
|
||||
}
|
||||
|
||||
// Extract and set usage metadata (token counts).
|
||||
if usageResult := gjson.GetBytes(rawJson, "response.usageMetadata"); usageResult.Exists() {
|
||||
if candidatesTokenCountResult := usageResult.Get("candidatesTokenCount"); candidatesTokenCountResult.Exists() {
|
||||
template, _ = sjson.Set(template, "usage.completion_tokens", candidatesTokenCountResult.Int())
|
||||
}
|
||||
if totalTokenCountResult := usageResult.Get("totalTokenCount"); totalTokenCountResult.Exists() {
|
||||
template, _ = sjson.Set(template, "usage.total_tokens", totalTokenCountResult.Int())
|
||||
}
|
||||
promptTokenCount := usageResult.Get("promptTokenCount").Int()
|
||||
thoughtsTokenCount := usageResult.Get("thoughtsTokenCount").Int()
|
||||
template, _ = sjson.Set(template, "usage.prompt_tokens", promptTokenCount+thoughtsTokenCount)
|
||||
if thoughtsTokenCount > 0 {
|
||||
template, _ = sjson.Set(template, "usage.completion_tokens_details.reasoning_tokens", thoughtsTokenCount)
|
||||
}
|
||||
}
|
||||
|
||||
// Process the main content part of the response.
|
||||
partResult := gjson.GetBytes(rawJson, "response.candidates.0.content.parts.0")
|
||||
partTextResult := partResult.Get("text")
|
||||
functionCallResult := partResult.Get("functionCall")
|
||||
|
||||
if partTextResult.Exists() {
|
||||
// Append text content, distinguishing between regular content and reasoning.
|
||||
if partResult.Get("thought").Bool() {
|
||||
currentContent := gjson.Get(template, "choices.0.message.reasoning_content").String()
|
||||
template, _ = sjson.Set(template, "choices.0.message.reasoning_content", currentContent+partTextResult.String())
|
||||
} else {
|
||||
currentContent := gjson.Get(template, "choices.0.message.content").String()
|
||||
template, _ = sjson.Set(template, "choices.0.message.content", currentContent+partTextResult.String())
|
||||
}
|
||||
template, _ = sjson.Set(template, "choices.0.message.role", "assistant")
|
||||
} else if functionCallResult.Exists() {
|
||||
// Append function call content to the tool_calls array.
|
||||
if !gjson.Get(template, "choices.0.message.tool_calls").Exists() {
|
||||
template, _ = sjson.SetRaw(template, "choices.0.message.tool_calls", `[]`)
|
||||
}
|
||||
functionCallItemTemplate := `{"id": "","type": "function","function": {"name": "","arguments": ""}}`
|
||||
fcName := functionCallResult.Get("name").String()
|
||||
functionCallItemTemplate, _ = sjson.Set(functionCallItemTemplate, "id", fcName)
|
||||
functionCallItemTemplate, _ = sjson.Set(functionCallItemTemplate, "function.name", fcName)
|
||||
if fcArgsResult := functionCallResult.Get("args"); fcArgsResult.Exists() {
|
||||
functionCallItemTemplate, _ = sjson.Set(functionCallItemTemplate, "function.arguments", fcArgsResult.Raw)
|
||||
}
|
||||
template, _ = sjson.Set(template, "choices.0.message.role", "assistant")
|
||||
template, _ = sjson.SetRaw(template, "choices.0.message.tool_calls.-1", functionCallItemTemplate)
|
||||
} else {
|
||||
// If no usable content is found, return an empty string.
|
||||
return ""
|
||||
}
|
||||
|
||||
return template
|
||||
}
|
||||
@@ -5,17 +5,18 @@ import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"github.com/luispater/CLIProxyAPI/internal/config"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"github.com/skratchdot/open-golang/open"
|
||||
"github.com/tidwall/gjson"
|
||||
"golang.org/x/net/proxy"
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"time"
|
||||
|
||||
"github.com/luispater/CLIProxyAPI/internal/config"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"github.com/skratchdot/open-golang/open"
|
||||
"github.com/tidwall/gjson"
|
||||
"golang.org/x/net/proxy"
|
||||
|
||||
"golang.org/x/oauth2"
|
||||
"golang.org/x/oauth2/google"
|
||||
)
|
||||
@@ -33,76 +34,78 @@ var (
|
||||
}
|
||||
)
|
||||
|
||||
// GetAuthenticatedClient configures and returns an HTTP client with OAuth2 tokens.
|
||||
// It handles the entire flow: loading, refreshing, and fetching new tokens.
|
||||
// GetAuthenticatedClient configures and returns an HTTP client ready for making authenticated API calls.
|
||||
// It manages the entire OAuth2 flow, including handling proxies, loading existing tokens,
|
||||
// initiating a new web-based OAuth flow if necessary, and refreshing tokens.
|
||||
func GetAuthenticatedClient(ctx context.Context, ts *TokenStorage, cfg *config.Config) (*http.Client, error) {
|
||||
// Configure proxy settings for the HTTP client if a proxy URL is provided.
|
||||
proxyURL, err := url.Parse(cfg.ProxyUrl)
|
||||
if err == nil {
|
||||
var transport *http.Transport
|
||||
if proxyURL.Scheme == "socks5" {
|
||||
// Handle SOCKS5 proxy.
|
||||
username := proxyURL.User.Username()
|
||||
password, _ := proxyURL.User.Password()
|
||||
auth := &proxy.Auth{
|
||||
User: username,
|
||||
Password: password,
|
||||
}
|
||||
auth := &proxy.Auth{User: username, Password: password}
|
||||
dialer, errSOCKS5 := proxy.SOCKS5("tcp", proxyURL.Host, auth, proxy.Direct)
|
||||
if errSOCKS5 != nil {
|
||||
log.Fatalf("create SOCKS5 dialer failed: %v", errSOCKS5)
|
||||
}
|
||||
|
||||
transport := &http.Transport{
|
||||
DialContext: func(ctx context.Context, network, addr string) (c net.Conn, err error) {
|
||||
transport = &http.Transport{
|
||||
DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
|
||||
return dialer.Dial(network, addr)
|
||||
},
|
||||
}
|
||||
proxyClient := &http.Client{
|
||||
Transport: transport,
|
||||
}
|
||||
|
||||
ctx = context.WithValue(ctx, oauth2.HTTPClient, proxyClient)
|
||||
} else if proxyURL.Scheme == "http" || proxyURL.Scheme == "https" {
|
||||
transport := &http.Transport{
|
||||
Proxy: http.ProxyURL(proxyURL),
|
||||
}
|
||||
proxyClient := &http.Client{
|
||||
Transport: transport,
|
||||
}
|
||||
// Handle HTTP/HTTPS proxy.
|
||||
transport = &http.Transport{Proxy: http.ProxyURL(proxyURL)}
|
||||
}
|
||||
|
||||
if transport != nil {
|
||||
proxyClient := &http.Client{Transport: transport}
|
||||
ctx = context.WithValue(ctx, oauth2.HTTPClient, proxyClient)
|
||||
}
|
||||
}
|
||||
|
||||
// Configure the OAuth2 client.
|
||||
conf := &oauth2.Config{
|
||||
ClientID: oauthClientID,
|
||||
ClientSecret: oauthClientSecret,
|
||||
RedirectURL: "http://localhost:8085/oauth2callback", // Placeholder, will be updated
|
||||
RedirectURL: "http://localhost:8085/oauth2callback", // This will be used by the local server.
|
||||
Scopes: oauthScopes,
|
||||
Endpoint: google.Endpoint,
|
||||
}
|
||||
|
||||
var token *oauth2.Token
|
||||
|
||||
// If no token is found in storage, initiate the web-based OAuth flow.
|
||||
if ts.Token == nil {
|
||||
log.Info("Could not load token from file, starting OAuth flow.")
|
||||
token, err = getTokenFromWeb(ctx, conf)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get token from web: %w", err)
|
||||
}
|
||||
newTs, errSaveTokenToFile := createTokenStorage(ctx, conf, token, ts.ProjectID)
|
||||
if errSaveTokenToFile != nil {
|
||||
log.Errorf("Warning: failed to save token to file: %v", err)
|
||||
return nil, errSaveTokenToFile
|
||||
// After getting a new token, create a new token storage object with user info.
|
||||
newTs, errCreateTokenStorage := createTokenStorage(ctx, conf, token, ts.ProjectID)
|
||||
if errCreateTokenStorage != nil {
|
||||
log.Errorf("Warning: failed to create token storage: %v", errCreateTokenStorage)
|
||||
return nil, errCreateTokenStorage
|
||||
}
|
||||
*ts = *newTs
|
||||
}
|
||||
|
||||
// Unmarshal the stored token into an oauth2.Token object.
|
||||
tsToken, _ := json.Marshal(ts.Token)
|
||||
if err = json.Unmarshal(tsToken, &token); err != nil {
|
||||
return nil, err
|
||||
return nil, fmt.Errorf("failed to unmarshal token: %w", err)
|
||||
}
|
||||
|
||||
// Return an HTTP client that automatically handles token refreshing.
|
||||
return conf.Client(ctx, token), nil
|
||||
}
|
||||
|
||||
// createTokenStorage creates a token storage.
|
||||
// createTokenStorage creates a new TokenStorage object. It fetches the user's email
|
||||
// using the provided token and populates the storage structure.
|
||||
func createTokenStorage(ctx context.Context, config *oauth2.Config, token *oauth2.Token, projectID string) (*TokenStorage, error) {
|
||||
httpClient := config.Client(ctx, token)
|
||||
req, err := http.NewRequestWithContext(ctx, "GET", "https://www.googleapis.com/oauth2/v1/userinfo?alt=json", nil)
|
||||
@@ -117,7 +120,9 @@ func createTokenStorage(ctx context.Context, config *oauth2.Config, token *oauth
|
||||
return nil, fmt.Errorf("failed to execute request: %w", err)
|
||||
}
|
||||
defer func() {
|
||||
_ = resp.Body.Close()
|
||||
if err = resp.Body.Close(); err != nil {
|
||||
log.Printf("warn: failed to close response body: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
bodyBytes, _ := io.ReadAll(resp.Body)
|
||||
@@ -154,7 +159,10 @@ func createTokenStorage(ctx context.Context, config *oauth2.Config, token *oauth
|
||||
return &ts, nil
|
||||
}
|
||||
|
||||
// getTokenFromWeb starts a local server to handle the OAuth2 flow.
|
||||
// getTokenFromWeb initiates the web-based OAuth2 authorization flow.
|
||||
// It starts a local HTTP server to listen for the callback from Google's auth server,
|
||||
// opens the user's browser to the authorization URL, and exchanges the received
|
||||
// authorization code for an access token.
|
||||
func getTokenFromWeb(ctx context.Context, config *oauth2.Config) (*oauth2.Token, error) {
|
||||
// Use a channel to pass the authorization code from the HTTP handler to the main function.
|
||||
codeChan := make(chan string)
|
||||
|
||||
@@ -1,9 +1,17 @@
|
||||
package auth
|
||||
|
||||
// TokenStorage defines the structure for storing OAuth2 token information,
|
||||
// along with associated user and project details. This data is typically
|
||||
// serialized to a JSON file for persistence.
|
||||
type TokenStorage struct {
|
||||
Token any `json:"token"`
|
||||
// Token holds the raw OAuth2 token data, including access and refresh tokens.
|
||||
Token any `json:"token"`
|
||||
// ProjectID is the Google Cloud Project ID associated with this token.
|
||||
ProjectID string `json:"project_id"`
|
||||
Email string `json:"email"`
|
||||
Auto bool `json:"auto"`
|
||||
Checked bool `json:"checked"`
|
||||
// Email is the email address of the authenticated user.
|
||||
Email string `json:"email"`
|
||||
// Auto indicates if the project ID was automatically selected.
|
||||
Auto bool `json:"auto"`
|
||||
// Checked indicates if the associated Cloud AI API has been verified as enabled.
|
||||
Checked bool `json:"checked"`
|
||||
}
|
||||
|
||||
@@ -6,12 +6,6 @@ import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"github.com/luispater/CLIProxyAPI/internal/auth"
|
||||
"github.com/luispater/CLIProxyAPI/internal/config"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"github.com/tidwall/gjson"
|
||||
"github.com/tidwall/sjson"
|
||||
"golang.org/x/oauth2"
|
||||
"io"
|
||||
"net/http"
|
||||
"os"
|
||||
@@ -20,6 +14,13 @@ import (
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/luispater/CLIProxyAPI/internal/auth"
|
||||
"github.com/luispater/CLIProxyAPI/internal/config"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"github.com/tidwall/gjson"
|
||||
"github.com/tidwall/sjson"
|
||||
"golang.org/x/oauth2"
|
||||
)
|
||||
|
||||
const (
|
||||
@@ -194,7 +195,9 @@ func (c *Client) makeAPIRequest(ctx context.Context, endpoint, method string, bo
|
||||
return fmt.Errorf("failed to execute request: %w", err)
|
||||
}
|
||||
defer func() {
|
||||
_ = resp.Body.Close()
|
||||
if err = resp.Body.Close(); err != nil {
|
||||
log.Printf("warn: failed to close response body: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
|
||||
@@ -253,7 +256,9 @@ func (c *Client) StreamAPIRequest(ctx context.Context, endpoint string, body int
|
||||
|
||||
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
|
||||
defer func() {
|
||||
_ = resp.Body.Close()
|
||||
if err = resp.Body.Close(); err != nil {
|
||||
log.Printf("warn: failed to close response body: %v", err)
|
||||
}
|
||||
}()
|
||||
bodyBytes, _ := io.ReadAll(resp.Body)
|
||||
|
||||
@@ -355,6 +360,9 @@ func (c *Client) SendMessageStream(ctx context.Context, rawJson []byte, model st
|
||||
return dataChan, errChan
|
||||
}
|
||||
|
||||
// CheckCloudAPIIsEnabled sends a simple test request to the API to verify
|
||||
// that the Cloud AI API is enabled for the user's project. It provides
|
||||
// an activation URL if the API is disabled.
|
||||
func (c *Client) CheckCloudAPIIsEnabled() (bool, error) {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer func() {
|
||||
@@ -363,79 +371,78 @@ func (c *Client) CheckCloudAPIIsEnabled() (bool, error) {
|
||||
}()
|
||||
c.RequestMutex.Lock()
|
||||
|
||||
requestBody := `{"project":"%s","request":{"contents":[{"role":"user","parts":[{"text":"Be concise. What is the capital of France?"}]}],"generationConfig":{"thinkingConfig":{"include_thoughts":false,"thinkingBudget":0}}},"model":"gemini-2.5-flash"}`
|
||||
requestBody = fmt.Sprintf(requestBody, c.tokenStorage.ProjectID)
|
||||
// log.Debug(requestBody)
|
||||
// A simple request to test the API endpoint.
|
||||
requestBody := fmt.Sprintf(`{"project":"%s","request":{"contents":[{"role":"user","parts":[{"text":"Be concise. What is the capital of France?"}]}],"generationConfig":{"thinkingConfig":{"include_thoughts":false,"thinkingBudget":0}}},"model":"gemini-2.5-flash"}`, c.tokenStorage.ProjectID)
|
||||
|
||||
stream, err := c.StreamAPIRequest(ctx, "streamGenerateContent", []byte(requestBody))
|
||||
if err != nil {
|
||||
// If a 403 Forbidden error occurs, it likely means the API is not enabled.
|
||||
if err.StatusCode == 403 {
|
||||
errJson := err.Error.Error()
|
||||
codeResult := gjson.Get(errJson, "error.code")
|
||||
if codeResult.Exists() && codeResult.Type == gjson.Number {
|
||||
if codeResult.Int() == 403 {
|
||||
activationUrlResult := gjson.Get(errJson, "error.details.0.metadata.activationUrl")
|
||||
if activationUrlResult.Exists() {
|
||||
log.Warnf(
|
||||
"\n\nPlease activate your account with this url:\n\n%s\n And execute this command again:\n%s --login --project_id %s",
|
||||
activationUrlResult.String(),
|
||||
os.Args[0],
|
||||
c.tokenStorage.ProjectID,
|
||||
)
|
||||
}
|
||||
// Check for a specific error code and extract the activation URL.
|
||||
if gjson.Get(errJson, "error.code").Int() == 403 {
|
||||
activationUrl := gjson.Get(errJson, "error.details.0.metadata.activationUrl").String()
|
||||
if activationUrl != "" {
|
||||
log.Warnf(
|
||||
"\n\nPlease activate your account with this url:\n\n%s\n And execute this command again:\n%s --login --project_id %s",
|
||||
activationUrl,
|
||||
os.Args[0],
|
||||
c.tokenStorage.ProjectID,
|
||||
)
|
||||
}
|
||||
}
|
||||
return false, nil
|
||||
}
|
||||
return false, err.Error
|
||||
}
|
||||
defer func() {
|
||||
_ = stream.Close()
|
||||
}()
|
||||
|
||||
// We only need to know if the request was successful, so we can drain the stream.
|
||||
scanner := bufio.NewScanner(stream)
|
||||
for scanner.Scan() {
|
||||
line := scanner.Text()
|
||||
if !strings.HasPrefix(line, "data: ") {
|
||||
continue
|
||||
}
|
||||
// Do nothing, just consume the stream.
|
||||
}
|
||||
|
||||
if scannerErr := scanner.Err(); scannerErr != nil {
|
||||
_ = stream.Close()
|
||||
} else {
|
||||
_ = stream.Close()
|
||||
}
|
||||
|
||||
return true, nil
|
||||
return scanner.Err() == nil, scanner.Err()
|
||||
}
|
||||
|
||||
// GetProjectList fetches a list of Google Cloud projects accessible by the user.
|
||||
func (c *Client) GetProjectList(ctx context.Context) (*GCPProject, error) {
|
||||
token, err := c.httpClient.Transport.(*oauth2.Transport).Source.Token()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get token: %w", err)
|
||||
}
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, "GET", "https://cloudresourcemanager.googleapis.com/v1/projects", nil)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("could not get project list: %v", err)
|
||||
return nil, fmt.Errorf("could not create project list request: %v", err)
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", token.AccessToken))
|
||||
|
||||
resp, err := c.httpClient.Do(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to execute request: %w", err)
|
||||
return nil, fmt.Errorf("failed to execute project list request: %w", err)
|
||||
}
|
||||
defer func() {
|
||||
_ = resp.Body.Close()
|
||||
}()
|
||||
|
||||
bodyBytes, _ := io.ReadAll(resp.Body)
|
||||
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
|
||||
return nil, fmt.Errorf("get user info request failed with status %d: %s", resp.StatusCode, string(bodyBytes))
|
||||
bodyBytes, _ := io.ReadAll(resp.Body)
|
||||
return nil, fmt.Errorf("project list request failed with status %d: %s", resp.StatusCode, string(bodyBytes))
|
||||
}
|
||||
|
||||
var project GCPProject
|
||||
err = json.Unmarshal(bodyBytes, &project)
|
||||
if err != nil {
|
||||
if err = json.NewDecoder(resp.Body).Decode(&project); err != nil {
|
||||
return nil, fmt.Errorf("failed to unmarshal project list: %w", err)
|
||||
}
|
||||
return &project, nil
|
||||
}
|
||||
|
||||
// SaveTokenToFile serializes the client's current token storage to a JSON file.
|
||||
// The filename is constructed from the user's email and project ID.
|
||||
func (c *Client) SaveTokenToFile() error {
|
||||
if err := os.MkdirAll(c.cfg.AuthDir, 0700); err != nil {
|
||||
return fmt.Errorf("failed to create directory: %v", err)
|
||||
@@ -457,7 +464,8 @@ func (c *Client) SaveTokenToFile() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// getClientMetadata returns metadata about the client environment.
|
||||
// getClientMetadata returns a map of metadata about the client environment,
|
||||
// such as IDE type, platform, and plugin version.
|
||||
func getClientMetadata() map[string]string {
|
||||
return map[string]string{
|
||||
"ideType": "IDE_UNSPECIFIED",
|
||||
@@ -467,7 +475,8 @@ func getClientMetadata() map[string]string {
|
||||
}
|
||||
}
|
||||
|
||||
// getClientMetadataString returns the metadata as a comma-separated string.
|
||||
// getClientMetadataString returns the client metadata as a single,
|
||||
// comma-separated string, which is required for the 'Client-Metadata' header.
|
||||
func getClientMetadataString() string {
|
||||
md := getClientMetadata()
|
||||
parts := make([]string, 0, len(md))
|
||||
@@ -477,11 +486,13 @@ func getClientMetadataString() string {
|
||||
return strings.Join(parts, ",")
|
||||
}
|
||||
|
||||
// getUserAgent constructs the User-Agent string for HTTP requests.
|
||||
func getUserAgent() string {
|
||||
return fmt.Sprintf(fmt.Sprintf("GeminiCLI/%s (%s; %s)", pluginVersion, runtime.GOOS, runtime.GOARCH))
|
||||
return fmt.Sprintf("GeminiCLI/%s (%s; %s)", pluginVersion, runtime.GOOS, runtime.GOARCH)
|
||||
}
|
||||
|
||||
// getPlatform returns the OS and architecture in the format expected by the API.
|
||||
// getPlatform determines the operating system and architecture and formats
|
||||
// it into a string expected by the backend API.
|
||||
func getPlatform() string {
|
||||
goOS := runtime.GOOS
|
||||
arch := runtime.GOARCH
|
||||
|
||||
@@ -2,17 +2,23 @@ package client
|
||||
|
||||
import "time"
|
||||
|
||||
// ErrorMessage encapsulates an error with an associated HTTP status code.
|
||||
type ErrorMessage struct {
|
||||
StatusCode int
|
||||
Error error
|
||||
}
|
||||
|
||||
// GCPProject represents the response structure for a Google Cloud project list request.
|
||||
type GCPProject struct {
|
||||
Projects []GCPProjectProjects `json:"projects"`
|
||||
}
|
||||
|
||||
// GCPProjectLabels defines the labels associated with a GCP project.
|
||||
type GCPProjectLabels struct {
|
||||
GenerativeLanguage string `json:"generative-language"`
|
||||
}
|
||||
|
||||
// GCPProjectProjects contains details about a single Google Cloud project.
|
||||
type GCPProjectProjects struct {
|
||||
ProjectNumber string `json:"projectNumber"`
|
||||
ProjectID string `json:"projectId"`
|
||||
@@ -22,12 +28,14 @@ type GCPProjectProjects struct {
|
||||
CreateTime time.Time `json:"createTime"`
|
||||
}
|
||||
|
||||
// Content represents a single message in a conversation, with a role and parts.
|
||||
type Content struct {
|
||||
Role string `json:"role"`
|
||||
Parts []Part `json:"parts"`
|
||||
}
|
||||
|
||||
// Part represents a single part of a message's content.
|
||||
// Part represents a distinct piece of content within a message, which can be
|
||||
// text, inline data (like an image), a function call, or a function response.
|
||||
type Part struct {
|
||||
Text string `json:"text,omitempty"`
|
||||
InlineData *InlineData `json:"inlineData,omitempty"`
|
||||
@@ -35,46 +43,48 @@ type Part struct {
|
||||
FunctionResponse *FunctionResponse `json:"functionResponse,omitempty"`
|
||||
}
|
||||
|
||||
// InlineData represents base64-encoded data with its MIME type.
|
||||
type InlineData struct {
|
||||
MimeType string `json:"mime_type,omitempty"`
|
||||
Data string `json:"data,omitempty"`
|
||||
}
|
||||
|
||||
// FunctionCall represents a tool call requested by the model.
|
||||
// FunctionCall represents a tool call requested by the model, including the
|
||||
// function name and its arguments.
|
||||
type FunctionCall struct {
|
||||
Name string `json:"name"`
|
||||
Args map[string]interface{} `json:"args"`
|
||||
}
|
||||
|
||||
// FunctionResponse represents the result of a tool execution.
|
||||
// FunctionResponse represents the result of a tool execution, sent back to the model.
|
||||
type FunctionResponse struct {
|
||||
Name string `json:"name"`
|
||||
Response map[string]interface{} `json:"response"`
|
||||
}
|
||||
|
||||
// GenerateContentRequest is the request payload for the streamGenerateContent endpoint.
|
||||
// GenerateContentRequest is the top-level request structure for the streamGenerateContent endpoint.
|
||||
type GenerateContentRequest struct {
|
||||
Contents []Content `json:"contents"`
|
||||
Tools []ToolDeclaration `json:"tools,omitempty"`
|
||||
GenerationConfig `json:"generationConfig"`
|
||||
}
|
||||
|
||||
// GenerationConfig defines model generation parameters.
|
||||
// GenerationConfig defines parameters that control the model's generation behavior.
|
||||
type GenerationConfig struct {
|
||||
ThinkingConfig GenerationConfigThinkingConfig `json:"thinkingConfig,omitempty"`
|
||||
Temperature float64 `json:"temperature,omitempty"`
|
||||
TopP float64 `json:"topP,omitempty"`
|
||||
TopK float64 `json:"topK,omitempty"`
|
||||
// Temperature, TopP, TopK, etc. can be added here.
|
||||
}
|
||||
|
||||
// GenerationConfigThinkingConfig specifies configuration for the model's "thinking" process.
|
||||
type GenerationConfigThinkingConfig struct {
|
||||
// IncludeThoughts determines whether the model should output its reasoning process.
|
||||
IncludeThoughts bool `json:"include_thoughts,omitempty"`
|
||||
}
|
||||
|
||||
// ToolDeclaration is the structure for declaring tools to the API.
|
||||
// For now, we'll assume a simple structure. A more complete implementation
|
||||
// would mirror the OpenAPI schema definition.
|
||||
// ToolDeclaration defines the structure for declaring tools (like functions)
|
||||
// that the model can call.
|
||||
type ToolDeclaration struct {
|
||||
FunctionDeclarations []interface{} `json:"functionDeclarations"`
|
||||
}
|
||||
|
||||
@@ -9,6 +9,9 @@ import (
|
||||
"os"
|
||||
)
|
||||
|
||||
// DoLogin handles the entire user login and setup process.
|
||||
// It authenticates the user, sets up the user's project, checks API enablement,
|
||||
// and saves the token for future use.
|
||||
func DoLogin(cfg *config.Config, projectID string) {
|
||||
var err error
|
||||
var ts auth.TokenStorage
|
||||
@@ -16,9 +19,8 @@ func DoLogin(cfg *config.Config, projectID string) {
|
||||
ts.ProjectID = projectID
|
||||
}
|
||||
|
||||
// 2. Initialize authenticated HTTP Client
|
||||
// Initialize an authenticated HTTP client. This will trigger the OAuth flow if necessary.
|
||||
clientCtx := context.Background()
|
||||
|
||||
log.Info("Initializing authentication...")
|
||||
httpClient, errGetClient := auth.GetAuthenticatedClient(clientCtx, &ts, cfg)
|
||||
if errGetClient != nil {
|
||||
@@ -27,51 +29,57 @@ func DoLogin(cfg *config.Config, projectID string) {
|
||||
}
|
||||
log.Info("Authentication successful.")
|
||||
|
||||
// 3. Initialize CLI Client
|
||||
// Initialize the API client.
|
||||
cliClient := client.NewClient(httpClient, &ts, cfg)
|
||||
|
||||
// Perform the user setup process.
|
||||
err = cliClient.SetupUser(clientCtx, ts.Email, projectID)
|
||||
if err != nil {
|
||||
// Handle the specific case where a project ID is required but not provided.
|
||||
if err.Error() == "failed to start user onboarding, need define a project id" {
|
||||
log.Error("failed to start user onboarding")
|
||||
log.Error("Failed to start user onboarding: A project ID is required.")
|
||||
// Fetch and display the user's available projects to help them choose one.
|
||||
project, errGetProjectList := cliClient.GetProjectList(clientCtx)
|
||||
if errGetProjectList != nil {
|
||||
log.Fatalf("failed to complete user setup: %v", err)
|
||||
log.Fatalf("Failed to get project list: %v", err)
|
||||
} else {
|
||||
log.Infof("Your account %s needs specify a project id.", ts.Email)
|
||||
log.Infof("Your account %s needs to specify a project ID.", ts.Email)
|
||||
log.Info("========================================================================")
|
||||
for i := 0; i < len(project.Projects); i++ {
|
||||
log.Infof("Project ID: %s", project.Projects[i].ProjectID)
|
||||
log.Infof("Project Name: %s", project.Projects[i].Name)
|
||||
log.Info("========================================================================")
|
||||
for _, p := range project.Projects {
|
||||
log.Infof("Project ID: %s", p.ProjectID)
|
||||
log.Infof("Project Name: %s", p.Name)
|
||||
log.Info("------------------------------------------------------------------------")
|
||||
}
|
||||
log.Infof("Please run this command to login again:\n\n%s --login --project_id <project_id>\n", os.Args[0])
|
||||
log.Infof("Please run this command to login again with a specific project:\n\n%s --login --project_id <project_id>\n", os.Args[0])
|
||||
}
|
||||
} else {
|
||||
// Log as a warning because in some cases, the CLI might still be usable
|
||||
// or the user might want to retry setup later.
|
||||
log.Fatalf("failed to complete user setup: %v", err)
|
||||
log.Fatalf("Failed to complete user setup: %v", err)
|
||||
}
|
||||
} else {
|
||||
auto := projectID == ""
|
||||
cliClient.SetIsAuto(auto)
|
||||
return // Exit after handling the error.
|
||||
}
|
||||
|
||||
if !cliClient.IsChecked() && !cliClient.IsAuto() {
|
||||
isChecked, checkErr := cliClient.CheckCloudAPIIsEnabled()
|
||||
if checkErr != nil {
|
||||
log.Fatalf("failed to check cloud api is enabled: %v", checkErr)
|
||||
return
|
||||
}
|
||||
cliClient.SetIsChecked(isChecked)
|
||||
}
|
||||
// If setup is successful, proceed to check API status and save the token.
|
||||
auto := projectID == ""
|
||||
cliClient.SetIsAuto(auto)
|
||||
|
||||
if !cliClient.IsChecked() && !cliClient.IsAuto() {
|
||||
// If the project was not automatically selected, check if the Cloud AI API is enabled.
|
||||
if !cliClient.IsChecked() && !cliClient.IsAuto() {
|
||||
isChecked, checkErr := cliClient.CheckCloudAPIIsEnabled()
|
||||
if checkErr != nil {
|
||||
log.Fatalf("Failed to check if Cloud AI API is enabled: %v", checkErr)
|
||||
return
|
||||
}
|
||||
|
||||
err = cliClient.SaveTokenToFile()
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
cliClient.SetIsChecked(isChecked)
|
||||
// If the check fails (returns false), the CheckCloudAPIIsEnabled function
|
||||
// will have already printed instructions, so we can just exit.
|
||||
if !isChecked {
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// Save the successfully obtained and verified token to a file.
|
||||
err = cliClient.SaveTokenToFile()
|
||||
if err != nil {
|
||||
log.Fatalf("Failed to save token to file: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -18,20 +18,25 @@ import (
|
||||
"time"
|
||||
)
|
||||
|
||||
// StartService initializes and starts the main API proxy service.
|
||||
// It loads all available authentication tokens, creates a pool of clients,
|
||||
// starts the API server, and handles graceful shutdown signals.
|
||||
func StartService(cfg *config.Config) {
|
||||
// Create API server configuration
|
||||
// Configure the API server based on the main application config.
|
||||
apiConfig := &api.ServerConfig{
|
||||
Port: fmt.Sprintf("%d", cfg.Port),
|
||||
Debug: cfg.Debug,
|
||||
ApiKeys: cfg.ApiKeys,
|
||||
}
|
||||
|
||||
// Create a pool of API clients, one for each token file found.
|
||||
cliClients := make([]*client.Client, 0)
|
||||
err := filepath.Walk(cfg.AuthDir, func(path string, info fs.FileInfo, err error) error {
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Process only JSON files in the auth directory.
|
||||
if !info.IsDir() && strings.HasSuffix(info.Name(), ".json") {
|
||||
log.Debugf("Loading token from: %s", path)
|
||||
f, errOpen := os.Open(path)
|
||||
@@ -42,58 +47,62 @@ func StartService(cfg *config.Config) {
|
||||
_ = f.Close()
|
||||
}()
|
||||
|
||||
// Decode the token storage file.
|
||||
var ts auth.TokenStorage
|
||||
if err = json.NewDecoder(f).Decode(&ts); err == nil {
|
||||
// 2. Initialize authenticated HTTP Client
|
||||
// For each valid token, create an authenticated client.
|
||||
clientCtx := context.Background()
|
||||
|
||||
log.Info("Initializing authentication...")
|
||||
log.Info("Initializing authentication for token...")
|
||||
httpClient, errGetClient := auth.GetAuthenticatedClient(clientCtx, &ts, cfg)
|
||||
if errGetClient != nil {
|
||||
log.Fatalf("failed to get authenticated client: %v", errGetClient)
|
||||
// Log fatal will exit, but we return the error for completeness.
|
||||
log.Fatalf("failed to get authenticated client for token %s: %v", path, errGetClient)
|
||||
return errGetClient
|
||||
}
|
||||
log.Info("Authentication successful.")
|
||||
|
||||
// 3. Initialize CLI Client
|
||||
// Add the new client to the pool.
|
||||
cliClient := client.NewClient(httpClient, &ts, cfg)
|
||||
cliClients = append(cliClients, cliClient)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
log.Fatalf("Error walking auth directory: %v", err)
|
||||
}
|
||||
|
||||
// Create API server
|
||||
// Create and start the API server with the pool of clients.
|
||||
apiServer := api.NewServer(apiConfig, cliClients)
|
||||
log.Infof("Starting API server on port %s", apiConfig.Port)
|
||||
if err = apiServer.Start(); err != nil {
|
||||
log.Fatalf("API server failed to start: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
// Set up graceful shutdown
|
||||
// Set up a channel to listen for OS signals for graceful shutdown.
|
||||
sigChan := make(chan os.Signal, 1)
|
||||
signal.Notify(sigChan, syscall.SIGINT, syscall.SIGTERM)
|
||||
|
||||
// Main loop to wait for shutdown signal.
|
||||
for {
|
||||
select {
|
||||
case <-sigChan:
|
||||
log.Debugf("Received shutdown signal. Cleaning up...")
|
||||
|
||||
// Create shutdown context
|
||||
// Create a context with a timeout for the shutdown process.
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
|
||||
_ = ctx // Mark ctx as used to avoid error, as apiServer.Stop(ctx) is commented out
|
||||
_ = cancel
|
||||
|
||||
// Stop API server
|
||||
// Stop the API server gracefully.
|
||||
if err = apiServer.Stop(ctx); err != nil {
|
||||
log.Debugf("Error stopping API server: %v", err)
|
||||
}
|
||||
cancel()
|
||||
|
||||
log.Debugf("Cleanup completed. Exiting...")
|
||||
os.Exit(0)
|
||||
case <-time.After(5 * time.Second):
|
||||
|
||||
// This case is currently empty and acts as a periodic check.
|
||||
// It could be used for periodic tasks in the future.
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -6,33 +6,35 @@ import (
|
||||
"os"
|
||||
)
|
||||
|
||||
// Config represents the application's configuration
|
||||
// Config represents the application's configuration, loaded from a YAML file.
|
||||
type Config struct {
|
||||
Port int `yaml:"port"`
|
||||
AuthDir string `yaml:"auth_dir"`
|
||||
Debug bool `yaml:"debug"`
|
||||
ProxyUrl string `yaml:"proxy-url"`
|
||||
ApiKeys []string `yaml:"api_keys"`
|
||||
// Port is the network port on which the API server will listen.
|
||||
Port int `yaml:"port"`
|
||||
// AuthDir is the directory where authentication token files are stored.
|
||||
AuthDir string `yaml:"auth_dir"`
|
||||
// Debug enables or disables debug-level logging and other debug features.
|
||||
Debug bool `yaml:"debug"`
|
||||
// ProxyUrl is the URL of an optional proxy server to use for outbound requests.
|
||||
ProxyUrl string `yaml:"proxy-url"`
|
||||
// ApiKeys is a list of keys for authenticating clients to this proxy server.
|
||||
ApiKeys []string `yaml:"api_keys"`
|
||||
}
|
||||
|
||||
// / LoadConfig loads the configuration from the specified file
|
||||
// LoadConfig reads a YAML configuration file from the given path,
|
||||
// unmarshals it into a Config struct, and returns it.
|
||||
func LoadConfig(configFile string) (*Config, error) {
|
||||
// Read the configuration file
|
||||
// Read the entire configuration file into memory.
|
||||
data, err := os.ReadFile(configFile)
|
||||
// If reading the file fails
|
||||
if err != nil {
|
||||
// Return an error
|
||||
return nil, fmt.Errorf("failed to read config file: %w", err)
|
||||
}
|
||||
|
||||
// Parse the YAML data
|
||||
// Unmarshal the YAML data into the Config struct.
|
||||
var config Config
|
||||
// If parsing the YAML data fails
|
||||
if err = yaml.Unmarshal(data, &config); err != nil {
|
||||
// Return an error
|
||||
return nil, fmt.Errorf("failed to parse config file: %w", err)
|
||||
}
|
||||
|
||||
// Return the configuration
|
||||
// Return the populated configuration struct.
|
||||
return &config, nil
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user