Numerous Comments Added and Extensive Optimization Performed using Roo-Code with CLIProxyAPI itself.

This commit is contained in:
Luis Pater
2025-07-04 18:44:55 +08:00
parent 8dd7f8e82f
commit 5ec6450c50
15 changed files with 629 additions and 559 deletions

View File

@@ -12,9 +12,11 @@ import (
"strings" "strings"
) )
// LogFormatter defines a custom log format for logrus.
type LogFormatter struct { type LogFormatter struct {
} }
// Format renders a single log entry.
func (m *LogFormatter) Format(entry *log.Entry) ([]byte, error) { func (m *LogFormatter) Format(entry *log.Entry) ([]byte, error) {
var b *bytes.Buffer var b *bytes.Buffer
if entry.Buffer != nil { if entry.Buffer != nil {
@@ -25,33 +27,42 @@ func (m *LogFormatter) Format(entry *log.Entry) ([]byte, error) {
timestamp := entry.Time.Format("2006-01-02 15:04:05") timestamp := entry.Time.Format("2006-01-02 15:04:05")
var newLog string var newLog string
// Customize the log format to include timestamp, level, caller file/line, and message.
newLog = fmt.Sprintf("[%s] [%s] [%s:%d] %s\n", timestamp, entry.Level, path.Base(entry.Caller.File), entry.Caller.Line, entry.Message) newLog = fmt.Sprintf("[%s] [%s] [%s:%d] %s\n", timestamp, entry.Level, path.Base(entry.Caller.File), entry.Caller.Line, entry.Message)
b.WriteString(newLog) b.WriteString(newLog)
return b.Bytes(), nil return b.Bytes(), nil
} }
// init initializes the logger configuration.
func init() { func init() {
// Set logger output to standard output.
log.SetOutput(os.Stdout) log.SetOutput(os.Stdout)
// Enable reporting the caller function's file and line number.
log.SetReportCaller(true) log.SetReportCaller(true)
// Set the custom log formatter.
log.SetFormatter(&LogFormatter{}) log.SetFormatter(&LogFormatter{})
} }
// main is the entry point of the application.
func main() { func main() {
var login bool var login bool
var projectID string var projectID string
var configPath string var configPath string
// Define command-line flags.
flag.BoolVar(&login, "login", false, "Login Google Account") flag.BoolVar(&login, "login", false, "Login Google Account")
flag.StringVar(&projectID, "project_id", "", "Project ID") flag.StringVar(&projectID, "project_id", "", "Project ID")
flag.StringVar(&configPath, "config", "", "Configure File Path") flag.StringVar(&configPath, "config", "", "Configure File Path")
// Parse the command-line flags.
flag.Parse() flag.Parse()
var err error var err error
var cfg *config.Config var cfg *config.Config
var wd string var wd string
// Load configuration from the specified path or the default path.
if configPath != "" { if configPath != "" {
cfg, err = config.LoadConfig(configPath) cfg, err = config.LoadConfig(configPath)
} else { } else {
@@ -65,12 +76,14 @@ func main() {
log.Fatalf("failed to load config: %v", err) log.Fatalf("failed to load config: %v", err)
} }
// Set the log level based on the configuration.
if cfg.Debug { if cfg.Debug {
log.SetLevel(log.DebugLevel) log.SetLevel(log.DebugLevel)
} else { } else {
log.SetLevel(log.InfoLevel) log.SetLevel(log.InfoLevel)
} }
// Expand the tilde (~) in the auth directory path to the user's home directory.
if strings.HasPrefix(cfg.AuthDir, "~") { if strings.HasPrefix(cfg.AuthDir, "~") {
home, errUserHomeDir := os.UserHomeDir() home, errUserHomeDir := os.UserHomeDir()
if errUserHomeDir != nil { if errUserHomeDir != nil {
@@ -85,6 +98,7 @@ func main() {
} }
} }
// Either perform login or start the service based on the 'login' flag.
if login { if login {
cmd.DoLogin(cfg, projectID) cmd.DoLogin(cfg, projectID)
} else { } else {

View File

@@ -1,6 +1,6 @@
port: 8317 port: 8317
auth_dir: "~/.cli-proxy-api" auth_dir: "~/.cli-proxy-api"
debug: false debug: true
proxy-url: "" proxy-url: ""
api_keys: api_keys:
- "12345" - "12345"

View File

@@ -2,14 +2,12 @@ package api
import ( import (
"context" "context"
"encoding/json"
"fmt" "fmt"
"github.com/luispater/CLIProxyAPI/internal/api/translator"
"github.com/luispater/CLIProxyAPI/internal/client" "github.com/luispater/CLIProxyAPI/internal/client"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"github.com/tidwall/gjson" "github.com/tidwall/gjson"
"github.com/tidwall/sjson"
"net/http" "net/http"
"strings"
"sync" "sync"
"time" "time"
@@ -21,13 +19,15 @@ var (
lastUsedClientIndex = 0 lastUsedClientIndex = 0
) )
// APIHandlers contains the handlers for API endpoints // APIHandlers contains the handlers for API endpoints.
// It holds a pool of clients to interact with the backend service.
type APIHandlers struct { type APIHandlers struct {
cliClients []*client.Client cliClients []*client.Client
debug bool debug bool
} }
// NewAPIHandlers creates a new API handlers instance // NewAPIHandlers creates a new API handlers instance.
// It takes a slice of clients and a debug flag as input.
func NewAPIHandlers(cliClients []*client.Client, debug bool) *APIHandlers { func NewAPIHandlers(cliClients []*client.Client, debug bool) *APIHandlers {
return &APIHandlers{ return &APIHandlers{
cliClients: cliClients, cliClients: cliClients,
@@ -35,6 +35,8 @@ func NewAPIHandlers(cliClients []*client.Client, debug bool) *APIHandlers {
} }
} }
// Models handles the /v1/models endpoint.
// It returns a hardcoded list of available AI models.
func (h *APIHandlers) Models(c *gin.Context) { func (h *APIHandlers) Models(c *gin.Context) {
c.JSON(http.StatusOK, gin.H{ c.JSON(http.StatusOK, gin.H{
"data": []map[string]any{ "data": []map[string]any{
@@ -162,15 +164,23 @@ func (h *APIHandlers) Models(c *gin.Context) {
}) })
} }
// ChatCompletions handles the /v1/chat/completions endpoint // ChatCompletions handles the /v1/chat/completions endpoint.
// It determines whether the request is for a streaming or non-streaming response
// and calls the appropriate handler.
func (h *APIHandlers) ChatCompletions(c *gin.Context) { func (h *APIHandlers) ChatCompletions(c *gin.Context) {
rawJson, err := c.GetRawData() rawJson, err := c.GetRawData()
// If data retrieval fails, return 400 error // If data retrieval fails, return a 400 Bad Request error.
if err != nil { if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("Invalid request: %v", err), "code": 400}) c.JSON(http.StatusBadRequest, ErrorResponse{
Error: ErrorDetail{
Message: fmt.Sprintf("Invalid request: %v", err),
Type: "invalid_request_error",
},
})
return return
} }
// Check if the client requested a streaming response.
streamResult := gjson.GetBytes(rawJson, "stream") streamResult := gjson.GetBytes(rawJson, "stream")
if streamResult.Type == gjson.True { if streamResult.Type == gjson.True {
h.handleStreamingResponse(c, rawJson) h.handleStreamingResponse(c, rawJson)
@@ -179,184 +189,9 @@ func (h *APIHandlers) ChatCompletions(c *gin.Context) {
} }
} }
func (h *APIHandlers) prepareRequest(rawJson []byte) (string, []client.Content, []client.ToolDeclaration) { // handleNonStreamingResponse handles non-streaming chat completion responses.
// log.Debug(string(rawJson)) // It selects a client from the pool, sends the request, and aggregates the response
modelName := "gemini-2.5-pro" // before sending it back to the client.
modelResult := gjson.GetBytes(rawJson, "model")
if modelResult.Type == gjson.String {
modelName = modelResult.String()
}
contents := make([]client.Content, 0)
messagesResult := gjson.GetBytes(rawJson, "messages")
if messagesResult.IsArray() {
messagesResults := messagesResult.Array()
for i := 0; i < len(messagesResults); i++ {
messageResult := messagesResults[i]
roleResult := messageResult.Get("role")
contentResult := messageResult.Get("content")
if roleResult.Type == gjson.String {
if roleResult.String() == "system" {
if contentResult.Type == gjson.String {
contents = append(contents, client.Content{Role: "user", Parts: []client.Part{{Text: contentResult.String()}}})
} else if contentResult.IsObject() {
contentTypeResult := contentResult.Get("type")
if contentTypeResult.Type == gjson.String && contentTypeResult.String() == "text" {
contentTextResult := contentResult.Get("text")
if contentTextResult.Type == gjson.String {
contents = append(contents, client.Content{Role: "user", Parts: []client.Part{{Text: contentTextResult.String()}}})
contents = append(contents, client.Content{Role: "model", Parts: []client.Part{{Text: "Understood. I will follow these instructions and use my tools to assist you."}}})
}
}
}
} else if roleResult.String() == "user" {
if contentResult.Type == gjson.String {
contents = append(contents, client.Content{Role: "user", Parts: []client.Part{{Text: contentResult.String()}}})
} else if contentResult.IsObject() {
contentTypeResult := contentResult.Get("type")
if contentTypeResult.Type == gjson.String && contentTypeResult.String() == "text" {
contentTextResult := contentResult.Get("text")
if contentTextResult.Type == gjson.String {
contents = append(contents, client.Content{Role: "user", Parts: []client.Part{{Text: contentTextResult.String()}}})
}
}
} else if contentResult.IsArray() {
contentItemResults := contentResult.Array()
parts := make([]client.Part, 0)
for j := 0; j < len(contentItemResults); j++ {
contentItemResult := contentItemResults[j]
contentTypeResult := contentItemResult.Get("type")
if contentTypeResult.Type == gjson.String && contentTypeResult.String() == "text" {
contentTextResult := contentItemResult.Get("text")
if contentTextResult.Type == gjson.String {
parts = append(parts, client.Part{Text: contentTextResult.String()})
}
} else if contentTypeResult.Type == gjson.String && contentTypeResult.String() == "image_url" {
imageURLResult := contentItemResult.Get("image_url.url")
if imageURLResult.Type == gjson.String {
imageURL := imageURLResult.String()
if len(imageURL) > 5 {
imageURLs := strings.SplitN(imageURL[5:], ";", 2)
if len(imageURLs) == 2 {
if len(imageURLs[1]) > 7 {
parts = append(parts, client.Part{InlineData: &client.InlineData{
MimeType: imageURLs[0],
Data: imageURLs[1][7:],
}})
}
}
}
}
} else if contentTypeResult.Type == gjson.String && contentTypeResult.String() == "file" {
filenameResult := contentItemResult.Get("file.filename")
fileDataResult := contentItemResult.Get("file.file_data")
if filenameResult.Type == gjson.String && fileDataResult.Type == gjson.String {
filename := filenameResult.String()
splitFilename := strings.Split(filename, ".")
ext := splitFilename[len(splitFilename)-1]
mimeType, ok := MimeTypes[ext]
if !ok {
log.Warnf("Unknown file name extension '%s' at index %d, skipping file", ext, j)
continue
}
parts = append(parts, client.Part{InlineData: &client.InlineData{
MimeType: mimeType,
Data: fileDataResult.String(),
}})
}
}
}
contents = append(contents, client.Content{Role: "user", Parts: parts})
}
} else if roleResult.String() == "assistant" {
if contentResult.Type == gjson.String {
contents = append(contents, client.Content{Role: "model", Parts: []client.Part{{Text: contentResult.String()}}})
} else if contentResult.IsObject() {
contentTypeResult := contentResult.Get("type")
if contentTypeResult.Type == gjson.String && contentTypeResult.String() == "text" {
contentTextResult := contentResult.Get("text")
if contentTextResult.Type == gjson.String {
contents = append(contents, client.Content{Role: "user", Parts: []client.Part{{Text: contentTextResult.String()}}})
}
}
} else if !contentResult.Exists() || contentResult.Type == gjson.Null {
toolCallsResult := messageResult.Get("tool_calls")
if toolCallsResult.IsArray() {
tcsResult := toolCallsResult.Array()
for j := 0; j < len(tcsResult); j++ {
tcResult := tcsResult[j]
functionNameResult := tcResult.Get("function.name")
functionArguments := tcResult.Get("function.arguments")
if functionNameResult.Exists() && functionNameResult.Type == gjson.String && functionArguments.Exists() && functionArguments.Type == gjson.String {
var args map[string]any
err := json.Unmarshal([]byte(functionArguments.String()), &args)
if err == nil {
contents = append(contents, client.Content{
Role: "model", Parts: []client.Part{
{
FunctionCall: &client.FunctionCall{
Name: functionNameResult.String(),
Args: args,
},
},
},
})
}
}
}
}
}
} else if roleResult.String() == "tool" {
toolCallIDResult := messageResult.Get("tool_call_id")
if toolCallIDResult.Exists() && toolCallIDResult.Type == gjson.String {
if contentResult.Type == gjson.String {
functionResponse := client.FunctionResponse{Name: toolCallIDResult.String(), Response: map[string]interface{}{"result": contentResult.String()}}
contents = append(contents, client.Content{Role: "tool", Parts: []client.Part{{FunctionResponse: &functionResponse}}})
} else if contentResult.IsObject() {
contentTypeResult := contentResult.Get("type")
if contentTypeResult.Type == gjson.String && contentTypeResult.String() == "text" {
contentTextResult := contentResult.Get("text")
if contentTextResult.Type == gjson.String {
functionResponse := client.FunctionResponse{Name: toolCallIDResult.String(), Response: map[string]interface{}{"result": contentResult.String()}}
contents = append(contents, client.Content{Role: "tool", Parts: []client.Part{{FunctionResponse: &functionResponse}}})
}
}
}
}
}
}
}
}
var tools []client.ToolDeclaration
toolsResult := gjson.GetBytes(rawJson, "tools")
if toolsResult.IsArray() {
tools = make([]client.ToolDeclaration, 1)
tools[0].FunctionDeclarations = make([]any, 0)
toolsResults := toolsResult.Array()
for i := 0; i < len(toolsResults); i++ {
toolTypeResult := toolsResults[i].Get("type")
if toolTypeResult.Type != gjson.String || toolTypeResult.String() != "function" {
continue
}
functionTypeResult := toolsResults[i].Get("function")
if functionTypeResult.Exists() && functionTypeResult.IsObject() {
var functionDeclaration any
err := json.Unmarshal([]byte(functionTypeResult.Raw), &functionDeclaration)
if err == nil {
tools[0].FunctionDeclarations = append(tools[0].FunctionDeclarations, functionDeclaration)
}
}
}
} else {
tools = make([]client.ToolDeclaration, 0)
}
return modelName, contents, tools
}
// handleNonStreamingResponse handles non-streaming responses
func (h *APIHandlers) handleNonStreamingResponse(c *gin.Context, rawJson []byte) { func (h *APIHandlers) handleNonStreamingResponse(c *gin.Context, rawJson []byte) {
c.Header("Content-Type", "application/json") c.Header("Content-Type", "application/json")
@@ -372,7 +207,7 @@ func (h *APIHandlers) handleNonStreamingResponse(c *gin.Context, rawJson []byte)
return return
} }
modelName, contents, tools := h.prepareRequest(rawJson) modelName, contents, tools := translator.PrepareRequest(rawJson)
cliCtx, cliCancel := context.WithCancel(context.Background()) cliCtx, cliCancel := context.WithCancel(context.Background())
var cliClient *client.Client var cliClient *client.Client
defer func() { defer func() {
@@ -425,19 +260,13 @@ func (h *APIHandlers) handleNonStreamingResponse(c *gin.Context, rawJson []byte)
cliCancel() cliCancel()
return return
} else { } else {
jsonTemplate = h.convertCliToOpenAINonStream(jsonTemplate, chunk) jsonTemplate = translator.ConvertCliToOpenAINonStream(jsonTemplate, chunk)
} }
case err, okError := <-errChan: case err, okError := <-errChan:
if okError { if okError {
c.Status(err.StatusCode) c.Status(err.StatusCode)
_, _ = fmt.Fprint(c.Writer, err.Error.Error()) _, _ = fmt.Fprint(c.Writer, err.Error.Error())
flusher.Flush() flusher.Flush()
// c.JSON(http.StatusInternalServerError, ErrorResponse{
// Error: ErrorDetail{
// Message: err.Error(),
// Type: "server_error",
// },
// })
cliCancel() cliCancel()
return return
} }
@@ -455,7 +284,7 @@ func (h *APIHandlers) handleStreamingResponse(c *gin.Context, rawJson []byte) {
c.Header("Connection", "keep-alive") c.Header("Connection", "keep-alive")
c.Header("Access-Control-Allow-Origin", "*") c.Header("Access-Control-Allow-Origin", "*")
// Handle streaming manually // Get the http.Flusher interface to manually flush the response.
flusher, ok := c.Writer.(http.Flusher) flusher, ok := c.Writer.(http.Flusher)
if !ok { if !ok {
c.JSON(http.StatusInternalServerError, ErrorResponse{ c.JSON(http.StatusInternalServerError, ErrorResponse{
@@ -466,28 +295,33 @@ func (h *APIHandlers) handleStreamingResponse(c *gin.Context, rawJson []byte) {
}) })
return return
} }
modelName, contents, tools := h.prepareRequest(rawJson)
// Prepare the request for the backend client.
modelName, contents, tools := translator.PrepareRequest(rawJson)
cliCtx, cliCancel := context.WithCancel(context.Background()) cliCtx, cliCancel := context.WithCancel(context.Background())
var cliClient *client.Client var cliClient *client.Client
defer func() { defer func() {
// Ensure the client's mutex is unlocked on function exit.
if cliClient != nil { if cliClient != nil {
cliClient.RequestMutex.Unlock() cliClient.RequestMutex.Unlock()
} }
}() }()
// Lock the mutex to update the last used page index // Use a round-robin approach to select the next available client.
// This distributes the load among the available clients.
mutex.Lock() mutex.Lock()
startIndex := lastUsedClientIndex startIndex := lastUsedClientIndex
currentIndex := (startIndex + 1) % len(h.cliClients) currentIndex := (startIndex + 1) % len(h.cliClients)
lastUsedClientIndex = currentIndex lastUsedClientIndex = currentIndex
mutex.Unlock() mutex.Unlock()
// Reorder the pages to start from the last used index // Reorder the clients to start from the next client in the rotation.
reorderedPages := make([]*client.Client, len(h.cliClients)) reorderedPages := make([]*client.Client, len(h.cliClients))
for i := 0; i < len(h.cliClients); i++ { for i := 0; i < len(h.cliClients); i++ {
reorderedPages[i] = h.cliClients[(startIndex+1+i)%len(h.cliClients)] reorderedPages[i] = h.cliClients[(startIndex+1+i)%len(h.cliClients)]
} }
// Attempt to lock a client for the request.
locked := false locked := false
for i := 0; i < len(reorderedPages); i++ { for i := 0; i < len(reorderedPages); i++ {
cliClient = reorderedPages[i] cliClient = reorderedPages[i]
@@ -496,235 +330,52 @@ func (h *APIHandlers) handleStreamingResponse(c *gin.Context, rawJson []byte) {
break break
} }
} }
// If no client is available, block and wait for the first client.
if !locked { if !locked {
cliClient = h.cliClients[0] cliClient = h.cliClients[0]
cliClient.RequestMutex.Lock() cliClient.RequestMutex.Lock()
} }
log.Debugf("Request use account: %s, project id: %s", cliClient.GetEmail(), cliClient.GetProjectID()) log.Debugf("Request use account: %s, project id: %s", cliClient.GetEmail(), cliClient.GetProjectID())
// Send the message and receive response chunks and errors via channels.
respChan, errChan := cliClient.SendMessageStream(cliCtx, rawJson, modelName, contents, tools) respChan, errChan := cliClient.SendMessageStream(cliCtx, rawJson, modelName, contents, tools)
for { for {
select { select {
// Handle client disconnection.
case <-c.Request.Context().Done(): case <-c.Request.Context().Done():
if c.Request.Context().Err().Error() == "context canceled" { if c.Request.Context().Err().Error() == "context canceled" {
log.Debugf("Client disconnected: %v", c.Request.Context().Err()) log.Debugf("Client disconnected: %v", c.Request.Context().Err())
cliCancel() cliCancel() // Cancel the backend request.
return return
} }
// Process incoming response chunks.
case chunk, okStream := <-respChan: case chunk, okStream := <-respChan:
if !okStream { if !okStream {
// Stream is closed, send the final [DONE] message.
_, _ = fmt.Fprintf(c.Writer, "data: [DONE]\n\n") _, _ = fmt.Fprintf(c.Writer, "data: [DONE]\n\n")
flusher.Flush() flusher.Flush()
cliCancel() cliCancel()
return return
} else { } else {
openAIFormat := h.convertCliToOpenAI(chunk) // Convert the chunk to OpenAI format and send it to the client.
openAIFormat := translator.ConvertCliToOpenAI(chunk)
if openAIFormat != "" { if openAIFormat != "" {
_, _ = fmt.Fprintf(c.Writer, "data: %s\n\n", openAIFormat) _, _ = fmt.Fprintf(c.Writer, "data: %s\n\n", openAIFormat)
flusher.Flush() flusher.Flush()
} }
} }
// Handle errors from the backend.
case err, okError := <-errChan: case err, okError := <-errChan:
if okError { if okError {
c.Status(err.StatusCode) c.Status(err.StatusCode)
_, _ = fmt.Fprint(c.Writer, err.Error.Error()) _, _ = fmt.Fprint(c.Writer, err.Error.Error())
flusher.Flush() flusher.Flush()
// c.JSON(http.StatusInternalServerError, ErrorResponse{
// Error: ErrorDetail{
// Message: err.Error(),
// Type: "server_error",
// },
// })
cliCancel() cliCancel()
return return
} }
// Send a keep-alive signal to the client.
case <-time.After(500 * time.Millisecond): case <-time.After(500 * time.Millisecond):
_, _ = c.Writer.Write([]byte(": CLI-PROXY-API PROCESSING\n\n")) _, _ = c.Writer.Write([]byte(": CLI-PROXY-API PROCESSING\n\n"))
flusher.Flush() flusher.Flush()
} }
} }
} }
func (h *APIHandlers) convertCliToOpenAI(rawJson []byte) string {
// log.Debugf(string(rawJson))
template := `{"id":"","object":"chat.completion.chunk","created":12345,"model":"model","choices":[{"index":0,"delta":{"role":null,"content":null,"reasoning_content":null,"tool_calls":null},"finish_reason":null,"native_finish_reason":null}]}`
modelVersionResult := gjson.GetBytes(rawJson, "response.modelVersion")
if modelVersionResult.Exists() && modelVersionResult.Type == gjson.String {
template, _ = sjson.Set(template, "model", modelVersionResult.String())
}
createTimeResult := gjson.GetBytes(rawJson, "response.createTime")
if createTimeResult.Exists() && createTimeResult.Type == gjson.String {
t, err := time.Parse(time.RFC3339Nano, createTimeResult.String())
var unixTimestamp int64
if err == nil {
unixTimestamp = t.Unix()
} else {
unixTimestamp = time.Now().Unix()
}
template, _ = sjson.Set(template, "created", unixTimestamp)
}
responseIdResult := gjson.GetBytes(rawJson, "response.responseId")
if responseIdResult.Exists() && responseIdResult.Type == gjson.String {
template, _ = sjson.Set(template, "id", responseIdResult.String())
}
finishReasonResult := gjson.GetBytes(rawJson, "response.candidates.0.finishReason")
if finishReasonResult.Exists() && finishReasonResult.Type == gjson.String {
template, _ = sjson.Set(template, "choices.0.finish_reason", finishReasonResult.String())
template, _ = sjson.Set(template, "choices.0.native_finish_reason", finishReasonResult.String())
}
usageResult := gjson.GetBytes(rawJson, "response.usageMetadata")
candidatesTokenCountResult := usageResult.Get("candidatesTokenCount")
if candidatesTokenCountResult.Exists() && candidatesTokenCountResult.Type == gjson.Number {
template, _ = sjson.Set(template, "usage.completion_tokens", candidatesTokenCountResult.Int())
}
totalTokenCountResult := usageResult.Get("totalTokenCount")
if totalTokenCountResult.Exists() && totalTokenCountResult.Type == gjson.Number {
template, _ = sjson.Set(template, "usage.total_tokens", totalTokenCountResult.Int())
}
thoughtsTokenCountResult := usageResult.Get("thoughtsTokenCount")
promptTokenCountResult := usageResult.Get("promptTokenCount")
if promptTokenCountResult.Exists() && promptTokenCountResult.Type == gjson.Number {
if thoughtsTokenCountResult.Exists() && thoughtsTokenCountResult.Type == gjson.Number {
template, _ = sjson.Set(template, "usage.prompt_tokens", promptTokenCountResult.Int()+thoughtsTokenCountResult.Int())
} else {
template, _ = sjson.Set(template, "usage.prompt_tokens", promptTokenCountResult.Int())
}
}
if thoughtsTokenCountResult.Exists() && thoughtsTokenCountResult.Type == gjson.Number {
template, _ = sjson.Set(template, "usage.completion_tokens_details.reasoning_tokens", thoughtsTokenCountResult.Int())
}
partResult := gjson.GetBytes(rawJson, "response.candidates.0.content.parts.0")
partTextResult := partResult.Get("text")
functionCallResult := partResult.Get("functionCall")
if partTextResult.Exists() && partTextResult.Type == gjson.String {
partThoughtResult := partResult.Get("thought")
if partThoughtResult.Exists() && partThoughtResult.Type == gjson.True {
template, _ = sjson.Set(template, "choices.0.delta.reasoning_content", partTextResult.String())
} else {
template, _ = sjson.Set(template, "choices.0.delta.content", partTextResult.String())
}
template, _ = sjson.Set(template, "choices.0.delta.role", "assistant")
} else if functionCallResult.Exists() {
functionCallTemplate := `[{"id": "","type": "function","function": {"name": "","arguments": ""}}]`
fcNameResult := functionCallResult.Get("name")
if fcNameResult.Exists() && fcNameResult.Type == gjson.String {
functionCallTemplate, _ = sjson.Set(functionCallTemplate, "0.id", fcNameResult.String())
functionCallTemplate, _ = sjson.Set(functionCallTemplate, "0.function.name", fcNameResult.String())
}
fcArgsResult := functionCallResult.Get("args")
if fcArgsResult.Exists() && fcArgsResult.IsObject() {
functionCallTemplate, _ = sjson.Set(functionCallTemplate, "0.function.arguments", fcArgsResult.Raw)
}
template, _ = sjson.Set(template, "choices.0.delta.role", "assistant")
template, _ = sjson.SetRaw(template, "choices.0.delta.tool_calls", functionCallTemplate)
} else {
return ""
}
return template
}
func (h *APIHandlers) convertCliToOpenAINonStream(template string, rawJson []byte) string {
modelVersionResult := gjson.GetBytes(rawJson, "response.modelVersion")
if modelVersionResult.Exists() && modelVersionResult.Type == gjson.String {
template, _ = sjson.Set(template, "model", modelVersionResult.String())
}
createTimeResult := gjson.GetBytes(rawJson, "response.createTime")
if createTimeResult.Exists() && createTimeResult.Type == gjson.String {
t, err := time.Parse(time.RFC3339Nano, createTimeResult.String())
var unixTimestamp int64
if err == nil {
unixTimestamp = t.Unix()
} else {
unixTimestamp = time.Now().Unix()
}
template, _ = sjson.Set(template, "created", unixTimestamp)
}
responseIdResult := gjson.GetBytes(rawJson, "response.responseId")
if responseIdResult.Exists() && responseIdResult.Type == gjson.String {
template, _ = sjson.Set(template, "id", responseIdResult.String())
}
finishReasonResult := gjson.GetBytes(rawJson, "response.candidates.0.finishReason")
if finishReasonResult.Exists() && finishReasonResult.Type == gjson.String {
template, _ = sjson.Set(template, "choices.0.finish_reason", finishReasonResult.String())
template, _ = sjson.Set(template, "choices.0.native_finish_reason", finishReasonResult.String())
}
usageResult := gjson.GetBytes(rawJson, "response.usageMetadata")
candidatesTokenCountResult := usageResult.Get("candidatesTokenCount")
if candidatesTokenCountResult.Exists() && candidatesTokenCountResult.Type == gjson.Number {
template, _ = sjson.Set(template, "usage.completion_tokens", candidatesTokenCountResult.Int())
}
totalTokenCountResult := usageResult.Get("totalTokenCount")
if totalTokenCountResult.Exists() && totalTokenCountResult.Type == gjson.Number {
template, _ = sjson.Set(template, "usage.total_tokens", totalTokenCountResult.Int())
}
thoughtsTokenCountResult := usageResult.Get("thoughtsTokenCount")
promptTokenCountResult := usageResult.Get("promptTokenCount")
if promptTokenCountResult.Exists() && promptTokenCountResult.Type == gjson.Number {
if thoughtsTokenCountResult.Exists() && thoughtsTokenCountResult.Type == gjson.Number {
template, _ = sjson.Set(template, "usage.prompt_tokens", promptTokenCountResult.Int()+thoughtsTokenCountResult.Int())
} else {
template, _ = sjson.Set(template, "usage.prompt_tokens", promptTokenCountResult.Int())
}
}
if thoughtsTokenCountResult.Exists() && thoughtsTokenCountResult.Type == gjson.Number {
template, _ = sjson.Set(template, "usage.completion_tokens_details.reasoning_tokens", thoughtsTokenCountResult.Int())
}
partResult := gjson.GetBytes(rawJson, "response.candidates.0.content.parts.0")
partTextResult := partResult.Get("text")
functionCallResult := partResult.Get("functionCall")
if partTextResult.Exists() && partTextResult.Type == gjson.String {
partThoughtResult := partResult.Get("thought")
if partThoughtResult.Exists() && partThoughtResult.Type == gjson.True {
reasoningContentResult := gjson.Get(template, "choices.0.message.reasoning_content")
if reasoningContentResult.Type == gjson.String {
template, _ = sjson.Set(template, "choices.0.message.reasoning_content", reasoningContentResult.String()+partTextResult.String())
} else {
template, _ = sjson.Set(template, "choices.0.message.reasoning_content", partTextResult.String())
}
} else {
reasoningContentResult := gjson.Get(template, "choices.0.message.content")
if reasoningContentResult.Type == gjson.String {
template, _ = sjson.Set(template, "choices.0.message.content", reasoningContentResult.String()+partTextResult.String())
} else {
template, _ = sjson.Set(template, "choices.0.message.content", partTextResult.String())
}
}
template, _ = sjson.Set(template, "choices.0.message.role", "assistant")
} else if functionCallResult.Exists() {
toolCallsResult := gjson.Get(template, "choices.0.message.tool_calls")
if !toolCallsResult.Exists() || toolCallsResult.Type == gjson.Null {
template, _ = sjson.SetRaw(template, "choices.0.message.tool_calls", `[]`)
}
functionCallItemTemplate := `{"id": "","type": "function","function": {"name": "","arguments": ""}}`
fcNameResult := functionCallResult.Get("name")
if fcNameResult.Exists() && fcNameResult.Type == gjson.String {
functionCallItemTemplate, _ = sjson.Set(functionCallItemTemplate, "id", fcNameResult.String())
functionCallItemTemplate, _ = sjson.Set(functionCallItemTemplate, "function.name", fcNameResult.String())
}
fcArgsResult := functionCallResult.Get("args")
if fcArgsResult.Exists() && fcArgsResult.IsObject() {
functionCallItemTemplate, _ = sjson.Set(functionCallItemTemplate, "function.arguments", fcArgsResult.Raw)
}
template, _ = sjson.Set(template, "choices.0.message.role", "assistant")
template, _ = sjson.SetRaw(template, "choices.0.message.tool_calls.-1", functionCallItemTemplate)
} else {
return ""
}
return template
}

View File

@@ -1,13 +1,18 @@
package api package api
// ErrorResponse represents an error response // ErrorResponse represents a standard error response format for the API.
// It contains a single ErrorDetail field.
type ErrorResponse struct { type ErrorResponse struct {
Error ErrorDetail `json:"error"` Error ErrorDetail `json:"error"`
} }
// ErrorDetail represents error details // ErrorDetail provides specific information about an error that occurred.
// It includes a human-readable message, an error type, and an optional error code.
type ErrorDetail struct { type ErrorDetail struct {
// A human-readable message providing more details about the error.
Message string `json:"message"` Message string `json:"message"`
Type string `json:"type"` // The type of error that occurred (e.g., "invalid_request_error").
Code string `json:"code,omitempty"` Type string `json:"type"`
// A short code identifying the error, if applicable.
Code string `json:"code,omitempty"`
} }

View File

@@ -11,7 +11,8 @@ import (
"strings" "strings"
) )
// Server represents the API server // Server represents the main API server.
// It encapsulates the Gin engine, HTTP server, handlers, and configuration.
type Server struct { type Server struct {
engine *gin.Engine engine *gin.Engine
server *http.Server server *http.Server
@@ -19,14 +20,18 @@ type Server struct {
cfg *ServerConfig cfg *ServerConfig
} }
// ServerConfig contains configuration for the API server // ServerConfig contains the configuration for the API server.
type ServerConfig struct { type ServerConfig struct {
Port string // Port is the port number the server will listen on.
Debug bool Port string
// Debug enables or disables debug mode for the server and Gin.
Debug bool
// ApiKeys is a list of valid API keys for authentication.
ApiKeys []string ApiKeys []string
} }
// NewServer creates a new API server instance // NewServer creates and initializes a new API server instance.
// It sets up the Gin engine, middleware, routes, and handlers.
func NewServer(config *ServerConfig, cliClients []*client.Client) *Server { func NewServer(config *ServerConfig, cliClients []*client.Client) *Server {
// Set gin mode // Set gin mode
if !config.Debug { if !config.Debug {
@@ -63,7 +68,8 @@ func NewServer(config *ServerConfig, cliClients []*client.Client) *Server {
return s return s
} }
// setupRoutes configures the API routes // setupRoutes configures the API routes for the server.
// It defines the endpoints and associates them with their respective handlers.
func (s *Server) setupRoutes() { func (s *Server) setupRoutes() {
// OpenAI compatible API routes // OpenAI compatible API routes
v1 := s.engine.Group("/v1") v1 := s.engine.Group("/v1")
@@ -86,11 +92,12 @@ func (s *Server) setupRoutes() {
}) })
} }
// Start starts the API server // Start begins listening for and serving HTTP requests.
// It's a blocking call and will only return on an unrecoverable error.
func (s *Server) Start() error { func (s *Server) Start() error {
log.Debugf("Starting API server on %s", s.server.Addr) log.Debugf("Starting API server on %s", s.server.Addr)
// Start the HTTP server // Start the HTTP server.
if err := s.server.ListenAndServe(); err != nil && !errors.Is(err, http.ErrServerClosed) { if err := s.server.ListenAndServe(); err != nil && !errors.Is(err, http.ErrServerClosed) {
return fmt.Errorf("failed to start HTTP server: %v", err) return fmt.Errorf("failed to start HTTP server: %v", err)
} }
@@ -98,11 +105,12 @@ func (s *Server) Start() error {
return nil return nil
} }
// Stop gracefully stops the API server // Stop gracefully shuts down the API server without interrupting any
// active connections.
func (s *Server) Stop(ctx context.Context) error { func (s *Server) Stop(ctx context.Context) error {
log.Debug("Stopping API server...") log.Debug("Stopping API server...")
// Shutdown the HTTP server // Shutdown the HTTP server.
if err := s.server.Shutdown(ctx); err != nil { if err := s.server.Shutdown(ctx); err != nil {
return fmt.Errorf("failed to shutdown HTTP server: %v", err) return fmt.Errorf("failed to shutdown HTTP server: %v", err)
} }
@@ -111,7 +119,8 @@ func (s *Server) Stop(ctx context.Context) error {
return nil return nil
} }
// corsMiddleware adds CORS headers // corsMiddleware returns a Gin middleware handler that adds CORS headers
// to every response, allowing cross-origin requests.
func corsMiddleware() gin.HandlerFunc { func corsMiddleware() gin.HandlerFunc {
return func(c *gin.Context) { return func(c *gin.Context) {
c.Header("Access-Control-Allow-Origin", "*") c.Header("Access-Control-Allow-Origin", "*")
@@ -127,7 +136,8 @@ func corsMiddleware() gin.HandlerFunc {
} }
} }
// AuthMiddleware authenticates requests using API keys // AuthMiddleware returns a Gin middleware handler that authenticates requests
// using API keys. If no API keys are configured, it allows all requests.
func AuthMiddleware(cfg *ServerConfig) gin.HandlerFunc { func AuthMiddleware(cfg *ServerConfig) gin.HandlerFunc {
return func(c *gin.Context) { return func(c *gin.Context) {
if len(cfg.ApiKeys) == 0 { if len(cfg.ApiKeys) == 0 {

View File

@@ -1,5 +1,7 @@
package api package translator
// MimeTypes is a comprehensive map of file extensions to their corresponding MIME types.
// This is used to identify the type of file being uploaded or processed.
var MimeTypes = map[string]string{ var MimeTypes = map[string]string{
"ez": "application/andrew-inset", "ez": "application/andrew-inset",
"aw": "application/applixware", "aw": "application/applixware",

View File

@@ -0,0 +1,163 @@
package translator
import (
"encoding/json"
"strings"
"github.com/luispater/CLIProxyAPI/internal/client"
log "github.com/sirupsen/logrus"
"github.com/tidwall/gjson"
)
// PrepareRequest translates a raw JSON request from an OpenAI-compatible format
// to the internal format expected by the backend client. It parses messages,
// roles, content types (text, image, file), and tool calls.
func PrepareRequest(rawJson []byte) (string, []client.Content, []client.ToolDeclaration) {
// Extract the model name from the request, defaulting to "gemini-2.5-pro".
modelName := "gemini-2.5-pro"
modelResult := gjson.GetBytes(rawJson, "model")
if modelResult.Type == gjson.String {
modelName = modelResult.String()
}
// Process the array of messages.
contents := make([]client.Content, 0)
messagesResult := gjson.GetBytes(rawJson, "messages")
if messagesResult.IsArray() {
messagesResults := messagesResult.Array()
for i := 0; i < len(messagesResults); i++ {
messageResult := messagesResults[i]
roleResult := messageResult.Get("role")
contentResult := messageResult.Get("content")
if roleResult.Type != gjson.String {
continue
}
switch roleResult.String() {
// System messages are converted to a user message followed by a model's acknowledgment.
case "system":
if contentResult.Type == gjson.String {
contents = append(contents, client.Content{Role: "user", Parts: []client.Part{{Text: contentResult.String()}}})
contents = append(contents, client.Content{Role: "model", Parts: []client.Part{{Text: "Understood. I will follow these instructions and use my tools to assist you."}}})
} else if contentResult.IsObject() {
// Handle object-based system messages.
if contentResult.Get("type").String() == "text" {
contents = append(contents, client.Content{Role: "user", Parts: []client.Part{{Text: contentResult.Get("text").String()}}})
contents = append(contents, client.Content{Role: "model", Parts: []client.Part{{Text: "Understood. I will follow these instructions and use my tools to assist you."}}})
}
}
// User messages can contain simple text or a multi-part body.
case "user":
if contentResult.Type == gjson.String {
contents = append(contents, client.Content{Role: "user", Parts: []client.Part{{Text: contentResult.String()}}})
} else if contentResult.IsArray() {
// Handle multi-part user messages (text, images, files).
contentItemResults := contentResult.Array()
parts := make([]client.Part, 0)
for j := 0; j < len(contentItemResults); j++ {
contentItemResult := contentItemResults[j]
contentTypeResult := contentItemResult.Get("type")
switch contentTypeResult.String() {
case "text":
parts = append(parts, client.Part{Text: contentItemResult.Get("text").String()})
case "image_url":
// Parse data URI for images.
imageURL := contentItemResult.Get("image_url.url").String()
if len(imageURL) > 5 {
imageURLs := strings.SplitN(imageURL[5:], ";", 2)
if len(imageURLs) == 2 && len(imageURLs[1]) > 7 {
parts = append(parts, client.Part{InlineData: &client.InlineData{
MimeType: imageURLs[0],
Data: imageURLs[1][7:],
}})
}
}
case "file":
// Handle file attachments by determining MIME type from extension.
filename := contentItemResult.Get("file.filename").String()
fileData := contentItemResult.Get("file.file_data").String()
ext := ""
if split := strings.Split(filename, "."); len(split) > 1 {
ext = split[len(split)-1]
}
if mimeType, ok := MimeTypes[ext]; ok {
parts = append(parts, client.Part{InlineData: &client.InlineData{
MimeType: mimeType,
Data: fileData,
}})
} else {
log.Warnf("Unknown file name extension '%s' at index %d, skipping file", ext, j)
}
}
}
contents = append(contents, client.Content{Role: "user", Parts: parts})
}
// Assistant messages can contain text or tool calls.
case "assistant":
if contentResult.Type == gjson.String {
contents = append(contents, client.Content{Role: "model", Parts: []client.Part{{Text: contentResult.String()}}})
} else if !contentResult.Exists() || contentResult.Type == gjson.Null {
// Handle tool calls made by the assistant.
toolCallsResult := messageResult.Get("tool_calls")
if toolCallsResult.IsArray() {
tcsResult := toolCallsResult.Array()
for j := 0; j < len(tcsResult); j++ {
tcResult := tcsResult[j]
functionName := tcResult.Get("function.name").String()
functionArgs := tcResult.Get("function.arguments").String()
var args map[string]any
if err := json.Unmarshal([]byte(functionArgs), &args); err == nil {
contents = append(contents, client.Content{
Role: "model", Parts: []client.Part{{
FunctionCall: &client.FunctionCall{
Name: functionName,
Args: args,
},
}},
})
}
}
}
}
// Tool messages contain the output of a tool call.
case "tool":
toolCallID := messageResult.Get("tool_call_id").String()
if toolCallID != "" {
var responseData string
if contentResult.Type == gjson.String {
responseData = contentResult.String()
} else if contentResult.IsObject() && contentResult.Get("type").String() == "text" {
responseData = contentResult.Get("text").String()
}
functionResponse := client.FunctionResponse{Name: toolCallID, Response: map[string]interface{}{"result": responseData}}
contents = append(contents, client.Content{Role: "tool", Parts: []client.Part{{FunctionResponse: &functionResponse}}})
}
}
}
}
// Translate the tool declarations from the request.
var tools []client.ToolDeclaration
toolsResult := gjson.GetBytes(rawJson, "tools")
if toolsResult.IsArray() {
tools = make([]client.ToolDeclaration, 1)
tools[0].FunctionDeclarations = make([]any, 0)
toolsResults := toolsResult.Array()
for i := 0; i < len(toolsResults); i++ {
toolResult := toolsResults[i]
if toolResult.Get("type").String() == "function" {
functionTypeResult := toolResult.Get("function")
if functionTypeResult.Exists() && functionTypeResult.IsObject() {
var functionDeclaration any
if err := json.Unmarshal([]byte(functionTypeResult.Raw), &functionDeclaration); err == nil {
tools[0].FunctionDeclarations = append(tools[0].FunctionDeclarations, functionDeclaration)
}
}
}
}
} else {
tools = make([]client.ToolDeclaration, 0)
}
return modelName, contents, tools
}

View File

@@ -0,0 +1,169 @@
package translator
import (
"time"
"github.com/tidwall/gjson"
"github.com/tidwall/sjson"
)
// ConvertCliToOpenAI translates a single chunk of a streaming response from the
// backend client format to the OpenAI Server-Sent Events (SSE) format.
// It returns an empty string if the chunk contains no useful data.
func ConvertCliToOpenAI(rawJson []byte) string {
// Initialize the OpenAI SSE template.
template := `{"id":"","object":"chat.completion.chunk","created":12345,"model":"model","choices":[{"index":0,"delta":{"role":null,"content":null,"reasoning_content":null,"tool_calls":null},"finish_reason":null,"native_finish_reason":null}]}`
// Extract and set the model version.
if modelVersionResult := gjson.GetBytes(rawJson, "response.modelVersion"); modelVersionResult.Exists() {
template, _ = sjson.Set(template, "model", modelVersionResult.String())
}
// Extract and set the creation timestamp.
if createTimeResult := gjson.GetBytes(rawJson, "response.createTime"); createTimeResult.Exists() {
t, err := time.Parse(time.RFC3339Nano, createTimeResult.String())
unixTimestamp := time.Now().Unix()
if err == nil {
unixTimestamp = t.Unix()
}
template, _ = sjson.Set(template, "created", unixTimestamp)
}
// Extract and set the response ID.
if responseIdResult := gjson.GetBytes(rawJson, "response.responseId"); responseIdResult.Exists() {
template, _ = sjson.Set(template, "id", responseIdResult.String())
}
// Extract and set the finish reason.
if finishReasonResult := gjson.GetBytes(rawJson, "response.candidates.0.finishReason"); finishReasonResult.Exists() {
template, _ = sjson.Set(template, "choices.0.finish_reason", finishReasonResult.String())
template, _ = sjson.Set(template, "choices.0.native_finish_reason", finishReasonResult.String())
}
// Extract and set usage metadata (token counts).
if usageResult := gjson.GetBytes(rawJson, "response.usageMetadata"); usageResult.Exists() {
if candidatesTokenCountResult := usageResult.Get("candidatesTokenCount"); candidatesTokenCountResult.Exists() {
template, _ = sjson.Set(template, "usage.completion_tokens", candidatesTokenCountResult.Int())
}
if totalTokenCountResult := usageResult.Get("totalTokenCount"); totalTokenCountResult.Exists() {
template, _ = sjson.Set(template, "usage.total_tokens", totalTokenCountResult.Int())
}
promptTokenCount := usageResult.Get("promptTokenCount").Int()
thoughtsTokenCount := usageResult.Get("thoughtsTokenCount").Int()
template, _ = sjson.Set(template, "usage.prompt_tokens", promptTokenCount+thoughtsTokenCount)
if thoughtsTokenCount > 0 {
template, _ = sjson.Set(template, "usage.completion_tokens_details.reasoning_tokens", thoughtsTokenCount)
}
}
// Process the main content part of the response.
partResult := gjson.GetBytes(rawJson, "response.candidates.0.content.parts.0")
partTextResult := partResult.Get("text")
functionCallResult := partResult.Get("functionCall")
if partTextResult.Exists() {
// Handle text content, distinguishing between regular content and reasoning/thoughts.
if partResult.Get("thought").Bool() {
template, _ = sjson.Set(template, "choices.0.delta.reasoning_content", partTextResult.String())
} else {
template, _ = sjson.Set(template, "choices.0.delta.content", partTextResult.String())
}
template, _ = sjson.Set(template, "choices.0.delta.role", "assistant")
} else if functionCallResult.Exists() {
// Handle function call content.
functionCallTemplate := `[{"id": "","type": "function","function": {"name": "","arguments": ""}}]`
fcName := functionCallResult.Get("name").String()
functionCallTemplate, _ = sjson.Set(functionCallTemplate, "0.id", fcName)
functionCallTemplate, _ = sjson.Set(functionCallTemplate, "0.function.name", fcName)
if fcArgsResult := functionCallResult.Get("args"); fcArgsResult.Exists() {
functionCallTemplate, _ = sjson.Set(functionCallTemplate, "0.function.arguments", fcArgsResult.Raw)
}
template, _ = sjson.Set(template, "choices.0.delta.role", "assistant")
template, _ = sjson.SetRaw(template, "choices.0.delta.tool_calls", functionCallTemplate)
} else {
// If no usable content is found, return an empty string.
return ""
}
return template
}
// ConvertCliToOpenAINonStream aggregates response chunks from the backend client
// into a single, non-streaming OpenAI-compatible JSON response.
func ConvertCliToOpenAINonStream(template string, rawJson []byte) string {
// Extract and set metadata fields that are typically set once per response.
if gjson.Get(template, "id").String() == "" {
if modelVersionResult := gjson.GetBytes(rawJson, "response.modelVersion"); modelVersionResult.Exists() {
template, _ = sjson.Set(template, "model", modelVersionResult.String())
}
if createTimeResult := gjson.GetBytes(rawJson, "response.createTime"); createTimeResult.Exists() {
t, err := time.Parse(time.RFC3339Nano, createTimeResult.String())
unixTimestamp := time.Now().Unix()
if err == nil {
unixTimestamp = t.Unix()
}
template, _ = sjson.Set(template, "created", unixTimestamp)
}
if responseIdResult := gjson.GetBytes(rawJson, "response.responseId"); responseIdResult.Exists() {
template, _ = sjson.Set(template, "id", responseIdResult.String())
}
}
// Extract and set the finish reason.
if finishReasonResult := gjson.GetBytes(rawJson, "response.candidates.0.finishReason"); finishReasonResult.Exists() {
template, _ = sjson.Set(template, "choices.0.finish_reason", finishReasonResult.String())
template, _ = sjson.Set(template, "choices.0.native_finish_reason", finishReasonResult.String())
}
// Extract and set usage metadata (token counts).
if usageResult := gjson.GetBytes(rawJson, "response.usageMetadata"); usageResult.Exists() {
if candidatesTokenCountResult := usageResult.Get("candidatesTokenCount"); candidatesTokenCountResult.Exists() {
template, _ = sjson.Set(template, "usage.completion_tokens", candidatesTokenCountResult.Int())
}
if totalTokenCountResult := usageResult.Get("totalTokenCount"); totalTokenCountResult.Exists() {
template, _ = sjson.Set(template, "usage.total_tokens", totalTokenCountResult.Int())
}
promptTokenCount := usageResult.Get("promptTokenCount").Int()
thoughtsTokenCount := usageResult.Get("thoughtsTokenCount").Int()
template, _ = sjson.Set(template, "usage.prompt_tokens", promptTokenCount+thoughtsTokenCount)
if thoughtsTokenCount > 0 {
template, _ = sjson.Set(template, "usage.completion_tokens_details.reasoning_tokens", thoughtsTokenCount)
}
}
// Process the main content part of the response.
partResult := gjson.GetBytes(rawJson, "response.candidates.0.content.parts.0")
partTextResult := partResult.Get("text")
functionCallResult := partResult.Get("functionCall")
if partTextResult.Exists() {
// Append text content, distinguishing between regular content and reasoning.
if partResult.Get("thought").Bool() {
currentContent := gjson.Get(template, "choices.0.message.reasoning_content").String()
template, _ = sjson.Set(template, "choices.0.message.reasoning_content", currentContent+partTextResult.String())
} else {
currentContent := gjson.Get(template, "choices.0.message.content").String()
template, _ = sjson.Set(template, "choices.0.message.content", currentContent+partTextResult.String())
}
template, _ = sjson.Set(template, "choices.0.message.role", "assistant")
} else if functionCallResult.Exists() {
// Append function call content to the tool_calls array.
if !gjson.Get(template, "choices.0.message.tool_calls").Exists() {
template, _ = sjson.SetRaw(template, "choices.0.message.tool_calls", `[]`)
}
functionCallItemTemplate := `{"id": "","type": "function","function": {"name": "","arguments": ""}}`
fcName := functionCallResult.Get("name").String()
functionCallItemTemplate, _ = sjson.Set(functionCallItemTemplate, "id", fcName)
functionCallItemTemplate, _ = sjson.Set(functionCallItemTemplate, "function.name", fcName)
if fcArgsResult := functionCallResult.Get("args"); fcArgsResult.Exists() {
functionCallItemTemplate, _ = sjson.Set(functionCallItemTemplate, "function.arguments", fcArgsResult.Raw)
}
template, _ = sjson.Set(template, "choices.0.message.role", "assistant")
template, _ = sjson.SetRaw(template, "choices.0.message.tool_calls.-1", functionCallItemTemplate)
} else {
// If no usable content is found, return an empty string.
return ""
}
return template
}

View File

@@ -5,17 +5,18 @@ import (
"encoding/json" "encoding/json"
"errors" "errors"
"fmt" "fmt"
"github.com/luispater/CLIProxyAPI/internal/config"
log "github.com/sirupsen/logrus"
"github.com/skratchdot/open-golang/open"
"github.com/tidwall/gjson"
"golang.org/x/net/proxy"
"io" "io"
"net" "net"
"net/http" "net/http"
"net/url" "net/url"
"time" "time"
"github.com/luispater/CLIProxyAPI/internal/config"
log "github.com/sirupsen/logrus"
"github.com/skratchdot/open-golang/open"
"github.com/tidwall/gjson"
"golang.org/x/net/proxy"
"golang.org/x/oauth2" "golang.org/x/oauth2"
"golang.org/x/oauth2/google" "golang.org/x/oauth2/google"
) )
@@ -33,76 +34,78 @@ var (
} }
) )
// GetAuthenticatedClient configures and returns an HTTP client with OAuth2 tokens. // GetAuthenticatedClient configures and returns an HTTP client ready for making authenticated API calls.
// It handles the entire flow: loading, refreshing, and fetching new tokens. // It manages the entire OAuth2 flow, including handling proxies, loading existing tokens,
// initiating a new web-based OAuth flow if necessary, and refreshing tokens.
func GetAuthenticatedClient(ctx context.Context, ts *TokenStorage, cfg *config.Config) (*http.Client, error) { func GetAuthenticatedClient(ctx context.Context, ts *TokenStorage, cfg *config.Config) (*http.Client, error) {
// Configure proxy settings for the HTTP client if a proxy URL is provided.
proxyURL, err := url.Parse(cfg.ProxyUrl) proxyURL, err := url.Parse(cfg.ProxyUrl)
if err == nil { if err == nil {
var transport *http.Transport
if proxyURL.Scheme == "socks5" { if proxyURL.Scheme == "socks5" {
// Handle SOCKS5 proxy.
username := proxyURL.User.Username() username := proxyURL.User.Username()
password, _ := proxyURL.User.Password() password, _ := proxyURL.User.Password()
auth := &proxy.Auth{ auth := &proxy.Auth{User: username, Password: password}
User: username,
Password: password,
}
dialer, errSOCKS5 := proxy.SOCKS5("tcp", proxyURL.Host, auth, proxy.Direct) dialer, errSOCKS5 := proxy.SOCKS5("tcp", proxyURL.Host, auth, proxy.Direct)
if errSOCKS5 != nil { if errSOCKS5 != nil {
log.Fatalf("create SOCKS5 dialer failed: %v", errSOCKS5) log.Fatalf("create SOCKS5 dialer failed: %v", errSOCKS5)
} }
transport = &http.Transport{
transport := &http.Transport{ DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
DialContext: func(ctx context.Context, network, addr string) (c net.Conn, err error) {
return dialer.Dial(network, addr) return dialer.Dial(network, addr)
}, },
} }
proxyClient := &http.Client{
Transport: transport,
}
ctx = context.WithValue(ctx, oauth2.HTTPClient, proxyClient)
} else if proxyURL.Scheme == "http" || proxyURL.Scheme == "https" { } else if proxyURL.Scheme == "http" || proxyURL.Scheme == "https" {
transport := &http.Transport{ // Handle HTTP/HTTPS proxy.
Proxy: http.ProxyURL(proxyURL), transport = &http.Transport{Proxy: http.ProxyURL(proxyURL)}
} }
proxyClient := &http.Client{
Transport: transport, if transport != nil {
} proxyClient := &http.Client{Transport: transport}
ctx = context.WithValue(ctx, oauth2.HTTPClient, proxyClient) ctx = context.WithValue(ctx, oauth2.HTTPClient, proxyClient)
} }
} }
// Configure the OAuth2 client.
conf := &oauth2.Config{ conf := &oauth2.Config{
ClientID: oauthClientID, ClientID: oauthClientID,
ClientSecret: oauthClientSecret, ClientSecret: oauthClientSecret,
RedirectURL: "http://localhost:8085/oauth2callback", // Placeholder, will be updated RedirectURL: "http://localhost:8085/oauth2callback", // This will be used by the local server.
Scopes: oauthScopes, Scopes: oauthScopes,
Endpoint: google.Endpoint, Endpoint: google.Endpoint,
} }
var token *oauth2.Token var token *oauth2.Token
// If no token is found in storage, initiate the web-based OAuth flow.
if ts.Token == nil { if ts.Token == nil {
log.Info("Could not load token from file, starting OAuth flow.") log.Info("Could not load token from file, starting OAuth flow.")
token, err = getTokenFromWeb(ctx, conf) token, err = getTokenFromWeb(ctx, conf)
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to get token from web: %w", err) return nil, fmt.Errorf("failed to get token from web: %w", err)
} }
newTs, errSaveTokenToFile := createTokenStorage(ctx, conf, token, ts.ProjectID) // After getting a new token, create a new token storage object with user info.
if errSaveTokenToFile != nil { newTs, errCreateTokenStorage := createTokenStorage(ctx, conf, token, ts.ProjectID)
log.Errorf("Warning: failed to save token to file: %v", err) if errCreateTokenStorage != nil {
return nil, errSaveTokenToFile log.Errorf("Warning: failed to create token storage: %v", errCreateTokenStorage)
return nil, errCreateTokenStorage
} }
*ts = *newTs *ts = *newTs
} }
// Unmarshal the stored token into an oauth2.Token object.
tsToken, _ := json.Marshal(ts.Token) tsToken, _ := json.Marshal(ts.Token)
if err = json.Unmarshal(tsToken, &token); err != nil { if err = json.Unmarshal(tsToken, &token); err != nil {
return nil, err return nil, fmt.Errorf("failed to unmarshal token: %w", err)
} }
// Return an HTTP client that automatically handles token refreshing.
return conf.Client(ctx, token), nil return conf.Client(ctx, token), nil
} }
// createTokenStorage creates a token storage. // createTokenStorage creates a new TokenStorage object. It fetches the user's email
// using the provided token and populates the storage structure.
func createTokenStorage(ctx context.Context, config *oauth2.Config, token *oauth2.Token, projectID string) (*TokenStorage, error) { func createTokenStorage(ctx context.Context, config *oauth2.Config, token *oauth2.Token, projectID string) (*TokenStorage, error) {
httpClient := config.Client(ctx, token) httpClient := config.Client(ctx, token)
req, err := http.NewRequestWithContext(ctx, "GET", "https://www.googleapis.com/oauth2/v1/userinfo?alt=json", nil) req, err := http.NewRequestWithContext(ctx, "GET", "https://www.googleapis.com/oauth2/v1/userinfo?alt=json", nil)
@@ -117,7 +120,9 @@ func createTokenStorage(ctx context.Context, config *oauth2.Config, token *oauth
return nil, fmt.Errorf("failed to execute request: %w", err) return nil, fmt.Errorf("failed to execute request: %w", err)
} }
defer func() { defer func() {
_ = resp.Body.Close() if err = resp.Body.Close(); err != nil {
log.Printf("warn: failed to close response body: %v", err)
}
}() }()
bodyBytes, _ := io.ReadAll(resp.Body) bodyBytes, _ := io.ReadAll(resp.Body)
@@ -154,7 +159,10 @@ func createTokenStorage(ctx context.Context, config *oauth2.Config, token *oauth
return &ts, nil return &ts, nil
} }
// getTokenFromWeb starts a local server to handle the OAuth2 flow. // getTokenFromWeb initiates the web-based OAuth2 authorization flow.
// It starts a local HTTP server to listen for the callback from Google's auth server,
// opens the user's browser to the authorization URL, and exchanges the received
// authorization code for an access token.
func getTokenFromWeb(ctx context.Context, config *oauth2.Config) (*oauth2.Token, error) { func getTokenFromWeb(ctx context.Context, config *oauth2.Config) (*oauth2.Token, error) {
// Use a channel to pass the authorization code from the HTTP handler to the main function. // Use a channel to pass the authorization code from the HTTP handler to the main function.
codeChan := make(chan string) codeChan := make(chan string)

View File

@@ -1,9 +1,17 @@
package auth package auth
// TokenStorage defines the structure for storing OAuth2 token information,
// along with associated user and project details. This data is typically
// serialized to a JSON file for persistence.
type TokenStorage struct { type TokenStorage struct {
Token any `json:"token"` // Token holds the raw OAuth2 token data, including access and refresh tokens.
Token any `json:"token"`
// ProjectID is the Google Cloud Project ID associated with this token.
ProjectID string `json:"project_id"` ProjectID string `json:"project_id"`
Email string `json:"email"` // Email is the email address of the authenticated user.
Auto bool `json:"auto"` Email string `json:"email"`
Checked bool `json:"checked"` // Auto indicates if the project ID was automatically selected.
Auto bool `json:"auto"`
// Checked indicates if the associated Cloud AI API has been verified as enabled.
Checked bool `json:"checked"`
} }

View File

@@ -6,12 +6,6 @@ import (
"context" "context"
"encoding/json" "encoding/json"
"fmt" "fmt"
"github.com/luispater/CLIProxyAPI/internal/auth"
"github.com/luispater/CLIProxyAPI/internal/config"
log "github.com/sirupsen/logrus"
"github.com/tidwall/gjson"
"github.com/tidwall/sjson"
"golang.org/x/oauth2"
"io" "io"
"net/http" "net/http"
"os" "os"
@@ -20,6 +14,13 @@ import (
"strings" "strings"
"sync" "sync"
"time" "time"
"github.com/luispater/CLIProxyAPI/internal/auth"
"github.com/luispater/CLIProxyAPI/internal/config"
log "github.com/sirupsen/logrus"
"github.com/tidwall/gjson"
"github.com/tidwall/sjson"
"golang.org/x/oauth2"
) )
const ( const (
@@ -194,7 +195,9 @@ func (c *Client) makeAPIRequest(ctx context.Context, endpoint, method string, bo
return fmt.Errorf("failed to execute request: %w", err) return fmt.Errorf("failed to execute request: %w", err)
} }
defer func() { defer func() {
_ = resp.Body.Close() if err = resp.Body.Close(); err != nil {
log.Printf("warn: failed to close response body: %v", err)
}
}() }()
if resp.StatusCode < 200 || resp.StatusCode >= 300 { if resp.StatusCode < 200 || resp.StatusCode >= 300 {
@@ -253,7 +256,9 @@ func (c *Client) StreamAPIRequest(ctx context.Context, endpoint string, body int
if resp.StatusCode < 200 || resp.StatusCode >= 300 { if resp.StatusCode < 200 || resp.StatusCode >= 300 {
defer func() { defer func() {
_ = resp.Body.Close() if err = resp.Body.Close(); err != nil {
log.Printf("warn: failed to close response body: %v", err)
}
}() }()
bodyBytes, _ := io.ReadAll(resp.Body) bodyBytes, _ := io.ReadAll(resp.Body)
@@ -355,6 +360,9 @@ func (c *Client) SendMessageStream(ctx context.Context, rawJson []byte, model st
return dataChan, errChan return dataChan, errChan
} }
// CheckCloudAPIIsEnabled sends a simple test request to the API to verify
// that the Cloud AI API is enabled for the user's project. It provides
// an activation URL if the API is disabled.
func (c *Client) CheckCloudAPIIsEnabled() (bool, error) { func (c *Client) CheckCloudAPIIsEnabled() (bool, error) {
ctx, cancel := context.WithCancel(context.Background()) ctx, cancel := context.WithCancel(context.Background())
defer func() { defer func() {
@@ -363,79 +371,78 @@ func (c *Client) CheckCloudAPIIsEnabled() (bool, error) {
}() }()
c.RequestMutex.Lock() c.RequestMutex.Lock()
requestBody := `{"project":"%s","request":{"contents":[{"role":"user","parts":[{"text":"Be concise. What is the capital of France?"}]}],"generationConfig":{"thinkingConfig":{"include_thoughts":false,"thinkingBudget":0}}},"model":"gemini-2.5-flash"}` // A simple request to test the API endpoint.
requestBody = fmt.Sprintf(requestBody, c.tokenStorage.ProjectID) requestBody := fmt.Sprintf(`{"project":"%s","request":{"contents":[{"role":"user","parts":[{"text":"Be concise. What is the capital of France?"}]}],"generationConfig":{"thinkingConfig":{"include_thoughts":false,"thinkingBudget":0}}},"model":"gemini-2.5-flash"}`, c.tokenStorage.ProjectID)
// log.Debug(requestBody)
stream, err := c.StreamAPIRequest(ctx, "streamGenerateContent", []byte(requestBody)) stream, err := c.StreamAPIRequest(ctx, "streamGenerateContent", []byte(requestBody))
if err != nil { if err != nil {
// If a 403 Forbidden error occurs, it likely means the API is not enabled.
if err.StatusCode == 403 { if err.StatusCode == 403 {
errJson := err.Error.Error() errJson := err.Error.Error()
codeResult := gjson.Get(errJson, "error.code") // Check for a specific error code and extract the activation URL.
if codeResult.Exists() && codeResult.Type == gjson.Number { if gjson.Get(errJson, "error.code").Int() == 403 {
if codeResult.Int() == 403 { activationUrl := gjson.Get(errJson, "error.details.0.metadata.activationUrl").String()
activationUrlResult := gjson.Get(errJson, "error.details.0.metadata.activationUrl") if activationUrl != "" {
if activationUrlResult.Exists() { log.Warnf(
log.Warnf( "\n\nPlease activate your account with this url:\n\n%s\n And execute this command again:\n%s --login --project_id %s",
"\n\nPlease activate your account with this url:\n\n%s\n And execute this command again:\n%s --login --project_id %s", activationUrl,
activationUrlResult.String(), os.Args[0],
os.Args[0], c.tokenStorage.ProjectID,
c.tokenStorage.ProjectID, )
)
}
} }
} }
return false, nil return false, nil
} }
return false, err.Error return false, err.Error
} }
defer func() {
_ = stream.Close()
}()
// We only need to know if the request was successful, so we can drain the stream.
scanner := bufio.NewScanner(stream) scanner := bufio.NewScanner(stream)
for scanner.Scan() { for scanner.Scan() {
line := scanner.Text() // Do nothing, just consume the stream.
if !strings.HasPrefix(line, "data: ") {
continue
}
} }
if scannerErr := scanner.Err(); scannerErr != nil { return scanner.Err() == nil, scanner.Err()
_ = stream.Close()
} else {
_ = stream.Close()
}
return true, nil
} }
// GetProjectList fetches a list of Google Cloud projects accessible by the user.
func (c *Client) GetProjectList(ctx context.Context) (*GCPProject, error) { func (c *Client) GetProjectList(ctx context.Context) (*GCPProject, error) {
token, err := c.httpClient.Transport.(*oauth2.Transport).Source.Token() token, err := c.httpClient.Transport.(*oauth2.Transport).Source.Token()
if err != nil {
return nil, fmt.Errorf("failed to get token: %w", err)
}
req, err := http.NewRequestWithContext(ctx, "GET", "https://cloudresourcemanager.googleapis.com/v1/projects", nil) req, err := http.NewRequestWithContext(ctx, "GET", "https://cloudresourcemanager.googleapis.com/v1/projects", nil)
if err != nil { if err != nil {
return nil, fmt.Errorf("could not get project list: %v", err) return nil, fmt.Errorf("could not create project list request: %v", err)
} }
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", token.AccessToken)) req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", token.AccessToken))
resp, err := c.httpClient.Do(req) resp, err := c.httpClient.Do(req)
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to execute request: %w", err) return nil, fmt.Errorf("failed to execute project list request: %w", err)
} }
defer func() { defer func() {
_ = resp.Body.Close() _ = resp.Body.Close()
}() }()
bodyBytes, _ := io.ReadAll(resp.Body)
if resp.StatusCode < 200 || resp.StatusCode >= 300 { if resp.StatusCode < 200 || resp.StatusCode >= 300 {
return nil, fmt.Errorf("get user info request failed with status %d: %s", resp.StatusCode, string(bodyBytes)) bodyBytes, _ := io.ReadAll(resp.Body)
return nil, fmt.Errorf("project list request failed with status %d: %s", resp.StatusCode, string(bodyBytes))
} }
var project GCPProject var project GCPProject
err = json.Unmarshal(bodyBytes, &project) if err = json.NewDecoder(resp.Body).Decode(&project); err != nil {
if err != nil {
return nil, fmt.Errorf("failed to unmarshal project list: %w", err) return nil, fmt.Errorf("failed to unmarshal project list: %w", err)
} }
return &project, nil return &project, nil
} }
// SaveTokenToFile serializes the client's current token storage to a JSON file.
// The filename is constructed from the user's email and project ID.
func (c *Client) SaveTokenToFile() error { func (c *Client) SaveTokenToFile() error {
if err := os.MkdirAll(c.cfg.AuthDir, 0700); err != nil { if err := os.MkdirAll(c.cfg.AuthDir, 0700); err != nil {
return fmt.Errorf("failed to create directory: %v", err) return fmt.Errorf("failed to create directory: %v", err)
@@ -457,7 +464,8 @@ func (c *Client) SaveTokenToFile() error {
return nil return nil
} }
// getClientMetadata returns metadata about the client environment. // getClientMetadata returns a map of metadata about the client environment,
// such as IDE type, platform, and plugin version.
func getClientMetadata() map[string]string { func getClientMetadata() map[string]string {
return map[string]string{ return map[string]string{
"ideType": "IDE_UNSPECIFIED", "ideType": "IDE_UNSPECIFIED",
@@ -467,7 +475,8 @@ func getClientMetadata() map[string]string {
} }
} }
// getClientMetadataString returns the metadata as a comma-separated string. // getClientMetadataString returns the client metadata as a single,
// comma-separated string, which is required for the 'Client-Metadata' header.
func getClientMetadataString() string { func getClientMetadataString() string {
md := getClientMetadata() md := getClientMetadata()
parts := make([]string, 0, len(md)) parts := make([]string, 0, len(md))
@@ -477,11 +486,13 @@ func getClientMetadataString() string {
return strings.Join(parts, ",") return strings.Join(parts, ",")
} }
// getUserAgent constructs the User-Agent string for HTTP requests.
func getUserAgent() string { func getUserAgent() string {
return fmt.Sprintf(fmt.Sprintf("GeminiCLI/%s (%s; %s)", pluginVersion, runtime.GOOS, runtime.GOARCH)) return fmt.Sprintf("GeminiCLI/%s (%s; %s)", pluginVersion, runtime.GOOS, runtime.GOARCH)
} }
// getPlatform returns the OS and architecture in the format expected by the API. // getPlatform determines the operating system and architecture and formats
// it into a string expected by the backend API.
func getPlatform() string { func getPlatform() string {
goOS := runtime.GOOS goOS := runtime.GOOS
arch := runtime.GOARCH arch := runtime.GOARCH

View File

@@ -2,17 +2,23 @@ package client
import "time" import "time"
// ErrorMessage encapsulates an error with an associated HTTP status code.
type ErrorMessage struct { type ErrorMessage struct {
StatusCode int StatusCode int
Error error Error error
} }
// GCPProject represents the response structure for a Google Cloud project list request.
type GCPProject struct { type GCPProject struct {
Projects []GCPProjectProjects `json:"projects"` Projects []GCPProjectProjects `json:"projects"`
} }
// GCPProjectLabels defines the labels associated with a GCP project.
type GCPProjectLabels struct { type GCPProjectLabels struct {
GenerativeLanguage string `json:"generative-language"` GenerativeLanguage string `json:"generative-language"`
} }
// GCPProjectProjects contains details about a single Google Cloud project.
type GCPProjectProjects struct { type GCPProjectProjects struct {
ProjectNumber string `json:"projectNumber"` ProjectNumber string `json:"projectNumber"`
ProjectID string `json:"projectId"` ProjectID string `json:"projectId"`
@@ -22,12 +28,14 @@ type GCPProjectProjects struct {
CreateTime time.Time `json:"createTime"` CreateTime time.Time `json:"createTime"`
} }
// Content represents a single message in a conversation, with a role and parts.
type Content struct { type Content struct {
Role string `json:"role"` Role string `json:"role"`
Parts []Part `json:"parts"` Parts []Part `json:"parts"`
} }
// Part represents a single part of a message's content. // Part represents a distinct piece of content within a message, which can be
// text, inline data (like an image), a function call, or a function response.
type Part struct { type Part struct {
Text string `json:"text,omitempty"` Text string `json:"text,omitempty"`
InlineData *InlineData `json:"inlineData,omitempty"` InlineData *InlineData `json:"inlineData,omitempty"`
@@ -35,46 +43,48 @@ type Part struct {
FunctionResponse *FunctionResponse `json:"functionResponse,omitempty"` FunctionResponse *FunctionResponse `json:"functionResponse,omitempty"`
} }
// InlineData represents base64-encoded data with its MIME type.
type InlineData struct { type InlineData struct {
MimeType string `json:"mime_type,omitempty"` MimeType string `json:"mime_type,omitempty"`
Data string `json:"data,omitempty"` Data string `json:"data,omitempty"`
} }
// FunctionCall represents a tool call requested by the model. // FunctionCall represents a tool call requested by the model, including the
// function name and its arguments.
type FunctionCall struct { type FunctionCall struct {
Name string `json:"name"` Name string `json:"name"`
Args map[string]interface{} `json:"args"` Args map[string]interface{} `json:"args"`
} }
// FunctionResponse represents the result of a tool execution. // FunctionResponse represents the result of a tool execution, sent back to the model.
type FunctionResponse struct { type FunctionResponse struct {
Name string `json:"name"` Name string `json:"name"`
Response map[string]interface{} `json:"response"` Response map[string]interface{} `json:"response"`
} }
// GenerateContentRequest is the request payload for the streamGenerateContent endpoint. // GenerateContentRequest is the top-level request structure for the streamGenerateContent endpoint.
type GenerateContentRequest struct { type GenerateContentRequest struct {
Contents []Content `json:"contents"` Contents []Content `json:"contents"`
Tools []ToolDeclaration `json:"tools,omitempty"` Tools []ToolDeclaration `json:"tools,omitempty"`
GenerationConfig `json:"generationConfig"` GenerationConfig `json:"generationConfig"`
} }
// GenerationConfig defines model generation parameters. // GenerationConfig defines parameters that control the model's generation behavior.
type GenerationConfig struct { type GenerationConfig struct {
ThinkingConfig GenerationConfigThinkingConfig `json:"thinkingConfig,omitempty"` ThinkingConfig GenerationConfigThinkingConfig `json:"thinkingConfig,omitempty"`
Temperature float64 `json:"temperature,omitempty"` Temperature float64 `json:"temperature,omitempty"`
TopP float64 `json:"topP,omitempty"` TopP float64 `json:"topP,omitempty"`
TopK float64 `json:"topK,omitempty"` TopK float64 `json:"topK,omitempty"`
// Temperature, TopP, TopK, etc. can be added here.
} }
// GenerationConfigThinkingConfig specifies configuration for the model's "thinking" process.
type GenerationConfigThinkingConfig struct { type GenerationConfigThinkingConfig struct {
// IncludeThoughts determines whether the model should output its reasoning process.
IncludeThoughts bool `json:"include_thoughts,omitempty"` IncludeThoughts bool `json:"include_thoughts,omitempty"`
} }
// ToolDeclaration is the structure for declaring tools to the API. // ToolDeclaration defines the structure for declaring tools (like functions)
// For now, we'll assume a simple structure. A more complete implementation // that the model can call.
// would mirror the OpenAPI schema definition.
type ToolDeclaration struct { type ToolDeclaration struct {
FunctionDeclarations []interface{} `json:"functionDeclarations"` FunctionDeclarations []interface{} `json:"functionDeclarations"`
} }

View File

@@ -9,6 +9,9 @@ import (
"os" "os"
) )
// DoLogin handles the entire user login and setup process.
// It authenticates the user, sets up the user's project, checks API enablement,
// and saves the token for future use.
func DoLogin(cfg *config.Config, projectID string) { func DoLogin(cfg *config.Config, projectID string) {
var err error var err error
var ts auth.TokenStorage var ts auth.TokenStorage
@@ -16,9 +19,8 @@ func DoLogin(cfg *config.Config, projectID string) {
ts.ProjectID = projectID ts.ProjectID = projectID
} }
// 2. Initialize authenticated HTTP Client // Initialize an authenticated HTTP client. This will trigger the OAuth flow if necessary.
clientCtx := context.Background() clientCtx := context.Background()
log.Info("Initializing authentication...") log.Info("Initializing authentication...")
httpClient, errGetClient := auth.GetAuthenticatedClient(clientCtx, &ts, cfg) httpClient, errGetClient := auth.GetAuthenticatedClient(clientCtx, &ts, cfg)
if errGetClient != nil { if errGetClient != nil {
@@ -27,51 +29,57 @@ func DoLogin(cfg *config.Config, projectID string) {
} }
log.Info("Authentication successful.") log.Info("Authentication successful.")
// 3. Initialize CLI Client // Initialize the API client.
cliClient := client.NewClient(httpClient, &ts, cfg) cliClient := client.NewClient(httpClient, &ts, cfg)
// Perform the user setup process.
err = cliClient.SetupUser(clientCtx, ts.Email, projectID) err = cliClient.SetupUser(clientCtx, ts.Email, projectID)
if err != nil { if err != nil {
// Handle the specific case where a project ID is required but not provided.
if err.Error() == "failed to start user onboarding, need define a project id" { if err.Error() == "failed to start user onboarding, need define a project id" {
log.Error("failed to start user onboarding") log.Error("Failed to start user onboarding: A project ID is required.")
// Fetch and display the user's available projects to help them choose one.
project, errGetProjectList := cliClient.GetProjectList(clientCtx) project, errGetProjectList := cliClient.GetProjectList(clientCtx)
if errGetProjectList != nil { if errGetProjectList != nil {
log.Fatalf("failed to complete user setup: %v", err) log.Fatalf("Failed to get project list: %v", err)
} else { } else {
log.Infof("Your account %s needs specify a project id.", ts.Email) log.Infof("Your account %s needs to specify a project ID.", ts.Email)
log.Info("========================================================================") log.Info("========================================================================")
for i := 0; i < len(project.Projects); i++ { for _, p := range project.Projects {
log.Infof("Project ID: %s", project.Projects[i].ProjectID) log.Infof("Project ID: %s", p.ProjectID)
log.Infof("Project Name: %s", project.Projects[i].Name) log.Infof("Project Name: %s", p.Name)
log.Info("========================================================================") log.Info("------------------------------------------------------------------------")
} }
log.Infof("Please run this command to login again:\n\n%s --login --project_id <project_id>\n", os.Args[0]) log.Infof("Please run this command to login again with a specific project:\n\n%s --login --project_id <project_id>\n", os.Args[0])
} }
} else { } else {
// Log as a warning because in some cases, the CLI might still be usable log.Fatalf("Failed to complete user setup: %v", err)
// or the user might want to retry setup later.
log.Fatalf("failed to complete user setup: %v", err)
} }
} else { return // Exit after handling the error.
auto := projectID == "" }
cliClient.SetIsAuto(auto)
if !cliClient.IsChecked() && !cliClient.IsAuto() { // If setup is successful, proceed to check API status and save the token.
isChecked, checkErr := cliClient.CheckCloudAPIIsEnabled() auto := projectID == ""
if checkErr != nil { cliClient.SetIsAuto(auto)
log.Fatalf("failed to check cloud api is enabled: %v", checkErr)
return
}
cliClient.SetIsChecked(isChecked)
}
if !cliClient.IsChecked() && !cliClient.IsAuto() { // If the project was not automatically selected, check if the Cloud AI API is enabled.
if !cliClient.IsChecked() && !cliClient.IsAuto() {
isChecked, checkErr := cliClient.CheckCloudAPIIsEnabled()
if checkErr != nil {
log.Fatalf("Failed to check if Cloud AI API is enabled: %v", checkErr)
return return
} }
cliClient.SetIsChecked(isChecked)
err = cliClient.SaveTokenToFile() // If the check fails (returns false), the CheckCloudAPIIsEnabled function
if err != nil { // will have already printed instructions, so we can just exit.
log.Fatal(err) if !isChecked {
return return
} }
} }
// Save the successfully obtained and verified token to a file.
err = cliClient.SaveTokenToFile()
if err != nil {
log.Fatalf("Failed to save token to file: %v", err)
}
} }

View File

@@ -18,20 +18,25 @@ import (
"time" "time"
) )
// StartService initializes and starts the main API proxy service.
// It loads all available authentication tokens, creates a pool of clients,
// starts the API server, and handles graceful shutdown signals.
func StartService(cfg *config.Config) { func StartService(cfg *config.Config) {
// Create API server configuration // Configure the API server based on the main application config.
apiConfig := &api.ServerConfig{ apiConfig := &api.ServerConfig{
Port: fmt.Sprintf("%d", cfg.Port), Port: fmt.Sprintf("%d", cfg.Port),
Debug: cfg.Debug, Debug: cfg.Debug,
ApiKeys: cfg.ApiKeys, ApiKeys: cfg.ApiKeys,
} }
// Create a pool of API clients, one for each token file found.
cliClients := make([]*client.Client, 0) cliClients := make([]*client.Client, 0)
err := filepath.Walk(cfg.AuthDir, func(path string, info fs.FileInfo, err error) error { err := filepath.Walk(cfg.AuthDir, func(path string, info fs.FileInfo, err error) error {
if err != nil { if err != nil {
return err return err
} }
// Process only JSON files in the auth directory.
if !info.IsDir() && strings.HasSuffix(info.Name(), ".json") { if !info.IsDir() && strings.HasSuffix(info.Name(), ".json") {
log.Debugf("Loading token from: %s", path) log.Debugf("Loading token from: %s", path)
f, errOpen := os.Open(path) f, errOpen := os.Open(path)
@@ -42,58 +47,62 @@ func StartService(cfg *config.Config) {
_ = f.Close() _ = f.Close()
}() }()
// Decode the token storage file.
var ts auth.TokenStorage var ts auth.TokenStorage
if err = json.NewDecoder(f).Decode(&ts); err == nil { if err = json.NewDecoder(f).Decode(&ts); err == nil {
// 2. Initialize authenticated HTTP Client // For each valid token, create an authenticated client.
clientCtx := context.Background() clientCtx := context.Background()
log.Info("Initializing authentication for token...")
log.Info("Initializing authentication...")
httpClient, errGetClient := auth.GetAuthenticatedClient(clientCtx, &ts, cfg) httpClient, errGetClient := auth.GetAuthenticatedClient(clientCtx, &ts, cfg)
if errGetClient != nil { if errGetClient != nil {
log.Fatalf("failed to get authenticated client: %v", errGetClient) // Log fatal will exit, but we return the error for completeness.
log.Fatalf("failed to get authenticated client for token %s: %v", path, errGetClient)
return errGetClient return errGetClient
} }
log.Info("Authentication successful.") log.Info("Authentication successful.")
// 3. Initialize CLI Client // Add the new client to the pool.
cliClient := client.NewClient(httpClient, &ts, cfg) cliClient := client.NewClient(httpClient, &ts, cfg)
cliClients = append(cliClients, cliClient) cliClients = append(cliClients, cliClient)
} }
} }
return nil return nil
}) })
if err != nil {
log.Fatalf("Error walking auth directory: %v", err)
}
// Create API server // Create and start the API server with the pool of clients.
apiServer := api.NewServer(apiConfig, cliClients) apiServer := api.NewServer(apiConfig, cliClients)
log.Infof("Starting API server on port %s", apiConfig.Port) log.Infof("Starting API server on port %s", apiConfig.Port)
if err = apiServer.Start(); err != nil { if err = apiServer.Start(); err != nil {
log.Fatalf("API server failed to start: %v", err) log.Fatalf("API server failed to start: %v", err)
return
} }
// Set up graceful shutdown // Set up a channel to listen for OS signals for graceful shutdown.
sigChan := make(chan os.Signal, 1) sigChan := make(chan os.Signal, 1)
signal.Notify(sigChan, syscall.SIGINT, syscall.SIGTERM) signal.Notify(sigChan, syscall.SIGINT, syscall.SIGTERM)
// Main loop to wait for shutdown signal.
for { for {
select { select {
case <-sigChan: case <-sigChan:
log.Debugf("Received shutdown signal. Cleaning up...") log.Debugf("Received shutdown signal. Cleaning up...")
// Create shutdown context // Create a context with a timeout for the shutdown process.
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
_ = ctx // Mark ctx as used to avoid error, as apiServer.Stop(ctx) is commented out _ = cancel
// Stop API server // Stop the API server gracefully.
if err = apiServer.Stop(ctx); err != nil { if err = apiServer.Stop(ctx); err != nil {
log.Debugf("Error stopping API server: %v", err) log.Debugf("Error stopping API server: %v", err)
} }
cancel()
log.Debugf("Cleanup completed. Exiting...") log.Debugf("Cleanup completed. Exiting...")
os.Exit(0) os.Exit(0)
case <-time.After(5 * time.Second): case <-time.After(5 * time.Second):
// This case is currently empty and acts as a periodic check.
// It could be used for periodic tasks in the future.
} }
} }
} }

View File

@@ -6,33 +6,35 @@ import (
"os" "os"
) )
// Config represents the application's configuration // Config represents the application's configuration, loaded from a YAML file.
type Config struct { type Config struct {
Port int `yaml:"port"` // Port is the network port on which the API server will listen.
AuthDir string `yaml:"auth_dir"` Port int `yaml:"port"`
Debug bool `yaml:"debug"` // AuthDir is the directory where authentication token files are stored.
ProxyUrl string `yaml:"proxy-url"` AuthDir string `yaml:"auth_dir"`
ApiKeys []string `yaml:"api_keys"` // Debug enables or disables debug-level logging and other debug features.
Debug bool `yaml:"debug"`
// ProxyUrl is the URL of an optional proxy server to use for outbound requests.
ProxyUrl string `yaml:"proxy-url"`
// ApiKeys is a list of keys for authenticating clients to this proxy server.
ApiKeys []string `yaml:"api_keys"`
} }
// / LoadConfig loads the configuration from the specified file // LoadConfig reads a YAML configuration file from the given path,
// unmarshals it into a Config struct, and returns it.
func LoadConfig(configFile string) (*Config, error) { func LoadConfig(configFile string) (*Config, error) {
// Read the configuration file // Read the entire configuration file into memory.
data, err := os.ReadFile(configFile) data, err := os.ReadFile(configFile)
// If reading the file fails
if err != nil { if err != nil {
// Return an error
return nil, fmt.Errorf("failed to read config file: %w", err) return nil, fmt.Errorf("failed to read config file: %w", err)
} }
// Parse the YAML data // Unmarshal the YAML data into the Config struct.
var config Config var config Config
// If parsing the YAML data fails
if err = yaml.Unmarshal(data, &config); err != nil { if err = yaml.Unmarshal(data, &config); err != nil {
// Return an error
return nil, fmt.Errorf("failed to parse config file: %w", err) return nil, fmt.Errorf("failed to parse config file: %w", err)
} }
// Return the configuration // Return the populated configuration struct.
return &config, nil return &config, nil
} }