mirror of
https://github.com/router-for-me/CLIProxyAPI.git
synced 2026-02-02 20:40:52 +08:00
Numerous Comments Added and Extensive Optimization Performed using Roo-Code with CLIProxyAPI itself.
This commit is contained in:
@@ -12,9 +12,11 @@ import (
|
|||||||
"strings"
|
"strings"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// LogFormatter defines a custom log format for logrus.
|
||||||
type LogFormatter struct {
|
type LogFormatter struct {
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Format renders a single log entry.
|
||||||
func (m *LogFormatter) Format(entry *log.Entry) ([]byte, error) {
|
func (m *LogFormatter) Format(entry *log.Entry) ([]byte, error) {
|
||||||
var b *bytes.Buffer
|
var b *bytes.Buffer
|
||||||
if entry.Buffer != nil {
|
if entry.Buffer != nil {
|
||||||
@@ -25,33 +27,42 @@ func (m *LogFormatter) Format(entry *log.Entry) ([]byte, error) {
|
|||||||
|
|
||||||
timestamp := entry.Time.Format("2006-01-02 15:04:05")
|
timestamp := entry.Time.Format("2006-01-02 15:04:05")
|
||||||
var newLog string
|
var newLog string
|
||||||
|
// Customize the log format to include timestamp, level, caller file/line, and message.
|
||||||
newLog = fmt.Sprintf("[%s] [%s] [%s:%d] %s\n", timestamp, entry.Level, path.Base(entry.Caller.File), entry.Caller.Line, entry.Message)
|
newLog = fmt.Sprintf("[%s] [%s] [%s:%d] %s\n", timestamp, entry.Level, path.Base(entry.Caller.File), entry.Caller.Line, entry.Message)
|
||||||
|
|
||||||
b.WriteString(newLog)
|
b.WriteString(newLog)
|
||||||
return b.Bytes(), nil
|
return b.Bytes(), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// init initializes the logger configuration.
|
||||||
func init() {
|
func init() {
|
||||||
|
// Set logger output to standard output.
|
||||||
log.SetOutput(os.Stdout)
|
log.SetOutput(os.Stdout)
|
||||||
|
// Enable reporting the caller function's file and line number.
|
||||||
log.SetReportCaller(true)
|
log.SetReportCaller(true)
|
||||||
|
// Set the custom log formatter.
|
||||||
log.SetFormatter(&LogFormatter{})
|
log.SetFormatter(&LogFormatter{})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// main is the entry point of the application.
|
||||||
func main() {
|
func main() {
|
||||||
var login bool
|
var login bool
|
||||||
var projectID string
|
var projectID string
|
||||||
var configPath string
|
var configPath string
|
||||||
|
|
||||||
|
// Define command-line flags.
|
||||||
flag.BoolVar(&login, "login", false, "Login Google Account")
|
flag.BoolVar(&login, "login", false, "Login Google Account")
|
||||||
flag.StringVar(&projectID, "project_id", "", "Project ID")
|
flag.StringVar(&projectID, "project_id", "", "Project ID")
|
||||||
flag.StringVar(&configPath, "config", "", "Configure File Path")
|
flag.StringVar(&configPath, "config", "", "Configure File Path")
|
||||||
|
|
||||||
|
// Parse the command-line flags.
|
||||||
flag.Parse()
|
flag.Parse()
|
||||||
|
|
||||||
var err error
|
var err error
|
||||||
var cfg *config.Config
|
var cfg *config.Config
|
||||||
var wd string
|
var wd string
|
||||||
|
|
||||||
|
// Load configuration from the specified path or the default path.
|
||||||
if configPath != "" {
|
if configPath != "" {
|
||||||
cfg, err = config.LoadConfig(configPath)
|
cfg, err = config.LoadConfig(configPath)
|
||||||
} else {
|
} else {
|
||||||
@@ -65,12 +76,14 @@ func main() {
|
|||||||
log.Fatalf("failed to load config: %v", err)
|
log.Fatalf("failed to load config: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Set the log level based on the configuration.
|
||||||
if cfg.Debug {
|
if cfg.Debug {
|
||||||
log.SetLevel(log.DebugLevel)
|
log.SetLevel(log.DebugLevel)
|
||||||
} else {
|
} else {
|
||||||
log.SetLevel(log.InfoLevel)
|
log.SetLevel(log.InfoLevel)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Expand the tilde (~) in the auth directory path to the user's home directory.
|
||||||
if strings.HasPrefix(cfg.AuthDir, "~") {
|
if strings.HasPrefix(cfg.AuthDir, "~") {
|
||||||
home, errUserHomeDir := os.UserHomeDir()
|
home, errUserHomeDir := os.UserHomeDir()
|
||||||
if errUserHomeDir != nil {
|
if errUserHomeDir != nil {
|
||||||
@@ -85,6 +98,7 @@ func main() {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Either perform login or start the service based on the 'login' flag.
|
||||||
if login {
|
if login {
|
||||||
cmd.DoLogin(cfg, projectID)
|
cmd.DoLogin(cfg, projectID)
|
||||||
} else {
|
} else {
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
port: 8317
|
port: 8317
|
||||||
auth_dir: "~/.cli-proxy-api"
|
auth_dir: "~/.cli-proxy-api"
|
||||||
debug: false
|
debug: true
|
||||||
proxy-url: ""
|
proxy-url: ""
|
||||||
api_keys:
|
api_keys:
|
||||||
- "12345"
|
- "12345"
|
||||||
|
|||||||
@@ -2,14 +2,12 @@ package api
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"encoding/json"
|
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"github.com/luispater/CLIProxyAPI/internal/api/translator"
|
||||||
"github.com/luispater/CLIProxyAPI/internal/client"
|
"github.com/luispater/CLIProxyAPI/internal/client"
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
"github.com/tidwall/gjson"
|
"github.com/tidwall/gjson"
|
||||||
"github.com/tidwall/sjson"
|
|
||||||
"net/http"
|
"net/http"
|
||||||
"strings"
|
|
||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
@@ -21,13 +19,15 @@ var (
|
|||||||
lastUsedClientIndex = 0
|
lastUsedClientIndex = 0
|
||||||
)
|
)
|
||||||
|
|
||||||
// APIHandlers contains the handlers for API endpoints
|
// APIHandlers contains the handlers for API endpoints.
|
||||||
|
// It holds a pool of clients to interact with the backend service.
|
||||||
type APIHandlers struct {
|
type APIHandlers struct {
|
||||||
cliClients []*client.Client
|
cliClients []*client.Client
|
||||||
debug bool
|
debug bool
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewAPIHandlers creates a new API handlers instance
|
// NewAPIHandlers creates a new API handlers instance.
|
||||||
|
// It takes a slice of clients and a debug flag as input.
|
||||||
func NewAPIHandlers(cliClients []*client.Client, debug bool) *APIHandlers {
|
func NewAPIHandlers(cliClients []*client.Client, debug bool) *APIHandlers {
|
||||||
return &APIHandlers{
|
return &APIHandlers{
|
||||||
cliClients: cliClients,
|
cliClients: cliClients,
|
||||||
@@ -35,6 +35,8 @@ func NewAPIHandlers(cliClients []*client.Client, debug bool) *APIHandlers {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Models handles the /v1/models endpoint.
|
||||||
|
// It returns a hardcoded list of available AI models.
|
||||||
func (h *APIHandlers) Models(c *gin.Context) {
|
func (h *APIHandlers) Models(c *gin.Context) {
|
||||||
c.JSON(http.StatusOK, gin.H{
|
c.JSON(http.StatusOK, gin.H{
|
||||||
"data": []map[string]any{
|
"data": []map[string]any{
|
||||||
@@ -162,15 +164,23 @@ func (h *APIHandlers) Models(c *gin.Context) {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
// ChatCompletions handles the /v1/chat/completions endpoint
|
// ChatCompletions handles the /v1/chat/completions endpoint.
|
||||||
|
// It determines whether the request is for a streaming or non-streaming response
|
||||||
|
// and calls the appropriate handler.
|
||||||
func (h *APIHandlers) ChatCompletions(c *gin.Context) {
|
func (h *APIHandlers) ChatCompletions(c *gin.Context) {
|
||||||
rawJson, err := c.GetRawData()
|
rawJson, err := c.GetRawData()
|
||||||
// If data retrieval fails, return 400 error
|
// If data retrieval fails, return a 400 Bad Request error.
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("Invalid request: %v", err), "code": 400})
|
c.JSON(http.StatusBadRequest, ErrorResponse{
|
||||||
|
Error: ErrorDetail{
|
||||||
|
Message: fmt.Sprintf("Invalid request: %v", err),
|
||||||
|
Type: "invalid_request_error",
|
||||||
|
},
|
||||||
|
})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Check if the client requested a streaming response.
|
||||||
streamResult := gjson.GetBytes(rawJson, "stream")
|
streamResult := gjson.GetBytes(rawJson, "stream")
|
||||||
if streamResult.Type == gjson.True {
|
if streamResult.Type == gjson.True {
|
||||||
h.handleStreamingResponse(c, rawJson)
|
h.handleStreamingResponse(c, rawJson)
|
||||||
@@ -179,184 +189,9 @@ func (h *APIHandlers) ChatCompletions(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *APIHandlers) prepareRequest(rawJson []byte) (string, []client.Content, []client.ToolDeclaration) {
|
// handleNonStreamingResponse handles non-streaming chat completion responses.
|
||||||
// log.Debug(string(rawJson))
|
// It selects a client from the pool, sends the request, and aggregates the response
|
||||||
modelName := "gemini-2.5-pro"
|
// before sending it back to the client.
|
||||||
modelResult := gjson.GetBytes(rawJson, "model")
|
|
||||||
if modelResult.Type == gjson.String {
|
|
||||||
modelName = modelResult.String()
|
|
||||||
}
|
|
||||||
|
|
||||||
contents := make([]client.Content, 0)
|
|
||||||
messagesResult := gjson.GetBytes(rawJson, "messages")
|
|
||||||
if messagesResult.IsArray() {
|
|
||||||
messagesResults := messagesResult.Array()
|
|
||||||
for i := 0; i < len(messagesResults); i++ {
|
|
||||||
messageResult := messagesResults[i]
|
|
||||||
roleResult := messageResult.Get("role")
|
|
||||||
contentResult := messageResult.Get("content")
|
|
||||||
if roleResult.Type == gjson.String {
|
|
||||||
if roleResult.String() == "system" {
|
|
||||||
if contentResult.Type == gjson.String {
|
|
||||||
contents = append(contents, client.Content{Role: "user", Parts: []client.Part{{Text: contentResult.String()}}})
|
|
||||||
} else if contentResult.IsObject() {
|
|
||||||
contentTypeResult := contentResult.Get("type")
|
|
||||||
if contentTypeResult.Type == gjson.String && contentTypeResult.String() == "text" {
|
|
||||||
contentTextResult := contentResult.Get("text")
|
|
||||||
if contentTextResult.Type == gjson.String {
|
|
||||||
contents = append(contents, client.Content{Role: "user", Parts: []client.Part{{Text: contentTextResult.String()}}})
|
|
||||||
contents = append(contents, client.Content{Role: "model", Parts: []client.Part{{Text: "Understood. I will follow these instructions and use my tools to assist you."}}})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
} else if roleResult.String() == "user" {
|
|
||||||
if contentResult.Type == gjson.String {
|
|
||||||
contents = append(contents, client.Content{Role: "user", Parts: []client.Part{{Text: contentResult.String()}}})
|
|
||||||
} else if contentResult.IsObject() {
|
|
||||||
contentTypeResult := contentResult.Get("type")
|
|
||||||
if contentTypeResult.Type == gjson.String && contentTypeResult.String() == "text" {
|
|
||||||
contentTextResult := contentResult.Get("text")
|
|
||||||
if contentTextResult.Type == gjson.String {
|
|
||||||
contents = append(contents, client.Content{Role: "user", Parts: []client.Part{{Text: contentTextResult.String()}}})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
} else if contentResult.IsArray() {
|
|
||||||
contentItemResults := contentResult.Array()
|
|
||||||
parts := make([]client.Part, 0)
|
|
||||||
for j := 0; j < len(contentItemResults); j++ {
|
|
||||||
contentItemResult := contentItemResults[j]
|
|
||||||
contentTypeResult := contentItemResult.Get("type")
|
|
||||||
if contentTypeResult.Type == gjson.String && contentTypeResult.String() == "text" {
|
|
||||||
contentTextResult := contentItemResult.Get("text")
|
|
||||||
if contentTextResult.Type == gjson.String {
|
|
||||||
parts = append(parts, client.Part{Text: contentTextResult.String()})
|
|
||||||
}
|
|
||||||
} else if contentTypeResult.Type == gjson.String && contentTypeResult.String() == "image_url" {
|
|
||||||
imageURLResult := contentItemResult.Get("image_url.url")
|
|
||||||
if imageURLResult.Type == gjson.String {
|
|
||||||
imageURL := imageURLResult.String()
|
|
||||||
if len(imageURL) > 5 {
|
|
||||||
imageURLs := strings.SplitN(imageURL[5:], ";", 2)
|
|
||||||
if len(imageURLs) == 2 {
|
|
||||||
if len(imageURLs[1]) > 7 {
|
|
||||||
parts = append(parts, client.Part{InlineData: &client.InlineData{
|
|
||||||
MimeType: imageURLs[0],
|
|
||||||
Data: imageURLs[1][7:],
|
|
||||||
}})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
} else if contentTypeResult.Type == gjson.String && contentTypeResult.String() == "file" {
|
|
||||||
filenameResult := contentItemResult.Get("file.filename")
|
|
||||||
fileDataResult := contentItemResult.Get("file.file_data")
|
|
||||||
if filenameResult.Type == gjson.String && fileDataResult.Type == gjson.String {
|
|
||||||
filename := filenameResult.String()
|
|
||||||
splitFilename := strings.Split(filename, ".")
|
|
||||||
ext := splitFilename[len(splitFilename)-1]
|
|
||||||
|
|
||||||
mimeType, ok := MimeTypes[ext]
|
|
||||||
if !ok {
|
|
||||||
log.Warnf("Unknown file name extension '%s' at index %d, skipping file", ext, j)
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
parts = append(parts, client.Part{InlineData: &client.InlineData{
|
|
||||||
MimeType: mimeType,
|
|
||||||
Data: fileDataResult.String(),
|
|
||||||
}})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
contents = append(contents, client.Content{Role: "user", Parts: parts})
|
|
||||||
}
|
|
||||||
} else if roleResult.String() == "assistant" {
|
|
||||||
if contentResult.Type == gjson.String {
|
|
||||||
contents = append(contents, client.Content{Role: "model", Parts: []client.Part{{Text: contentResult.String()}}})
|
|
||||||
} else if contentResult.IsObject() {
|
|
||||||
contentTypeResult := contentResult.Get("type")
|
|
||||||
if contentTypeResult.Type == gjson.String && contentTypeResult.String() == "text" {
|
|
||||||
contentTextResult := contentResult.Get("text")
|
|
||||||
if contentTextResult.Type == gjson.String {
|
|
||||||
contents = append(contents, client.Content{Role: "user", Parts: []client.Part{{Text: contentTextResult.String()}}})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
} else if !contentResult.Exists() || contentResult.Type == gjson.Null {
|
|
||||||
toolCallsResult := messageResult.Get("tool_calls")
|
|
||||||
if toolCallsResult.IsArray() {
|
|
||||||
tcsResult := toolCallsResult.Array()
|
|
||||||
for j := 0; j < len(tcsResult); j++ {
|
|
||||||
tcResult := tcsResult[j]
|
|
||||||
functionNameResult := tcResult.Get("function.name")
|
|
||||||
functionArguments := tcResult.Get("function.arguments")
|
|
||||||
if functionNameResult.Exists() && functionNameResult.Type == gjson.String && functionArguments.Exists() && functionArguments.Type == gjson.String {
|
|
||||||
var args map[string]any
|
|
||||||
err := json.Unmarshal([]byte(functionArguments.String()), &args)
|
|
||||||
if err == nil {
|
|
||||||
contents = append(contents, client.Content{
|
|
||||||
Role: "model", Parts: []client.Part{
|
|
||||||
{
|
|
||||||
FunctionCall: &client.FunctionCall{
|
|
||||||
Name: functionNameResult.String(),
|
|
||||||
Args: args,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
} else if roleResult.String() == "tool" {
|
|
||||||
toolCallIDResult := messageResult.Get("tool_call_id")
|
|
||||||
if toolCallIDResult.Exists() && toolCallIDResult.Type == gjson.String {
|
|
||||||
if contentResult.Type == gjson.String {
|
|
||||||
functionResponse := client.FunctionResponse{Name: toolCallIDResult.String(), Response: map[string]interface{}{"result": contentResult.String()}}
|
|
||||||
contents = append(contents, client.Content{Role: "tool", Parts: []client.Part{{FunctionResponse: &functionResponse}}})
|
|
||||||
} else if contentResult.IsObject() {
|
|
||||||
contentTypeResult := contentResult.Get("type")
|
|
||||||
if contentTypeResult.Type == gjson.String && contentTypeResult.String() == "text" {
|
|
||||||
contentTextResult := contentResult.Get("text")
|
|
||||||
if contentTextResult.Type == gjson.String {
|
|
||||||
functionResponse := client.FunctionResponse{Name: toolCallIDResult.String(), Response: map[string]interface{}{"result": contentResult.String()}}
|
|
||||||
contents = append(contents, client.Content{Role: "tool", Parts: []client.Part{{FunctionResponse: &functionResponse}}})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
var tools []client.ToolDeclaration
|
|
||||||
toolsResult := gjson.GetBytes(rawJson, "tools")
|
|
||||||
if toolsResult.IsArray() {
|
|
||||||
tools = make([]client.ToolDeclaration, 1)
|
|
||||||
tools[0].FunctionDeclarations = make([]any, 0)
|
|
||||||
toolsResults := toolsResult.Array()
|
|
||||||
for i := 0; i < len(toolsResults); i++ {
|
|
||||||
toolTypeResult := toolsResults[i].Get("type")
|
|
||||||
if toolTypeResult.Type != gjson.String || toolTypeResult.String() != "function" {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
functionTypeResult := toolsResults[i].Get("function")
|
|
||||||
if functionTypeResult.Exists() && functionTypeResult.IsObject() {
|
|
||||||
var functionDeclaration any
|
|
||||||
err := json.Unmarshal([]byte(functionTypeResult.Raw), &functionDeclaration)
|
|
||||||
if err == nil {
|
|
||||||
tools[0].FunctionDeclarations = append(tools[0].FunctionDeclarations, functionDeclaration)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
tools = make([]client.ToolDeclaration, 0)
|
|
||||||
}
|
|
||||||
return modelName, contents, tools
|
|
||||||
}
|
|
||||||
|
|
||||||
// handleNonStreamingResponse handles non-streaming responses
|
|
||||||
func (h *APIHandlers) handleNonStreamingResponse(c *gin.Context, rawJson []byte) {
|
func (h *APIHandlers) handleNonStreamingResponse(c *gin.Context, rawJson []byte) {
|
||||||
c.Header("Content-Type", "application/json")
|
c.Header("Content-Type", "application/json")
|
||||||
|
|
||||||
@@ -372,7 +207,7 @@ func (h *APIHandlers) handleNonStreamingResponse(c *gin.Context, rawJson []byte)
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
modelName, contents, tools := h.prepareRequest(rawJson)
|
modelName, contents, tools := translator.PrepareRequest(rawJson)
|
||||||
cliCtx, cliCancel := context.WithCancel(context.Background())
|
cliCtx, cliCancel := context.WithCancel(context.Background())
|
||||||
var cliClient *client.Client
|
var cliClient *client.Client
|
||||||
defer func() {
|
defer func() {
|
||||||
@@ -425,19 +260,13 @@ func (h *APIHandlers) handleNonStreamingResponse(c *gin.Context, rawJson []byte)
|
|||||||
cliCancel()
|
cliCancel()
|
||||||
return
|
return
|
||||||
} else {
|
} else {
|
||||||
jsonTemplate = h.convertCliToOpenAINonStream(jsonTemplate, chunk)
|
jsonTemplate = translator.ConvertCliToOpenAINonStream(jsonTemplate, chunk)
|
||||||
}
|
}
|
||||||
case err, okError := <-errChan:
|
case err, okError := <-errChan:
|
||||||
if okError {
|
if okError {
|
||||||
c.Status(err.StatusCode)
|
c.Status(err.StatusCode)
|
||||||
_, _ = fmt.Fprint(c.Writer, err.Error.Error())
|
_, _ = fmt.Fprint(c.Writer, err.Error.Error())
|
||||||
flusher.Flush()
|
flusher.Flush()
|
||||||
// c.JSON(http.StatusInternalServerError, ErrorResponse{
|
|
||||||
// Error: ErrorDetail{
|
|
||||||
// Message: err.Error(),
|
|
||||||
// Type: "server_error",
|
|
||||||
// },
|
|
||||||
// })
|
|
||||||
cliCancel()
|
cliCancel()
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -455,7 +284,7 @@ func (h *APIHandlers) handleStreamingResponse(c *gin.Context, rawJson []byte) {
|
|||||||
c.Header("Connection", "keep-alive")
|
c.Header("Connection", "keep-alive")
|
||||||
c.Header("Access-Control-Allow-Origin", "*")
|
c.Header("Access-Control-Allow-Origin", "*")
|
||||||
|
|
||||||
// Handle streaming manually
|
// Get the http.Flusher interface to manually flush the response.
|
||||||
flusher, ok := c.Writer.(http.Flusher)
|
flusher, ok := c.Writer.(http.Flusher)
|
||||||
if !ok {
|
if !ok {
|
||||||
c.JSON(http.StatusInternalServerError, ErrorResponse{
|
c.JSON(http.StatusInternalServerError, ErrorResponse{
|
||||||
@@ -466,28 +295,33 @@ func (h *APIHandlers) handleStreamingResponse(c *gin.Context, rawJson []byte) {
|
|||||||
})
|
})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
modelName, contents, tools := h.prepareRequest(rawJson)
|
|
||||||
|
// Prepare the request for the backend client.
|
||||||
|
modelName, contents, tools := translator.PrepareRequest(rawJson)
|
||||||
cliCtx, cliCancel := context.WithCancel(context.Background())
|
cliCtx, cliCancel := context.WithCancel(context.Background())
|
||||||
var cliClient *client.Client
|
var cliClient *client.Client
|
||||||
defer func() {
|
defer func() {
|
||||||
|
// Ensure the client's mutex is unlocked on function exit.
|
||||||
if cliClient != nil {
|
if cliClient != nil {
|
||||||
cliClient.RequestMutex.Unlock()
|
cliClient.RequestMutex.Unlock()
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
|
||||||
// Lock the mutex to update the last used page index
|
// Use a round-robin approach to select the next available client.
|
||||||
|
// This distributes the load among the available clients.
|
||||||
mutex.Lock()
|
mutex.Lock()
|
||||||
startIndex := lastUsedClientIndex
|
startIndex := lastUsedClientIndex
|
||||||
currentIndex := (startIndex + 1) % len(h.cliClients)
|
currentIndex := (startIndex + 1) % len(h.cliClients)
|
||||||
lastUsedClientIndex = currentIndex
|
lastUsedClientIndex = currentIndex
|
||||||
mutex.Unlock()
|
mutex.Unlock()
|
||||||
|
|
||||||
// Reorder the pages to start from the last used index
|
// Reorder the clients to start from the next client in the rotation.
|
||||||
reorderedPages := make([]*client.Client, len(h.cliClients))
|
reorderedPages := make([]*client.Client, len(h.cliClients))
|
||||||
for i := 0; i < len(h.cliClients); i++ {
|
for i := 0; i < len(h.cliClients); i++ {
|
||||||
reorderedPages[i] = h.cliClients[(startIndex+1+i)%len(h.cliClients)]
|
reorderedPages[i] = h.cliClients[(startIndex+1+i)%len(h.cliClients)]
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Attempt to lock a client for the request.
|
||||||
locked := false
|
locked := false
|
||||||
for i := 0; i < len(reorderedPages); i++ {
|
for i := 0; i < len(reorderedPages); i++ {
|
||||||
cliClient = reorderedPages[i]
|
cliClient = reorderedPages[i]
|
||||||
@@ -496,235 +330,52 @@ func (h *APIHandlers) handleStreamingResponse(c *gin.Context, rawJson []byte) {
|
|||||||
break
|
break
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
// If no client is available, block and wait for the first client.
|
||||||
if !locked {
|
if !locked {
|
||||||
cliClient = h.cliClients[0]
|
cliClient = h.cliClients[0]
|
||||||
cliClient.RequestMutex.Lock()
|
cliClient.RequestMutex.Lock()
|
||||||
}
|
}
|
||||||
|
|
||||||
log.Debugf("Request use account: %s, project id: %s", cliClient.GetEmail(), cliClient.GetProjectID())
|
log.Debugf("Request use account: %s, project id: %s", cliClient.GetEmail(), cliClient.GetProjectID())
|
||||||
|
// Send the message and receive response chunks and errors via channels.
|
||||||
respChan, errChan := cliClient.SendMessageStream(cliCtx, rawJson, modelName, contents, tools)
|
respChan, errChan := cliClient.SendMessageStream(cliCtx, rawJson, modelName, contents, tools)
|
||||||
for {
|
for {
|
||||||
select {
|
select {
|
||||||
|
// Handle client disconnection.
|
||||||
case <-c.Request.Context().Done():
|
case <-c.Request.Context().Done():
|
||||||
if c.Request.Context().Err().Error() == "context canceled" {
|
if c.Request.Context().Err().Error() == "context canceled" {
|
||||||
log.Debugf("Client disconnected: %v", c.Request.Context().Err())
|
log.Debugf("Client disconnected: %v", c.Request.Context().Err())
|
||||||
cliCancel()
|
cliCancel() // Cancel the backend request.
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
// Process incoming response chunks.
|
||||||
case chunk, okStream := <-respChan:
|
case chunk, okStream := <-respChan:
|
||||||
if !okStream {
|
if !okStream {
|
||||||
|
// Stream is closed, send the final [DONE] message.
|
||||||
_, _ = fmt.Fprintf(c.Writer, "data: [DONE]\n\n")
|
_, _ = fmt.Fprintf(c.Writer, "data: [DONE]\n\n")
|
||||||
flusher.Flush()
|
flusher.Flush()
|
||||||
cliCancel()
|
cliCancel()
|
||||||
return
|
return
|
||||||
} else {
|
} else {
|
||||||
openAIFormat := h.convertCliToOpenAI(chunk)
|
// Convert the chunk to OpenAI format and send it to the client.
|
||||||
|
openAIFormat := translator.ConvertCliToOpenAI(chunk)
|
||||||
if openAIFormat != "" {
|
if openAIFormat != "" {
|
||||||
_, _ = fmt.Fprintf(c.Writer, "data: %s\n\n", openAIFormat)
|
_, _ = fmt.Fprintf(c.Writer, "data: %s\n\n", openAIFormat)
|
||||||
flusher.Flush()
|
flusher.Flush()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
// Handle errors from the backend.
|
||||||
case err, okError := <-errChan:
|
case err, okError := <-errChan:
|
||||||
if okError {
|
if okError {
|
||||||
c.Status(err.StatusCode)
|
c.Status(err.StatusCode)
|
||||||
_, _ = fmt.Fprint(c.Writer, err.Error.Error())
|
_, _ = fmt.Fprint(c.Writer, err.Error.Error())
|
||||||
flusher.Flush()
|
flusher.Flush()
|
||||||
// c.JSON(http.StatusInternalServerError, ErrorResponse{
|
|
||||||
// Error: ErrorDetail{
|
|
||||||
// Message: err.Error(),
|
|
||||||
// Type: "server_error",
|
|
||||||
// },
|
|
||||||
// })
|
|
||||||
cliCancel()
|
cliCancel()
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
// Send a keep-alive signal to the client.
|
||||||
case <-time.After(500 * time.Millisecond):
|
case <-time.After(500 * time.Millisecond):
|
||||||
_, _ = c.Writer.Write([]byte(": CLI-PROXY-API PROCESSING\n\n"))
|
_, _ = c.Writer.Write([]byte(": CLI-PROXY-API PROCESSING\n\n"))
|
||||||
flusher.Flush()
|
flusher.Flush()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *APIHandlers) convertCliToOpenAI(rawJson []byte) string {
|
|
||||||
// log.Debugf(string(rawJson))
|
|
||||||
template := `{"id":"","object":"chat.completion.chunk","created":12345,"model":"model","choices":[{"index":0,"delta":{"role":null,"content":null,"reasoning_content":null,"tool_calls":null},"finish_reason":null,"native_finish_reason":null}]}`
|
|
||||||
|
|
||||||
modelVersionResult := gjson.GetBytes(rawJson, "response.modelVersion")
|
|
||||||
if modelVersionResult.Exists() && modelVersionResult.Type == gjson.String {
|
|
||||||
template, _ = sjson.Set(template, "model", modelVersionResult.String())
|
|
||||||
}
|
|
||||||
|
|
||||||
createTimeResult := gjson.GetBytes(rawJson, "response.createTime")
|
|
||||||
if createTimeResult.Exists() && createTimeResult.Type == gjson.String {
|
|
||||||
t, err := time.Parse(time.RFC3339Nano, createTimeResult.String())
|
|
||||||
var unixTimestamp int64
|
|
||||||
if err == nil {
|
|
||||||
unixTimestamp = t.Unix()
|
|
||||||
} else {
|
|
||||||
unixTimestamp = time.Now().Unix()
|
|
||||||
}
|
|
||||||
template, _ = sjson.Set(template, "created", unixTimestamp)
|
|
||||||
}
|
|
||||||
|
|
||||||
responseIdResult := gjson.GetBytes(rawJson, "response.responseId")
|
|
||||||
if responseIdResult.Exists() && responseIdResult.Type == gjson.String {
|
|
||||||
template, _ = sjson.Set(template, "id", responseIdResult.String())
|
|
||||||
}
|
|
||||||
|
|
||||||
finishReasonResult := gjson.GetBytes(rawJson, "response.candidates.0.finishReason")
|
|
||||||
if finishReasonResult.Exists() && finishReasonResult.Type == gjson.String {
|
|
||||||
template, _ = sjson.Set(template, "choices.0.finish_reason", finishReasonResult.String())
|
|
||||||
template, _ = sjson.Set(template, "choices.0.native_finish_reason", finishReasonResult.String())
|
|
||||||
}
|
|
||||||
|
|
||||||
usageResult := gjson.GetBytes(rawJson, "response.usageMetadata")
|
|
||||||
candidatesTokenCountResult := usageResult.Get("candidatesTokenCount")
|
|
||||||
if candidatesTokenCountResult.Exists() && candidatesTokenCountResult.Type == gjson.Number {
|
|
||||||
template, _ = sjson.Set(template, "usage.completion_tokens", candidatesTokenCountResult.Int())
|
|
||||||
}
|
|
||||||
totalTokenCountResult := usageResult.Get("totalTokenCount")
|
|
||||||
if totalTokenCountResult.Exists() && totalTokenCountResult.Type == gjson.Number {
|
|
||||||
template, _ = sjson.Set(template, "usage.total_tokens", totalTokenCountResult.Int())
|
|
||||||
}
|
|
||||||
thoughtsTokenCountResult := usageResult.Get("thoughtsTokenCount")
|
|
||||||
promptTokenCountResult := usageResult.Get("promptTokenCount")
|
|
||||||
if promptTokenCountResult.Exists() && promptTokenCountResult.Type == gjson.Number {
|
|
||||||
if thoughtsTokenCountResult.Exists() && thoughtsTokenCountResult.Type == gjson.Number {
|
|
||||||
template, _ = sjson.Set(template, "usage.prompt_tokens", promptTokenCountResult.Int()+thoughtsTokenCountResult.Int())
|
|
||||||
} else {
|
|
||||||
template, _ = sjson.Set(template, "usage.prompt_tokens", promptTokenCountResult.Int())
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if thoughtsTokenCountResult.Exists() && thoughtsTokenCountResult.Type == gjson.Number {
|
|
||||||
template, _ = sjson.Set(template, "usage.completion_tokens_details.reasoning_tokens", thoughtsTokenCountResult.Int())
|
|
||||||
}
|
|
||||||
|
|
||||||
partResult := gjson.GetBytes(rawJson, "response.candidates.0.content.parts.0")
|
|
||||||
partTextResult := partResult.Get("text")
|
|
||||||
functionCallResult := partResult.Get("functionCall")
|
|
||||||
|
|
||||||
if partTextResult.Exists() && partTextResult.Type == gjson.String {
|
|
||||||
partThoughtResult := partResult.Get("thought")
|
|
||||||
if partThoughtResult.Exists() && partThoughtResult.Type == gjson.True {
|
|
||||||
template, _ = sjson.Set(template, "choices.0.delta.reasoning_content", partTextResult.String())
|
|
||||||
} else {
|
|
||||||
template, _ = sjson.Set(template, "choices.0.delta.content", partTextResult.String())
|
|
||||||
}
|
|
||||||
template, _ = sjson.Set(template, "choices.0.delta.role", "assistant")
|
|
||||||
} else if functionCallResult.Exists() {
|
|
||||||
functionCallTemplate := `[{"id": "","type": "function","function": {"name": "","arguments": ""}}]`
|
|
||||||
fcNameResult := functionCallResult.Get("name")
|
|
||||||
if fcNameResult.Exists() && fcNameResult.Type == gjson.String {
|
|
||||||
functionCallTemplate, _ = sjson.Set(functionCallTemplate, "0.id", fcNameResult.String())
|
|
||||||
functionCallTemplate, _ = sjson.Set(functionCallTemplate, "0.function.name", fcNameResult.String())
|
|
||||||
}
|
|
||||||
fcArgsResult := functionCallResult.Get("args")
|
|
||||||
if fcArgsResult.Exists() && fcArgsResult.IsObject() {
|
|
||||||
functionCallTemplate, _ = sjson.Set(functionCallTemplate, "0.function.arguments", fcArgsResult.Raw)
|
|
||||||
}
|
|
||||||
template, _ = sjson.Set(template, "choices.0.delta.role", "assistant")
|
|
||||||
template, _ = sjson.SetRaw(template, "choices.0.delta.tool_calls", functionCallTemplate)
|
|
||||||
} else {
|
|
||||||
return ""
|
|
||||||
}
|
|
||||||
|
|
||||||
return template
|
|
||||||
}
|
|
||||||
|
|
||||||
func (h *APIHandlers) convertCliToOpenAINonStream(template string, rawJson []byte) string {
|
|
||||||
modelVersionResult := gjson.GetBytes(rawJson, "response.modelVersion")
|
|
||||||
if modelVersionResult.Exists() && modelVersionResult.Type == gjson.String {
|
|
||||||
template, _ = sjson.Set(template, "model", modelVersionResult.String())
|
|
||||||
}
|
|
||||||
|
|
||||||
createTimeResult := gjson.GetBytes(rawJson, "response.createTime")
|
|
||||||
if createTimeResult.Exists() && createTimeResult.Type == gjson.String {
|
|
||||||
t, err := time.Parse(time.RFC3339Nano, createTimeResult.String())
|
|
||||||
var unixTimestamp int64
|
|
||||||
if err == nil {
|
|
||||||
unixTimestamp = t.Unix()
|
|
||||||
} else {
|
|
||||||
unixTimestamp = time.Now().Unix()
|
|
||||||
}
|
|
||||||
template, _ = sjson.Set(template, "created", unixTimestamp)
|
|
||||||
}
|
|
||||||
|
|
||||||
responseIdResult := gjson.GetBytes(rawJson, "response.responseId")
|
|
||||||
if responseIdResult.Exists() && responseIdResult.Type == gjson.String {
|
|
||||||
template, _ = sjson.Set(template, "id", responseIdResult.String())
|
|
||||||
}
|
|
||||||
|
|
||||||
finishReasonResult := gjson.GetBytes(rawJson, "response.candidates.0.finishReason")
|
|
||||||
if finishReasonResult.Exists() && finishReasonResult.Type == gjson.String {
|
|
||||||
template, _ = sjson.Set(template, "choices.0.finish_reason", finishReasonResult.String())
|
|
||||||
template, _ = sjson.Set(template, "choices.0.native_finish_reason", finishReasonResult.String())
|
|
||||||
}
|
|
||||||
|
|
||||||
usageResult := gjson.GetBytes(rawJson, "response.usageMetadata")
|
|
||||||
candidatesTokenCountResult := usageResult.Get("candidatesTokenCount")
|
|
||||||
if candidatesTokenCountResult.Exists() && candidatesTokenCountResult.Type == gjson.Number {
|
|
||||||
template, _ = sjson.Set(template, "usage.completion_tokens", candidatesTokenCountResult.Int())
|
|
||||||
}
|
|
||||||
totalTokenCountResult := usageResult.Get("totalTokenCount")
|
|
||||||
if totalTokenCountResult.Exists() && totalTokenCountResult.Type == gjson.Number {
|
|
||||||
template, _ = sjson.Set(template, "usage.total_tokens", totalTokenCountResult.Int())
|
|
||||||
}
|
|
||||||
thoughtsTokenCountResult := usageResult.Get("thoughtsTokenCount")
|
|
||||||
promptTokenCountResult := usageResult.Get("promptTokenCount")
|
|
||||||
if promptTokenCountResult.Exists() && promptTokenCountResult.Type == gjson.Number {
|
|
||||||
if thoughtsTokenCountResult.Exists() && thoughtsTokenCountResult.Type == gjson.Number {
|
|
||||||
template, _ = sjson.Set(template, "usage.prompt_tokens", promptTokenCountResult.Int()+thoughtsTokenCountResult.Int())
|
|
||||||
} else {
|
|
||||||
template, _ = sjson.Set(template, "usage.prompt_tokens", promptTokenCountResult.Int())
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if thoughtsTokenCountResult.Exists() && thoughtsTokenCountResult.Type == gjson.Number {
|
|
||||||
template, _ = sjson.Set(template, "usage.completion_tokens_details.reasoning_tokens", thoughtsTokenCountResult.Int())
|
|
||||||
}
|
|
||||||
|
|
||||||
partResult := gjson.GetBytes(rawJson, "response.candidates.0.content.parts.0")
|
|
||||||
partTextResult := partResult.Get("text")
|
|
||||||
functionCallResult := partResult.Get("functionCall")
|
|
||||||
|
|
||||||
if partTextResult.Exists() && partTextResult.Type == gjson.String {
|
|
||||||
partThoughtResult := partResult.Get("thought")
|
|
||||||
if partThoughtResult.Exists() && partThoughtResult.Type == gjson.True {
|
|
||||||
reasoningContentResult := gjson.Get(template, "choices.0.message.reasoning_content")
|
|
||||||
if reasoningContentResult.Type == gjson.String {
|
|
||||||
template, _ = sjson.Set(template, "choices.0.message.reasoning_content", reasoningContentResult.String()+partTextResult.String())
|
|
||||||
} else {
|
|
||||||
template, _ = sjson.Set(template, "choices.0.message.reasoning_content", partTextResult.String())
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
reasoningContentResult := gjson.Get(template, "choices.0.message.content")
|
|
||||||
if reasoningContentResult.Type == gjson.String {
|
|
||||||
template, _ = sjson.Set(template, "choices.0.message.content", reasoningContentResult.String()+partTextResult.String())
|
|
||||||
} else {
|
|
||||||
template, _ = sjson.Set(template, "choices.0.message.content", partTextResult.String())
|
|
||||||
}
|
|
||||||
}
|
|
||||||
template, _ = sjson.Set(template, "choices.0.message.role", "assistant")
|
|
||||||
} else if functionCallResult.Exists() {
|
|
||||||
toolCallsResult := gjson.Get(template, "choices.0.message.tool_calls")
|
|
||||||
if !toolCallsResult.Exists() || toolCallsResult.Type == gjson.Null {
|
|
||||||
template, _ = sjson.SetRaw(template, "choices.0.message.tool_calls", `[]`)
|
|
||||||
}
|
|
||||||
|
|
||||||
functionCallItemTemplate := `{"id": "","type": "function","function": {"name": "","arguments": ""}}`
|
|
||||||
fcNameResult := functionCallResult.Get("name")
|
|
||||||
if fcNameResult.Exists() && fcNameResult.Type == gjson.String {
|
|
||||||
functionCallItemTemplate, _ = sjson.Set(functionCallItemTemplate, "id", fcNameResult.String())
|
|
||||||
functionCallItemTemplate, _ = sjson.Set(functionCallItemTemplate, "function.name", fcNameResult.String())
|
|
||||||
}
|
|
||||||
fcArgsResult := functionCallResult.Get("args")
|
|
||||||
if fcArgsResult.Exists() && fcArgsResult.IsObject() {
|
|
||||||
functionCallItemTemplate, _ = sjson.Set(functionCallItemTemplate, "function.arguments", fcArgsResult.Raw)
|
|
||||||
}
|
|
||||||
template, _ = sjson.Set(template, "choices.0.message.role", "assistant")
|
|
||||||
template, _ = sjson.SetRaw(template, "choices.0.message.tool_calls.-1", functionCallItemTemplate)
|
|
||||||
} else {
|
|
||||||
return ""
|
|
||||||
}
|
|
||||||
|
|
||||||
return template
|
|
||||||
}
|
|
||||||
|
|||||||
@@ -1,13 +1,18 @@
|
|||||||
package api
|
package api
|
||||||
|
|
||||||
// ErrorResponse represents an error response
|
// ErrorResponse represents a standard error response format for the API.
|
||||||
|
// It contains a single ErrorDetail field.
|
||||||
type ErrorResponse struct {
|
type ErrorResponse struct {
|
||||||
Error ErrorDetail `json:"error"`
|
Error ErrorDetail `json:"error"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// ErrorDetail represents error details
|
// ErrorDetail provides specific information about an error that occurred.
|
||||||
|
// It includes a human-readable message, an error type, and an optional error code.
|
||||||
type ErrorDetail struct {
|
type ErrorDetail struct {
|
||||||
|
// A human-readable message providing more details about the error.
|
||||||
Message string `json:"message"`
|
Message string `json:"message"`
|
||||||
Type string `json:"type"`
|
// The type of error that occurred (e.g., "invalid_request_error").
|
||||||
Code string `json:"code,omitempty"`
|
Type string `json:"type"`
|
||||||
|
// A short code identifying the error, if applicable.
|
||||||
|
Code string `json:"code,omitempty"`
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -11,7 +11,8 @@ import (
|
|||||||
"strings"
|
"strings"
|
||||||
)
|
)
|
||||||
|
|
||||||
// Server represents the API server
|
// Server represents the main API server.
|
||||||
|
// It encapsulates the Gin engine, HTTP server, handlers, and configuration.
|
||||||
type Server struct {
|
type Server struct {
|
||||||
engine *gin.Engine
|
engine *gin.Engine
|
||||||
server *http.Server
|
server *http.Server
|
||||||
@@ -19,14 +20,18 @@ type Server struct {
|
|||||||
cfg *ServerConfig
|
cfg *ServerConfig
|
||||||
}
|
}
|
||||||
|
|
||||||
// ServerConfig contains configuration for the API server
|
// ServerConfig contains the configuration for the API server.
|
||||||
type ServerConfig struct {
|
type ServerConfig struct {
|
||||||
Port string
|
// Port is the port number the server will listen on.
|
||||||
Debug bool
|
Port string
|
||||||
|
// Debug enables or disables debug mode for the server and Gin.
|
||||||
|
Debug bool
|
||||||
|
// ApiKeys is a list of valid API keys for authentication.
|
||||||
ApiKeys []string
|
ApiKeys []string
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewServer creates a new API server instance
|
// NewServer creates and initializes a new API server instance.
|
||||||
|
// It sets up the Gin engine, middleware, routes, and handlers.
|
||||||
func NewServer(config *ServerConfig, cliClients []*client.Client) *Server {
|
func NewServer(config *ServerConfig, cliClients []*client.Client) *Server {
|
||||||
// Set gin mode
|
// Set gin mode
|
||||||
if !config.Debug {
|
if !config.Debug {
|
||||||
@@ -63,7 +68,8 @@ func NewServer(config *ServerConfig, cliClients []*client.Client) *Server {
|
|||||||
return s
|
return s
|
||||||
}
|
}
|
||||||
|
|
||||||
// setupRoutes configures the API routes
|
// setupRoutes configures the API routes for the server.
|
||||||
|
// It defines the endpoints and associates them with their respective handlers.
|
||||||
func (s *Server) setupRoutes() {
|
func (s *Server) setupRoutes() {
|
||||||
// OpenAI compatible API routes
|
// OpenAI compatible API routes
|
||||||
v1 := s.engine.Group("/v1")
|
v1 := s.engine.Group("/v1")
|
||||||
@@ -86,11 +92,12 @@ func (s *Server) setupRoutes() {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
// Start starts the API server
|
// Start begins listening for and serving HTTP requests.
|
||||||
|
// It's a blocking call and will only return on an unrecoverable error.
|
||||||
func (s *Server) Start() error {
|
func (s *Server) Start() error {
|
||||||
log.Debugf("Starting API server on %s", s.server.Addr)
|
log.Debugf("Starting API server on %s", s.server.Addr)
|
||||||
|
|
||||||
// Start the HTTP server
|
// Start the HTTP server.
|
||||||
if err := s.server.ListenAndServe(); err != nil && !errors.Is(err, http.ErrServerClosed) {
|
if err := s.server.ListenAndServe(); err != nil && !errors.Is(err, http.ErrServerClosed) {
|
||||||
return fmt.Errorf("failed to start HTTP server: %v", err)
|
return fmt.Errorf("failed to start HTTP server: %v", err)
|
||||||
}
|
}
|
||||||
@@ -98,11 +105,12 @@ func (s *Server) Start() error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Stop gracefully stops the API server
|
// Stop gracefully shuts down the API server without interrupting any
|
||||||
|
// active connections.
|
||||||
func (s *Server) Stop(ctx context.Context) error {
|
func (s *Server) Stop(ctx context.Context) error {
|
||||||
log.Debug("Stopping API server...")
|
log.Debug("Stopping API server...")
|
||||||
|
|
||||||
// Shutdown the HTTP server
|
// Shutdown the HTTP server.
|
||||||
if err := s.server.Shutdown(ctx); err != nil {
|
if err := s.server.Shutdown(ctx); err != nil {
|
||||||
return fmt.Errorf("failed to shutdown HTTP server: %v", err)
|
return fmt.Errorf("failed to shutdown HTTP server: %v", err)
|
||||||
}
|
}
|
||||||
@@ -111,7 +119,8 @@ func (s *Server) Stop(ctx context.Context) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// corsMiddleware adds CORS headers
|
// corsMiddleware returns a Gin middleware handler that adds CORS headers
|
||||||
|
// to every response, allowing cross-origin requests.
|
||||||
func corsMiddleware() gin.HandlerFunc {
|
func corsMiddleware() gin.HandlerFunc {
|
||||||
return func(c *gin.Context) {
|
return func(c *gin.Context) {
|
||||||
c.Header("Access-Control-Allow-Origin", "*")
|
c.Header("Access-Control-Allow-Origin", "*")
|
||||||
@@ -127,7 +136,8 @@ func corsMiddleware() gin.HandlerFunc {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// AuthMiddleware authenticates requests using API keys
|
// AuthMiddleware returns a Gin middleware handler that authenticates requests
|
||||||
|
// using API keys. If no API keys are configured, it allows all requests.
|
||||||
func AuthMiddleware(cfg *ServerConfig) gin.HandlerFunc {
|
func AuthMiddleware(cfg *ServerConfig) gin.HandlerFunc {
|
||||||
return func(c *gin.Context) {
|
return func(c *gin.Context) {
|
||||||
if len(cfg.ApiKeys) == 0 {
|
if len(cfg.ApiKeys) == 0 {
|
||||||
|
|||||||
@@ -1,5 +1,7 @@
|
|||||||
package api
|
package translator
|
||||||
|
|
||||||
|
// MimeTypes is a comprehensive map of file extensions to their corresponding MIME types.
|
||||||
|
// This is used to identify the type of file being uploaded or processed.
|
||||||
var MimeTypes = map[string]string{
|
var MimeTypes = map[string]string{
|
||||||
"ez": "application/andrew-inset",
|
"ez": "application/andrew-inset",
|
||||||
"aw": "application/applixware",
|
"aw": "application/applixware",
|
||||||
163
internal/api/translator/request.go
Normal file
163
internal/api/translator/request.go
Normal file
@@ -0,0 +1,163 @@
|
|||||||
|
package translator
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"github.com/luispater/CLIProxyAPI/internal/client"
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
"github.com/tidwall/gjson"
|
||||||
|
)
|
||||||
|
|
||||||
|
// PrepareRequest translates a raw JSON request from an OpenAI-compatible format
|
||||||
|
// to the internal format expected by the backend client. It parses messages,
|
||||||
|
// roles, content types (text, image, file), and tool calls.
|
||||||
|
func PrepareRequest(rawJson []byte) (string, []client.Content, []client.ToolDeclaration) {
|
||||||
|
// Extract the model name from the request, defaulting to "gemini-2.5-pro".
|
||||||
|
modelName := "gemini-2.5-pro"
|
||||||
|
modelResult := gjson.GetBytes(rawJson, "model")
|
||||||
|
if modelResult.Type == gjson.String {
|
||||||
|
modelName = modelResult.String()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Process the array of messages.
|
||||||
|
contents := make([]client.Content, 0)
|
||||||
|
messagesResult := gjson.GetBytes(rawJson, "messages")
|
||||||
|
if messagesResult.IsArray() {
|
||||||
|
messagesResults := messagesResult.Array()
|
||||||
|
for i := 0; i < len(messagesResults); i++ {
|
||||||
|
messageResult := messagesResults[i]
|
||||||
|
roleResult := messageResult.Get("role")
|
||||||
|
contentResult := messageResult.Get("content")
|
||||||
|
if roleResult.Type != gjson.String {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
switch roleResult.String() {
|
||||||
|
// System messages are converted to a user message followed by a model's acknowledgment.
|
||||||
|
case "system":
|
||||||
|
if contentResult.Type == gjson.String {
|
||||||
|
contents = append(contents, client.Content{Role: "user", Parts: []client.Part{{Text: contentResult.String()}}})
|
||||||
|
contents = append(contents, client.Content{Role: "model", Parts: []client.Part{{Text: "Understood. I will follow these instructions and use my tools to assist you."}}})
|
||||||
|
} else if contentResult.IsObject() {
|
||||||
|
// Handle object-based system messages.
|
||||||
|
if contentResult.Get("type").String() == "text" {
|
||||||
|
contents = append(contents, client.Content{Role: "user", Parts: []client.Part{{Text: contentResult.Get("text").String()}}})
|
||||||
|
contents = append(contents, client.Content{Role: "model", Parts: []client.Part{{Text: "Understood. I will follow these instructions and use my tools to assist you."}}})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// User messages can contain simple text or a multi-part body.
|
||||||
|
case "user":
|
||||||
|
if contentResult.Type == gjson.String {
|
||||||
|
contents = append(contents, client.Content{Role: "user", Parts: []client.Part{{Text: contentResult.String()}}})
|
||||||
|
} else if contentResult.IsArray() {
|
||||||
|
// Handle multi-part user messages (text, images, files).
|
||||||
|
contentItemResults := contentResult.Array()
|
||||||
|
parts := make([]client.Part, 0)
|
||||||
|
for j := 0; j < len(contentItemResults); j++ {
|
||||||
|
contentItemResult := contentItemResults[j]
|
||||||
|
contentTypeResult := contentItemResult.Get("type")
|
||||||
|
switch contentTypeResult.String() {
|
||||||
|
case "text":
|
||||||
|
parts = append(parts, client.Part{Text: contentItemResult.Get("text").String()})
|
||||||
|
case "image_url":
|
||||||
|
// Parse data URI for images.
|
||||||
|
imageURL := contentItemResult.Get("image_url.url").String()
|
||||||
|
if len(imageURL) > 5 {
|
||||||
|
imageURLs := strings.SplitN(imageURL[5:], ";", 2)
|
||||||
|
if len(imageURLs) == 2 && len(imageURLs[1]) > 7 {
|
||||||
|
parts = append(parts, client.Part{InlineData: &client.InlineData{
|
||||||
|
MimeType: imageURLs[0],
|
||||||
|
Data: imageURLs[1][7:],
|
||||||
|
}})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
case "file":
|
||||||
|
// Handle file attachments by determining MIME type from extension.
|
||||||
|
filename := contentItemResult.Get("file.filename").String()
|
||||||
|
fileData := contentItemResult.Get("file.file_data").String()
|
||||||
|
ext := ""
|
||||||
|
if split := strings.Split(filename, "."); len(split) > 1 {
|
||||||
|
ext = split[len(split)-1]
|
||||||
|
}
|
||||||
|
if mimeType, ok := MimeTypes[ext]; ok {
|
||||||
|
parts = append(parts, client.Part{InlineData: &client.InlineData{
|
||||||
|
MimeType: mimeType,
|
||||||
|
Data: fileData,
|
||||||
|
}})
|
||||||
|
} else {
|
||||||
|
log.Warnf("Unknown file name extension '%s' at index %d, skipping file", ext, j)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
contents = append(contents, client.Content{Role: "user", Parts: parts})
|
||||||
|
}
|
||||||
|
// Assistant messages can contain text or tool calls.
|
||||||
|
case "assistant":
|
||||||
|
if contentResult.Type == gjson.String {
|
||||||
|
contents = append(contents, client.Content{Role: "model", Parts: []client.Part{{Text: contentResult.String()}}})
|
||||||
|
} else if !contentResult.Exists() || contentResult.Type == gjson.Null {
|
||||||
|
// Handle tool calls made by the assistant.
|
||||||
|
toolCallsResult := messageResult.Get("tool_calls")
|
||||||
|
if toolCallsResult.IsArray() {
|
||||||
|
tcsResult := toolCallsResult.Array()
|
||||||
|
for j := 0; j < len(tcsResult); j++ {
|
||||||
|
tcResult := tcsResult[j]
|
||||||
|
functionName := tcResult.Get("function.name").String()
|
||||||
|
functionArgs := tcResult.Get("function.arguments").String()
|
||||||
|
var args map[string]any
|
||||||
|
if err := json.Unmarshal([]byte(functionArgs), &args); err == nil {
|
||||||
|
contents = append(contents, client.Content{
|
||||||
|
Role: "model", Parts: []client.Part{{
|
||||||
|
FunctionCall: &client.FunctionCall{
|
||||||
|
Name: functionName,
|
||||||
|
Args: args,
|
||||||
|
},
|
||||||
|
}},
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// Tool messages contain the output of a tool call.
|
||||||
|
case "tool":
|
||||||
|
toolCallID := messageResult.Get("tool_call_id").String()
|
||||||
|
if toolCallID != "" {
|
||||||
|
var responseData string
|
||||||
|
if contentResult.Type == gjson.String {
|
||||||
|
responseData = contentResult.String()
|
||||||
|
} else if contentResult.IsObject() && contentResult.Get("type").String() == "text" {
|
||||||
|
responseData = contentResult.Get("text").String()
|
||||||
|
}
|
||||||
|
functionResponse := client.FunctionResponse{Name: toolCallID, Response: map[string]interface{}{"result": responseData}}
|
||||||
|
contents = append(contents, client.Content{Role: "tool", Parts: []client.Part{{FunctionResponse: &functionResponse}}})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Translate the tool declarations from the request.
|
||||||
|
var tools []client.ToolDeclaration
|
||||||
|
toolsResult := gjson.GetBytes(rawJson, "tools")
|
||||||
|
if toolsResult.IsArray() {
|
||||||
|
tools = make([]client.ToolDeclaration, 1)
|
||||||
|
tools[0].FunctionDeclarations = make([]any, 0)
|
||||||
|
toolsResults := toolsResult.Array()
|
||||||
|
for i := 0; i < len(toolsResults); i++ {
|
||||||
|
toolResult := toolsResults[i]
|
||||||
|
if toolResult.Get("type").String() == "function" {
|
||||||
|
functionTypeResult := toolResult.Get("function")
|
||||||
|
if functionTypeResult.Exists() && functionTypeResult.IsObject() {
|
||||||
|
var functionDeclaration any
|
||||||
|
if err := json.Unmarshal([]byte(functionTypeResult.Raw), &functionDeclaration); err == nil {
|
||||||
|
tools[0].FunctionDeclarations = append(tools[0].FunctionDeclarations, functionDeclaration)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
tools = make([]client.ToolDeclaration, 0)
|
||||||
|
}
|
||||||
|
|
||||||
|
return modelName, contents, tools
|
||||||
|
}
|
||||||
169
internal/api/translator/response.go
Normal file
169
internal/api/translator/response.go
Normal file
@@ -0,0 +1,169 @@
|
|||||||
|
package translator
|
||||||
|
|
||||||
|
import (
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/tidwall/gjson"
|
||||||
|
"github.com/tidwall/sjson"
|
||||||
|
)
|
||||||
|
|
||||||
|
// ConvertCliToOpenAI translates a single chunk of a streaming response from the
|
||||||
|
// backend client format to the OpenAI Server-Sent Events (SSE) format.
|
||||||
|
// It returns an empty string if the chunk contains no useful data.
|
||||||
|
func ConvertCliToOpenAI(rawJson []byte) string {
|
||||||
|
// Initialize the OpenAI SSE template.
|
||||||
|
template := `{"id":"","object":"chat.completion.chunk","created":12345,"model":"model","choices":[{"index":0,"delta":{"role":null,"content":null,"reasoning_content":null,"tool_calls":null},"finish_reason":null,"native_finish_reason":null}]}`
|
||||||
|
|
||||||
|
// Extract and set the model version.
|
||||||
|
if modelVersionResult := gjson.GetBytes(rawJson, "response.modelVersion"); modelVersionResult.Exists() {
|
||||||
|
template, _ = sjson.Set(template, "model", modelVersionResult.String())
|
||||||
|
}
|
||||||
|
|
||||||
|
// Extract and set the creation timestamp.
|
||||||
|
if createTimeResult := gjson.GetBytes(rawJson, "response.createTime"); createTimeResult.Exists() {
|
||||||
|
t, err := time.Parse(time.RFC3339Nano, createTimeResult.String())
|
||||||
|
unixTimestamp := time.Now().Unix()
|
||||||
|
if err == nil {
|
||||||
|
unixTimestamp = t.Unix()
|
||||||
|
}
|
||||||
|
template, _ = sjson.Set(template, "created", unixTimestamp)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Extract and set the response ID.
|
||||||
|
if responseIdResult := gjson.GetBytes(rawJson, "response.responseId"); responseIdResult.Exists() {
|
||||||
|
template, _ = sjson.Set(template, "id", responseIdResult.String())
|
||||||
|
}
|
||||||
|
|
||||||
|
// Extract and set the finish reason.
|
||||||
|
if finishReasonResult := gjson.GetBytes(rawJson, "response.candidates.0.finishReason"); finishReasonResult.Exists() {
|
||||||
|
template, _ = sjson.Set(template, "choices.0.finish_reason", finishReasonResult.String())
|
||||||
|
template, _ = sjson.Set(template, "choices.0.native_finish_reason", finishReasonResult.String())
|
||||||
|
}
|
||||||
|
|
||||||
|
// Extract and set usage metadata (token counts).
|
||||||
|
if usageResult := gjson.GetBytes(rawJson, "response.usageMetadata"); usageResult.Exists() {
|
||||||
|
if candidatesTokenCountResult := usageResult.Get("candidatesTokenCount"); candidatesTokenCountResult.Exists() {
|
||||||
|
template, _ = sjson.Set(template, "usage.completion_tokens", candidatesTokenCountResult.Int())
|
||||||
|
}
|
||||||
|
if totalTokenCountResult := usageResult.Get("totalTokenCount"); totalTokenCountResult.Exists() {
|
||||||
|
template, _ = sjson.Set(template, "usage.total_tokens", totalTokenCountResult.Int())
|
||||||
|
}
|
||||||
|
promptTokenCount := usageResult.Get("promptTokenCount").Int()
|
||||||
|
thoughtsTokenCount := usageResult.Get("thoughtsTokenCount").Int()
|
||||||
|
template, _ = sjson.Set(template, "usage.prompt_tokens", promptTokenCount+thoughtsTokenCount)
|
||||||
|
if thoughtsTokenCount > 0 {
|
||||||
|
template, _ = sjson.Set(template, "usage.completion_tokens_details.reasoning_tokens", thoughtsTokenCount)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Process the main content part of the response.
|
||||||
|
partResult := gjson.GetBytes(rawJson, "response.candidates.0.content.parts.0")
|
||||||
|
partTextResult := partResult.Get("text")
|
||||||
|
functionCallResult := partResult.Get("functionCall")
|
||||||
|
|
||||||
|
if partTextResult.Exists() {
|
||||||
|
// Handle text content, distinguishing between regular content and reasoning/thoughts.
|
||||||
|
if partResult.Get("thought").Bool() {
|
||||||
|
template, _ = sjson.Set(template, "choices.0.delta.reasoning_content", partTextResult.String())
|
||||||
|
} else {
|
||||||
|
template, _ = sjson.Set(template, "choices.0.delta.content", partTextResult.String())
|
||||||
|
}
|
||||||
|
template, _ = sjson.Set(template, "choices.0.delta.role", "assistant")
|
||||||
|
} else if functionCallResult.Exists() {
|
||||||
|
// Handle function call content.
|
||||||
|
functionCallTemplate := `[{"id": "","type": "function","function": {"name": "","arguments": ""}}]`
|
||||||
|
fcName := functionCallResult.Get("name").String()
|
||||||
|
functionCallTemplate, _ = sjson.Set(functionCallTemplate, "0.id", fcName)
|
||||||
|
functionCallTemplate, _ = sjson.Set(functionCallTemplate, "0.function.name", fcName)
|
||||||
|
if fcArgsResult := functionCallResult.Get("args"); fcArgsResult.Exists() {
|
||||||
|
functionCallTemplate, _ = sjson.Set(functionCallTemplate, "0.function.arguments", fcArgsResult.Raw)
|
||||||
|
}
|
||||||
|
template, _ = sjson.Set(template, "choices.0.delta.role", "assistant")
|
||||||
|
template, _ = sjson.SetRaw(template, "choices.0.delta.tool_calls", functionCallTemplate)
|
||||||
|
} else {
|
||||||
|
// If no usable content is found, return an empty string.
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
return template
|
||||||
|
}
|
||||||
|
|
||||||
|
// ConvertCliToOpenAINonStream aggregates response chunks from the backend client
|
||||||
|
// into a single, non-streaming OpenAI-compatible JSON response.
|
||||||
|
func ConvertCliToOpenAINonStream(template string, rawJson []byte) string {
|
||||||
|
// Extract and set metadata fields that are typically set once per response.
|
||||||
|
if gjson.Get(template, "id").String() == "" {
|
||||||
|
if modelVersionResult := gjson.GetBytes(rawJson, "response.modelVersion"); modelVersionResult.Exists() {
|
||||||
|
template, _ = sjson.Set(template, "model", modelVersionResult.String())
|
||||||
|
}
|
||||||
|
if createTimeResult := gjson.GetBytes(rawJson, "response.createTime"); createTimeResult.Exists() {
|
||||||
|
t, err := time.Parse(time.RFC3339Nano, createTimeResult.String())
|
||||||
|
unixTimestamp := time.Now().Unix()
|
||||||
|
if err == nil {
|
||||||
|
unixTimestamp = t.Unix()
|
||||||
|
}
|
||||||
|
template, _ = sjson.Set(template, "created", unixTimestamp)
|
||||||
|
}
|
||||||
|
if responseIdResult := gjson.GetBytes(rawJson, "response.responseId"); responseIdResult.Exists() {
|
||||||
|
template, _ = sjson.Set(template, "id", responseIdResult.String())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Extract and set the finish reason.
|
||||||
|
if finishReasonResult := gjson.GetBytes(rawJson, "response.candidates.0.finishReason"); finishReasonResult.Exists() {
|
||||||
|
template, _ = sjson.Set(template, "choices.0.finish_reason", finishReasonResult.String())
|
||||||
|
template, _ = sjson.Set(template, "choices.0.native_finish_reason", finishReasonResult.String())
|
||||||
|
}
|
||||||
|
|
||||||
|
// Extract and set usage metadata (token counts).
|
||||||
|
if usageResult := gjson.GetBytes(rawJson, "response.usageMetadata"); usageResult.Exists() {
|
||||||
|
if candidatesTokenCountResult := usageResult.Get("candidatesTokenCount"); candidatesTokenCountResult.Exists() {
|
||||||
|
template, _ = sjson.Set(template, "usage.completion_tokens", candidatesTokenCountResult.Int())
|
||||||
|
}
|
||||||
|
if totalTokenCountResult := usageResult.Get("totalTokenCount"); totalTokenCountResult.Exists() {
|
||||||
|
template, _ = sjson.Set(template, "usage.total_tokens", totalTokenCountResult.Int())
|
||||||
|
}
|
||||||
|
promptTokenCount := usageResult.Get("promptTokenCount").Int()
|
||||||
|
thoughtsTokenCount := usageResult.Get("thoughtsTokenCount").Int()
|
||||||
|
template, _ = sjson.Set(template, "usage.prompt_tokens", promptTokenCount+thoughtsTokenCount)
|
||||||
|
if thoughtsTokenCount > 0 {
|
||||||
|
template, _ = sjson.Set(template, "usage.completion_tokens_details.reasoning_tokens", thoughtsTokenCount)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Process the main content part of the response.
|
||||||
|
partResult := gjson.GetBytes(rawJson, "response.candidates.0.content.parts.0")
|
||||||
|
partTextResult := partResult.Get("text")
|
||||||
|
functionCallResult := partResult.Get("functionCall")
|
||||||
|
|
||||||
|
if partTextResult.Exists() {
|
||||||
|
// Append text content, distinguishing between regular content and reasoning.
|
||||||
|
if partResult.Get("thought").Bool() {
|
||||||
|
currentContent := gjson.Get(template, "choices.0.message.reasoning_content").String()
|
||||||
|
template, _ = sjson.Set(template, "choices.0.message.reasoning_content", currentContent+partTextResult.String())
|
||||||
|
} else {
|
||||||
|
currentContent := gjson.Get(template, "choices.0.message.content").String()
|
||||||
|
template, _ = sjson.Set(template, "choices.0.message.content", currentContent+partTextResult.String())
|
||||||
|
}
|
||||||
|
template, _ = sjson.Set(template, "choices.0.message.role", "assistant")
|
||||||
|
} else if functionCallResult.Exists() {
|
||||||
|
// Append function call content to the tool_calls array.
|
||||||
|
if !gjson.Get(template, "choices.0.message.tool_calls").Exists() {
|
||||||
|
template, _ = sjson.SetRaw(template, "choices.0.message.tool_calls", `[]`)
|
||||||
|
}
|
||||||
|
functionCallItemTemplate := `{"id": "","type": "function","function": {"name": "","arguments": ""}}`
|
||||||
|
fcName := functionCallResult.Get("name").String()
|
||||||
|
functionCallItemTemplate, _ = sjson.Set(functionCallItemTemplate, "id", fcName)
|
||||||
|
functionCallItemTemplate, _ = sjson.Set(functionCallItemTemplate, "function.name", fcName)
|
||||||
|
if fcArgsResult := functionCallResult.Get("args"); fcArgsResult.Exists() {
|
||||||
|
functionCallItemTemplate, _ = sjson.Set(functionCallItemTemplate, "function.arguments", fcArgsResult.Raw)
|
||||||
|
}
|
||||||
|
template, _ = sjson.Set(template, "choices.0.message.role", "assistant")
|
||||||
|
template, _ = sjson.SetRaw(template, "choices.0.message.tool_calls.-1", functionCallItemTemplate)
|
||||||
|
} else {
|
||||||
|
// If no usable content is found, return an empty string.
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
return template
|
||||||
|
}
|
||||||
@@ -5,17 +5,18 @@ import (
|
|||||||
"encoding/json"
|
"encoding/json"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"github.com/luispater/CLIProxyAPI/internal/config"
|
|
||||||
log "github.com/sirupsen/logrus"
|
|
||||||
"github.com/skratchdot/open-golang/open"
|
|
||||||
"github.com/tidwall/gjson"
|
|
||||||
"golang.org/x/net/proxy"
|
|
||||||
"io"
|
"io"
|
||||||
"net"
|
"net"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/url"
|
"net/url"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/luispater/CLIProxyAPI/internal/config"
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
"github.com/skratchdot/open-golang/open"
|
||||||
|
"github.com/tidwall/gjson"
|
||||||
|
"golang.org/x/net/proxy"
|
||||||
|
|
||||||
"golang.org/x/oauth2"
|
"golang.org/x/oauth2"
|
||||||
"golang.org/x/oauth2/google"
|
"golang.org/x/oauth2/google"
|
||||||
)
|
)
|
||||||
@@ -33,76 +34,78 @@ var (
|
|||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
// GetAuthenticatedClient configures and returns an HTTP client with OAuth2 tokens.
|
// GetAuthenticatedClient configures and returns an HTTP client ready for making authenticated API calls.
|
||||||
// It handles the entire flow: loading, refreshing, and fetching new tokens.
|
// It manages the entire OAuth2 flow, including handling proxies, loading existing tokens,
|
||||||
|
// initiating a new web-based OAuth flow if necessary, and refreshing tokens.
|
||||||
func GetAuthenticatedClient(ctx context.Context, ts *TokenStorage, cfg *config.Config) (*http.Client, error) {
|
func GetAuthenticatedClient(ctx context.Context, ts *TokenStorage, cfg *config.Config) (*http.Client, error) {
|
||||||
|
// Configure proxy settings for the HTTP client if a proxy URL is provided.
|
||||||
proxyURL, err := url.Parse(cfg.ProxyUrl)
|
proxyURL, err := url.Parse(cfg.ProxyUrl)
|
||||||
if err == nil {
|
if err == nil {
|
||||||
|
var transport *http.Transport
|
||||||
if proxyURL.Scheme == "socks5" {
|
if proxyURL.Scheme == "socks5" {
|
||||||
|
// Handle SOCKS5 proxy.
|
||||||
username := proxyURL.User.Username()
|
username := proxyURL.User.Username()
|
||||||
password, _ := proxyURL.User.Password()
|
password, _ := proxyURL.User.Password()
|
||||||
auth := &proxy.Auth{
|
auth := &proxy.Auth{User: username, Password: password}
|
||||||
User: username,
|
|
||||||
Password: password,
|
|
||||||
}
|
|
||||||
dialer, errSOCKS5 := proxy.SOCKS5("tcp", proxyURL.Host, auth, proxy.Direct)
|
dialer, errSOCKS5 := proxy.SOCKS5("tcp", proxyURL.Host, auth, proxy.Direct)
|
||||||
if errSOCKS5 != nil {
|
if errSOCKS5 != nil {
|
||||||
log.Fatalf("create SOCKS5 dialer failed: %v", errSOCKS5)
|
log.Fatalf("create SOCKS5 dialer failed: %v", errSOCKS5)
|
||||||
}
|
}
|
||||||
|
transport = &http.Transport{
|
||||||
transport := &http.Transport{
|
DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
|
||||||
DialContext: func(ctx context.Context, network, addr string) (c net.Conn, err error) {
|
|
||||||
return dialer.Dial(network, addr)
|
return dialer.Dial(network, addr)
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
proxyClient := &http.Client{
|
|
||||||
Transport: transport,
|
|
||||||
}
|
|
||||||
|
|
||||||
ctx = context.WithValue(ctx, oauth2.HTTPClient, proxyClient)
|
|
||||||
} else if proxyURL.Scheme == "http" || proxyURL.Scheme == "https" {
|
} else if proxyURL.Scheme == "http" || proxyURL.Scheme == "https" {
|
||||||
transport := &http.Transport{
|
// Handle HTTP/HTTPS proxy.
|
||||||
Proxy: http.ProxyURL(proxyURL),
|
transport = &http.Transport{Proxy: http.ProxyURL(proxyURL)}
|
||||||
}
|
}
|
||||||
proxyClient := &http.Client{
|
|
||||||
Transport: transport,
|
if transport != nil {
|
||||||
}
|
proxyClient := &http.Client{Transport: transport}
|
||||||
ctx = context.WithValue(ctx, oauth2.HTTPClient, proxyClient)
|
ctx = context.WithValue(ctx, oauth2.HTTPClient, proxyClient)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Configure the OAuth2 client.
|
||||||
conf := &oauth2.Config{
|
conf := &oauth2.Config{
|
||||||
ClientID: oauthClientID,
|
ClientID: oauthClientID,
|
||||||
ClientSecret: oauthClientSecret,
|
ClientSecret: oauthClientSecret,
|
||||||
RedirectURL: "http://localhost:8085/oauth2callback", // Placeholder, will be updated
|
RedirectURL: "http://localhost:8085/oauth2callback", // This will be used by the local server.
|
||||||
Scopes: oauthScopes,
|
Scopes: oauthScopes,
|
||||||
Endpoint: google.Endpoint,
|
Endpoint: google.Endpoint,
|
||||||
}
|
}
|
||||||
|
|
||||||
var token *oauth2.Token
|
var token *oauth2.Token
|
||||||
|
|
||||||
|
// If no token is found in storage, initiate the web-based OAuth flow.
|
||||||
if ts.Token == nil {
|
if ts.Token == nil {
|
||||||
log.Info("Could not load token from file, starting OAuth flow.")
|
log.Info("Could not load token from file, starting OAuth flow.")
|
||||||
token, err = getTokenFromWeb(ctx, conf)
|
token, err = getTokenFromWeb(ctx, conf)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("failed to get token from web: %w", err)
|
return nil, fmt.Errorf("failed to get token from web: %w", err)
|
||||||
}
|
}
|
||||||
newTs, errSaveTokenToFile := createTokenStorage(ctx, conf, token, ts.ProjectID)
|
// After getting a new token, create a new token storage object with user info.
|
||||||
if errSaveTokenToFile != nil {
|
newTs, errCreateTokenStorage := createTokenStorage(ctx, conf, token, ts.ProjectID)
|
||||||
log.Errorf("Warning: failed to save token to file: %v", err)
|
if errCreateTokenStorage != nil {
|
||||||
return nil, errSaveTokenToFile
|
log.Errorf("Warning: failed to create token storage: %v", errCreateTokenStorage)
|
||||||
|
return nil, errCreateTokenStorage
|
||||||
}
|
}
|
||||||
*ts = *newTs
|
*ts = *newTs
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Unmarshal the stored token into an oauth2.Token object.
|
||||||
tsToken, _ := json.Marshal(ts.Token)
|
tsToken, _ := json.Marshal(ts.Token)
|
||||||
if err = json.Unmarshal(tsToken, &token); err != nil {
|
if err = json.Unmarshal(tsToken, &token); err != nil {
|
||||||
return nil, err
|
return nil, fmt.Errorf("failed to unmarshal token: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Return an HTTP client that automatically handles token refreshing.
|
||||||
return conf.Client(ctx, token), nil
|
return conf.Client(ctx, token), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// createTokenStorage creates a token storage.
|
// createTokenStorage creates a new TokenStorage object. It fetches the user's email
|
||||||
|
// using the provided token and populates the storage structure.
|
||||||
func createTokenStorage(ctx context.Context, config *oauth2.Config, token *oauth2.Token, projectID string) (*TokenStorage, error) {
|
func createTokenStorage(ctx context.Context, config *oauth2.Config, token *oauth2.Token, projectID string) (*TokenStorage, error) {
|
||||||
httpClient := config.Client(ctx, token)
|
httpClient := config.Client(ctx, token)
|
||||||
req, err := http.NewRequestWithContext(ctx, "GET", "https://www.googleapis.com/oauth2/v1/userinfo?alt=json", nil)
|
req, err := http.NewRequestWithContext(ctx, "GET", "https://www.googleapis.com/oauth2/v1/userinfo?alt=json", nil)
|
||||||
@@ -117,7 +120,9 @@ func createTokenStorage(ctx context.Context, config *oauth2.Config, token *oauth
|
|||||||
return nil, fmt.Errorf("failed to execute request: %w", err)
|
return nil, fmt.Errorf("failed to execute request: %w", err)
|
||||||
}
|
}
|
||||||
defer func() {
|
defer func() {
|
||||||
_ = resp.Body.Close()
|
if err = resp.Body.Close(); err != nil {
|
||||||
|
log.Printf("warn: failed to close response body: %v", err)
|
||||||
|
}
|
||||||
}()
|
}()
|
||||||
|
|
||||||
bodyBytes, _ := io.ReadAll(resp.Body)
|
bodyBytes, _ := io.ReadAll(resp.Body)
|
||||||
@@ -154,7 +159,10 @@ func createTokenStorage(ctx context.Context, config *oauth2.Config, token *oauth
|
|||||||
return &ts, nil
|
return &ts, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// getTokenFromWeb starts a local server to handle the OAuth2 flow.
|
// getTokenFromWeb initiates the web-based OAuth2 authorization flow.
|
||||||
|
// It starts a local HTTP server to listen for the callback from Google's auth server,
|
||||||
|
// opens the user's browser to the authorization URL, and exchanges the received
|
||||||
|
// authorization code for an access token.
|
||||||
func getTokenFromWeb(ctx context.Context, config *oauth2.Config) (*oauth2.Token, error) {
|
func getTokenFromWeb(ctx context.Context, config *oauth2.Config) (*oauth2.Token, error) {
|
||||||
// Use a channel to pass the authorization code from the HTTP handler to the main function.
|
// Use a channel to pass the authorization code from the HTTP handler to the main function.
|
||||||
codeChan := make(chan string)
|
codeChan := make(chan string)
|
||||||
|
|||||||
@@ -1,9 +1,17 @@
|
|||||||
package auth
|
package auth
|
||||||
|
|
||||||
|
// TokenStorage defines the structure for storing OAuth2 token information,
|
||||||
|
// along with associated user and project details. This data is typically
|
||||||
|
// serialized to a JSON file for persistence.
|
||||||
type TokenStorage struct {
|
type TokenStorage struct {
|
||||||
Token any `json:"token"`
|
// Token holds the raw OAuth2 token data, including access and refresh tokens.
|
||||||
|
Token any `json:"token"`
|
||||||
|
// ProjectID is the Google Cloud Project ID associated with this token.
|
||||||
ProjectID string `json:"project_id"`
|
ProjectID string `json:"project_id"`
|
||||||
Email string `json:"email"`
|
// Email is the email address of the authenticated user.
|
||||||
Auto bool `json:"auto"`
|
Email string `json:"email"`
|
||||||
Checked bool `json:"checked"`
|
// Auto indicates if the project ID was automatically selected.
|
||||||
|
Auto bool `json:"auto"`
|
||||||
|
// Checked indicates if the associated Cloud AI API has been verified as enabled.
|
||||||
|
Checked bool `json:"checked"`
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -6,12 +6,6 @@ import (
|
|||||||
"context"
|
"context"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
"github.com/luispater/CLIProxyAPI/internal/auth"
|
|
||||||
"github.com/luispater/CLIProxyAPI/internal/config"
|
|
||||||
log "github.com/sirupsen/logrus"
|
|
||||||
"github.com/tidwall/gjson"
|
|
||||||
"github.com/tidwall/sjson"
|
|
||||||
"golang.org/x/oauth2"
|
|
||||||
"io"
|
"io"
|
||||||
"net/http"
|
"net/http"
|
||||||
"os"
|
"os"
|
||||||
@@ -20,6 +14,13 @@ import (
|
|||||||
"strings"
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/luispater/CLIProxyAPI/internal/auth"
|
||||||
|
"github.com/luispater/CLIProxyAPI/internal/config"
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
"github.com/tidwall/gjson"
|
||||||
|
"github.com/tidwall/sjson"
|
||||||
|
"golang.org/x/oauth2"
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
@@ -194,7 +195,9 @@ func (c *Client) makeAPIRequest(ctx context.Context, endpoint, method string, bo
|
|||||||
return fmt.Errorf("failed to execute request: %w", err)
|
return fmt.Errorf("failed to execute request: %w", err)
|
||||||
}
|
}
|
||||||
defer func() {
|
defer func() {
|
||||||
_ = resp.Body.Close()
|
if err = resp.Body.Close(); err != nil {
|
||||||
|
log.Printf("warn: failed to close response body: %v", err)
|
||||||
|
}
|
||||||
}()
|
}()
|
||||||
|
|
||||||
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
|
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
|
||||||
@@ -253,7 +256,9 @@ func (c *Client) StreamAPIRequest(ctx context.Context, endpoint string, body int
|
|||||||
|
|
||||||
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
|
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
|
||||||
defer func() {
|
defer func() {
|
||||||
_ = resp.Body.Close()
|
if err = resp.Body.Close(); err != nil {
|
||||||
|
log.Printf("warn: failed to close response body: %v", err)
|
||||||
|
}
|
||||||
}()
|
}()
|
||||||
bodyBytes, _ := io.ReadAll(resp.Body)
|
bodyBytes, _ := io.ReadAll(resp.Body)
|
||||||
|
|
||||||
@@ -355,6 +360,9 @@ func (c *Client) SendMessageStream(ctx context.Context, rawJson []byte, model st
|
|||||||
return dataChan, errChan
|
return dataChan, errChan
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// CheckCloudAPIIsEnabled sends a simple test request to the API to verify
|
||||||
|
// that the Cloud AI API is enabled for the user's project. It provides
|
||||||
|
// an activation URL if the API is disabled.
|
||||||
func (c *Client) CheckCloudAPIIsEnabled() (bool, error) {
|
func (c *Client) CheckCloudAPIIsEnabled() (bool, error) {
|
||||||
ctx, cancel := context.WithCancel(context.Background())
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
defer func() {
|
defer func() {
|
||||||
@@ -363,79 +371,78 @@ func (c *Client) CheckCloudAPIIsEnabled() (bool, error) {
|
|||||||
}()
|
}()
|
||||||
c.RequestMutex.Lock()
|
c.RequestMutex.Lock()
|
||||||
|
|
||||||
requestBody := `{"project":"%s","request":{"contents":[{"role":"user","parts":[{"text":"Be concise. What is the capital of France?"}]}],"generationConfig":{"thinkingConfig":{"include_thoughts":false,"thinkingBudget":0}}},"model":"gemini-2.5-flash"}`
|
// A simple request to test the API endpoint.
|
||||||
requestBody = fmt.Sprintf(requestBody, c.tokenStorage.ProjectID)
|
requestBody := fmt.Sprintf(`{"project":"%s","request":{"contents":[{"role":"user","parts":[{"text":"Be concise. What is the capital of France?"}]}],"generationConfig":{"thinkingConfig":{"include_thoughts":false,"thinkingBudget":0}}},"model":"gemini-2.5-flash"}`, c.tokenStorage.ProjectID)
|
||||||
// log.Debug(requestBody)
|
|
||||||
stream, err := c.StreamAPIRequest(ctx, "streamGenerateContent", []byte(requestBody))
|
stream, err := c.StreamAPIRequest(ctx, "streamGenerateContent", []byte(requestBody))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
// If a 403 Forbidden error occurs, it likely means the API is not enabled.
|
||||||
if err.StatusCode == 403 {
|
if err.StatusCode == 403 {
|
||||||
errJson := err.Error.Error()
|
errJson := err.Error.Error()
|
||||||
codeResult := gjson.Get(errJson, "error.code")
|
// Check for a specific error code and extract the activation URL.
|
||||||
if codeResult.Exists() && codeResult.Type == gjson.Number {
|
if gjson.Get(errJson, "error.code").Int() == 403 {
|
||||||
if codeResult.Int() == 403 {
|
activationUrl := gjson.Get(errJson, "error.details.0.metadata.activationUrl").String()
|
||||||
activationUrlResult := gjson.Get(errJson, "error.details.0.metadata.activationUrl")
|
if activationUrl != "" {
|
||||||
if activationUrlResult.Exists() {
|
log.Warnf(
|
||||||
log.Warnf(
|
"\n\nPlease activate your account with this url:\n\n%s\n And execute this command again:\n%s --login --project_id %s",
|
||||||
"\n\nPlease activate your account with this url:\n\n%s\n And execute this command again:\n%s --login --project_id %s",
|
activationUrl,
|
||||||
activationUrlResult.String(),
|
os.Args[0],
|
||||||
os.Args[0],
|
c.tokenStorage.ProjectID,
|
||||||
c.tokenStorage.ProjectID,
|
)
|
||||||
)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return false, nil
|
return false, nil
|
||||||
}
|
}
|
||||||
return false, err.Error
|
return false, err.Error
|
||||||
}
|
}
|
||||||
|
defer func() {
|
||||||
|
_ = stream.Close()
|
||||||
|
}()
|
||||||
|
|
||||||
|
// We only need to know if the request was successful, so we can drain the stream.
|
||||||
scanner := bufio.NewScanner(stream)
|
scanner := bufio.NewScanner(stream)
|
||||||
for scanner.Scan() {
|
for scanner.Scan() {
|
||||||
line := scanner.Text()
|
// Do nothing, just consume the stream.
|
||||||
if !strings.HasPrefix(line, "data: ") {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if scannerErr := scanner.Err(); scannerErr != nil {
|
return scanner.Err() == nil, scanner.Err()
|
||||||
_ = stream.Close()
|
|
||||||
} else {
|
|
||||||
_ = stream.Close()
|
|
||||||
}
|
|
||||||
|
|
||||||
return true, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// GetProjectList fetches a list of Google Cloud projects accessible by the user.
|
||||||
func (c *Client) GetProjectList(ctx context.Context) (*GCPProject, error) {
|
func (c *Client) GetProjectList(ctx context.Context) (*GCPProject, error) {
|
||||||
token, err := c.httpClient.Transport.(*oauth2.Transport).Source.Token()
|
token, err := c.httpClient.Transport.(*oauth2.Transport).Source.Token()
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to get token: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
req, err := http.NewRequestWithContext(ctx, "GET", "https://cloudresourcemanager.googleapis.com/v1/projects", nil)
|
req, err := http.NewRequestWithContext(ctx, "GET", "https://cloudresourcemanager.googleapis.com/v1/projects", nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("could not get project list: %v", err)
|
return nil, fmt.Errorf("could not create project list request: %v", err)
|
||||||
}
|
}
|
||||||
req.Header.Set("Content-Type", "application/json")
|
|
||||||
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", token.AccessToken))
|
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", token.AccessToken))
|
||||||
|
|
||||||
resp, err := c.httpClient.Do(req)
|
resp, err := c.httpClient.Do(req)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("failed to execute request: %w", err)
|
return nil, fmt.Errorf("failed to execute project list request: %w", err)
|
||||||
}
|
}
|
||||||
defer func() {
|
defer func() {
|
||||||
_ = resp.Body.Close()
|
_ = resp.Body.Close()
|
||||||
}()
|
}()
|
||||||
|
|
||||||
bodyBytes, _ := io.ReadAll(resp.Body)
|
|
||||||
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
|
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
|
||||||
return nil, fmt.Errorf("get user info request failed with status %d: %s", resp.StatusCode, string(bodyBytes))
|
bodyBytes, _ := io.ReadAll(resp.Body)
|
||||||
|
return nil, fmt.Errorf("project list request failed with status %d: %s", resp.StatusCode, string(bodyBytes))
|
||||||
}
|
}
|
||||||
|
|
||||||
var project GCPProject
|
var project GCPProject
|
||||||
err = json.Unmarshal(bodyBytes, &project)
|
if err = json.NewDecoder(resp.Body).Decode(&project); err != nil {
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("failed to unmarshal project list: %w", err)
|
return nil, fmt.Errorf("failed to unmarshal project list: %w", err)
|
||||||
}
|
}
|
||||||
return &project, nil
|
return &project, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// SaveTokenToFile serializes the client's current token storage to a JSON file.
|
||||||
|
// The filename is constructed from the user's email and project ID.
|
||||||
func (c *Client) SaveTokenToFile() error {
|
func (c *Client) SaveTokenToFile() error {
|
||||||
if err := os.MkdirAll(c.cfg.AuthDir, 0700); err != nil {
|
if err := os.MkdirAll(c.cfg.AuthDir, 0700); err != nil {
|
||||||
return fmt.Errorf("failed to create directory: %v", err)
|
return fmt.Errorf("failed to create directory: %v", err)
|
||||||
@@ -457,7 +464,8 @@ func (c *Client) SaveTokenToFile() error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// getClientMetadata returns metadata about the client environment.
|
// getClientMetadata returns a map of metadata about the client environment,
|
||||||
|
// such as IDE type, platform, and plugin version.
|
||||||
func getClientMetadata() map[string]string {
|
func getClientMetadata() map[string]string {
|
||||||
return map[string]string{
|
return map[string]string{
|
||||||
"ideType": "IDE_UNSPECIFIED",
|
"ideType": "IDE_UNSPECIFIED",
|
||||||
@@ -467,7 +475,8 @@ func getClientMetadata() map[string]string {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// getClientMetadataString returns the metadata as a comma-separated string.
|
// getClientMetadataString returns the client metadata as a single,
|
||||||
|
// comma-separated string, which is required for the 'Client-Metadata' header.
|
||||||
func getClientMetadataString() string {
|
func getClientMetadataString() string {
|
||||||
md := getClientMetadata()
|
md := getClientMetadata()
|
||||||
parts := make([]string, 0, len(md))
|
parts := make([]string, 0, len(md))
|
||||||
@@ -477,11 +486,13 @@ func getClientMetadataString() string {
|
|||||||
return strings.Join(parts, ",")
|
return strings.Join(parts, ",")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// getUserAgent constructs the User-Agent string for HTTP requests.
|
||||||
func getUserAgent() string {
|
func getUserAgent() string {
|
||||||
return fmt.Sprintf(fmt.Sprintf("GeminiCLI/%s (%s; %s)", pluginVersion, runtime.GOOS, runtime.GOARCH))
|
return fmt.Sprintf("GeminiCLI/%s (%s; %s)", pluginVersion, runtime.GOOS, runtime.GOARCH)
|
||||||
}
|
}
|
||||||
|
|
||||||
// getPlatform returns the OS and architecture in the format expected by the API.
|
// getPlatform determines the operating system and architecture and formats
|
||||||
|
// it into a string expected by the backend API.
|
||||||
func getPlatform() string {
|
func getPlatform() string {
|
||||||
goOS := runtime.GOOS
|
goOS := runtime.GOOS
|
||||||
arch := runtime.GOARCH
|
arch := runtime.GOARCH
|
||||||
|
|||||||
@@ -2,17 +2,23 @@ package client
|
|||||||
|
|
||||||
import "time"
|
import "time"
|
||||||
|
|
||||||
|
// ErrorMessage encapsulates an error with an associated HTTP status code.
|
||||||
type ErrorMessage struct {
|
type ErrorMessage struct {
|
||||||
StatusCode int
|
StatusCode int
|
||||||
Error error
|
Error error
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// GCPProject represents the response structure for a Google Cloud project list request.
|
||||||
type GCPProject struct {
|
type GCPProject struct {
|
||||||
Projects []GCPProjectProjects `json:"projects"`
|
Projects []GCPProjectProjects `json:"projects"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// GCPProjectLabels defines the labels associated with a GCP project.
|
||||||
type GCPProjectLabels struct {
|
type GCPProjectLabels struct {
|
||||||
GenerativeLanguage string `json:"generative-language"`
|
GenerativeLanguage string `json:"generative-language"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// GCPProjectProjects contains details about a single Google Cloud project.
|
||||||
type GCPProjectProjects struct {
|
type GCPProjectProjects struct {
|
||||||
ProjectNumber string `json:"projectNumber"`
|
ProjectNumber string `json:"projectNumber"`
|
||||||
ProjectID string `json:"projectId"`
|
ProjectID string `json:"projectId"`
|
||||||
@@ -22,12 +28,14 @@ type GCPProjectProjects struct {
|
|||||||
CreateTime time.Time `json:"createTime"`
|
CreateTime time.Time `json:"createTime"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Content represents a single message in a conversation, with a role and parts.
|
||||||
type Content struct {
|
type Content struct {
|
||||||
Role string `json:"role"`
|
Role string `json:"role"`
|
||||||
Parts []Part `json:"parts"`
|
Parts []Part `json:"parts"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// Part represents a single part of a message's content.
|
// Part represents a distinct piece of content within a message, which can be
|
||||||
|
// text, inline data (like an image), a function call, or a function response.
|
||||||
type Part struct {
|
type Part struct {
|
||||||
Text string `json:"text,omitempty"`
|
Text string `json:"text,omitempty"`
|
||||||
InlineData *InlineData `json:"inlineData,omitempty"`
|
InlineData *InlineData `json:"inlineData,omitempty"`
|
||||||
@@ -35,46 +43,48 @@ type Part struct {
|
|||||||
FunctionResponse *FunctionResponse `json:"functionResponse,omitempty"`
|
FunctionResponse *FunctionResponse `json:"functionResponse,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// InlineData represents base64-encoded data with its MIME type.
|
||||||
type InlineData struct {
|
type InlineData struct {
|
||||||
MimeType string `json:"mime_type,omitempty"`
|
MimeType string `json:"mime_type,omitempty"`
|
||||||
Data string `json:"data,omitempty"`
|
Data string `json:"data,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// FunctionCall represents a tool call requested by the model.
|
// FunctionCall represents a tool call requested by the model, including the
|
||||||
|
// function name and its arguments.
|
||||||
type FunctionCall struct {
|
type FunctionCall struct {
|
||||||
Name string `json:"name"`
|
Name string `json:"name"`
|
||||||
Args map[string]interface{} `json:"args"`
|
Args map[string]interface{} `json:"args"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// FunctionResponse represents the result of a tool execution.
|
// FunctionResponse represents the result of a tool execution, sent back to the model.
|
||||||
type FunctionResponse struct {
|
type FunctionResponse struct {
|
||||||
Name string `json:"name"`
|
Name string `json:"name"`
|
||||||
Response map[string]interface{} `json:"response"`
|
Response map[string]interface{} `json:"response"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// GenerateContentRequest is the request payload for the streamGenerateContent endpoint.
|
// GenerateContentRequest is the top-level request structure for the streamGenerateContent endpoint.
|
||||||
type GenerateContentRequest struct {
|
type GenerateContentRequest struct {
|
||||||
Contents []Content `json:"contents"`
|
Contents []Content `json:"contents"`
|
||||||
Tools []ToolDeclaration `json:"tools,omitempty"`
|
Tools []ToolDeclaration `json:"tools,omitempty"`
|
||||||
GenerationConfig `json:"generationConfig"`
|
GenerationConfig `json:"generationConfig"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// GenerationConfig defines model generation parameters.
|
// GenerationConfig defines parameters that control the model's generation behavior.
|
||||||
type GenerationConfig struct {
|
type GenerationConfig struct {
|
||||||
ThinkingConfig GenerationConfigThinkingConfig `json:"thinkingConfig,omitempty"`
|
ThinkingConfig GenerationConfigThinkingConfig `json:"thinkingConfig,omitempty"`
|
||||||
Temperature float64 `json:"temperature,omitempty"`
|
Temperature float64 `json:"temperature,omitempty"`
|
||||||
TopP float64 `json:"topP,omitempty"`
|
TopP float64 `json:"topP,omitempty"`
|
||||||
TopK float64 `json:"topK,omitempty"`
|
TopK float64 `json:"topK,omitempty"`
|
||||||
// Temperature, TopP, TopK, etc. can be added here.
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// GenerationConfigThinkingConfig specifies configuration for the model's "thinking" process.
|
||||||
type GenerationConfigThinkingConfig struct {
|
type GenerationConfigThinkingConfig struct {
|
||||||
|
// IncludeThoughts determines whether the model should output its reasoning process.
|
||||||
IncludeThoughts bool `json:"include_thoughts,omitempty"`
|
IncludeThoughts bool `json:"include_thoughts,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// ToolDeclaration is the structure for declaring tools to the API.
|
// ToolDeclaration defines the structure for declaring tools (like functions)
|
||||||
// For now, we'll assume a simple structure. A more complete implementation
|
// that the model can call.
|
||||||
// would mirror the OpenAPI schema definition.
|
|
||||||
type ToolDeclaration struct {
|
type ToolDeclaration struct {
|
||||||
FunctionDeclarations []interface{} `json:"functionDeclarations"`
|
FunctionDeclarations []interface{} `json:"functionDeclarations"`
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -9,6 +9,9 @@ import (
|
|||||||
"os"
|
"os"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// DoLogin handles the entire user login and setup process.
|
||||||
|
// It authenticates the user, sets up the user's project, checks API enablement,
|
||||||
|
// and saves the token for future use.
|
||||||
func DoLogin(cfg *config.Config, projectID string) {
|
func DoLogin(cfg *config.Config, projectID string) {
|
||||||
var err error
|
var err error
|
||||||
var ts auth.TokenStorage
|
var ts auth.TokenStorage
|
||||||
@@ -16,9 +19,8 @@ func DoLogin(cfg *config.Config, projectID string) {
|
|||||||
ts.ProjectID = projectID
|
ts.ProjectID = projectID
|
||||||
}
|
}
|
||||||
|
|
||||||
// 2. Initialize authenticated HTTP Client
|
// Initialize an authenticated HTTP client. This will trigger the OAuth flow if necessary.
|
||||||
clientCtx := context.Background()
|
clientCtx := context.Background()
|
||||||
|
|
||||||
log.Info("Initializing authentication...")
|
log.Info("Initializing authentication...")
|
||||||
httpClient, errGetClient := auth.GetAuthenticatedClient(clientCtx, &ts, cfg)
|
httpClient, errGetClient := auth.GetAuthenticatedClient(clientCtx, &ts, cfg)
|
||||||
if errGetClient != nil {
|
if errGetClient != nil {
|
||||||
@@ -27,51 +29,57 @@ func DoLogin(cfg *config.Config, projectID string) {
|
|||||||
}
|
}
|
||||||
log.Info("Authentication successful.")
|
log.Info("Authentication successful.")
|
||||||
|
|
||||||
// 3. Initialize CLI Client
|
// Initialize the API client.
|
||||||
cliClient := client.NewClient(httpClient, &ts, cfg)
|
cliClient := client.NewClient(httpClient, &ts, cfg)
|
||||||
|
|
||||||
|
// Perform the user setup process.
|
||||||
err = cliClient.SetupUser(clientCtx, ts.Email, projectID)
|
err = cliClient.SetupUser(clientCtx, ts.Email, projectID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
// Handle the specific case where a project ID is required but not provided.
|
||||||
if err.Error() == "failed to start user onboarding, need define a project id" {
|
if err.Error() == "failed to start user onboarding, need define a project id" {
|
||||||
log.Error("failed to start user onboarding")
|
log.Error("Failed to start user onboarding: A project ID is required.")
|
||||||
|
// Fetch and display the user's available projects to help them choose one.
|
||||||
project, errGetProjectList := cliClient.GetProjectList(clientCtx)
|
project, errGetProjectList := cliClient.GetProjectList(clientCtx)
|
||||||
if errGetProjectList != nil {
|
if errGetProjectList != nil {
|
||||||
log.Fatalf("failed to complete user setup: %v", err)
|
log.Fatalf("Failed to get project list: %v", err)
|
||||||
} else {
|
} else {
|
||||||
log.Infof("Your account %s needs specify a project id.", ts.Email)
|
log.Infof("Your account %s needs to specify a project ID.", ts.Email)
|
||||||
log.Info("========================================================================")
|
log.Info("========================================================================")
|
||||||
for i := 0; i < len(project.Projects); i++ {
|
for _, p := range project.Projects {
|
||||||
log.Infof("Project ID: %s", project.Projects[i].ProjectID)
|
log.Infof("Project ID: %s", p.ProjectID)
|
||||||
log.Infof("Project Name: %s", project.Projects[i].Name)
|
log.Infof("Project Name: %s", p.Name)
|
||||||
log.Info("========================================================================")
|
log.Info("------------------------------------------------------------------------")
|
||||||
}
|
}
|
||||||
log.Infof("Please run this command to login again:\n\n%s --login --project_id <project_id>\n", os.Args[0])
|
log.Infof("Please run this command to login again with a specific project:\n\n%s --login --project_id <project_id>\n", os.Args[0])
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
// Log as a warning because in some cases, the CLI might still be usable
|
log.Fatalf("Failed to complete user setup: %v", err)
|
||||||
// or the user might want to retry setup later.
|
|
||||||
log.Fatalf("failed to complete user setup: %v", err)
|
|
||||||
}
|
}
|
||||||
} else {
|
return // Exit after handling the error.
|
||||||
auto := projectID == ""
|
}
|
||||||
cliClient.SetIsAuto(auto)
|
|
||||||
|
|
||||||
if !cliClient.IsChecked() && !cliClient.IsAuto() {
|
// If setup is successful, proceed to check API status and save the token.
|
||||||
isChecked, checkErr := cliClient.CheckCloudAPIIsEnabled()
|
auto := projectID == ""
|
||||||
if checkErr != nil {
|
cliClient.SetIsAuto(auto)
|
||||||
log.Fatalf("failed to check cloud api is enabled: %v", checkErr)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
cliClient.SetIsChecked(isChecked)
|
|
||||||
}
|
|
||||||
|
|
||||||
if !cliClient.IsChecked() && !cliClient.IsAuto() {
|
// If the project was not automatically selected, check if the Cloud AI API is enabled.
|
||||||
|
if !cliClient.IsChecked() && !cliClient.IsAuto() {
|
||||||
|
isChecked, checkErr := cliClient.CheckCloudAPIIsEnabled()
|
||||||
|
if checkErr != nil {
|
||||||
|
log.Fatalf("Failed to check if Cloud AI API is enabled: %v", checkErr)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
cliClient.SetIsChecked(isChecked)
|
||||||
err = cliClient.SaveTokenToFile()
|
// If the check fails (returns false), the CheckCloudAPIIsEnabled function
|
||||||
if err != nil {
|
// will have already printed instructions, so we can just exit.
|
||||||
log.Fatal(err)
|
if !isChecked {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Save the successfully obtained and verified token to a file.
|
||||||
|
err = cliClient.SaveTokenToFile()
|
||||||
|
if err != nil {
|
||||||
|
log.Fatalf("Failed to save token to file: %v", err)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -18,20 +18,25 @@ import (
|
|||||||
"time"
|
"time"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// StartService initializes and starts the main API proxy service.
|
||||||
|
// It loads all available authentication tokens, creates a pool of clients,
|
||||||
|
// starts the API server, and handles graceful shutdown signals.
|
||||||
func StartService(cfg *config.Config) {
|
func StartService(cfg *config.Config) {
|
||||||
// Create API server configuration
|
// Configure the API server based on the main application config.
|
||||||
apiConfig := &api.ServerConfig{
|
apiConfig := &api.ServerConfig{
|
||||||
Port: fmt.Sprintf("%d", cfg.Port),
|
Port: fmt.Sprintf("%d", cfg.Port),
|
||||||
Debug: cfg.Debug,
|
Debug: cfg.Debug,
|
||||||
ApiKeys: cfg.ApiKeys,
|
ApiKeys: cfg.ApiKeys,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Create a pool of API clients, one for each token file found.
|
||||||
cliClients := make([]*client.Client, 0)
|
cliClients := make([]*client.Client, 0)
|
||||||
err := filepath.Walk(cfg.AuthDir, func(path string, info fs.FileInfo, err error) error {
|
err := filepath.Walk(cfg.AuthDir, func(path string, info fs.FileInfo, err error) error {
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Process only JSON files in the auth directory.
|
||||||
if !info.IsDir() && strings.HasSuffix(info.Name(), ".json") {
|
if !info.IsDir() && strings.HasSuffix(info.Name(), ".json") {
|
||||||
log.Debugf("Loading token from: %s", path)
|
log.Debugf("Loading token from: %s", path)
|
||||||
f, errOpen := os.Open(path)
|
f, errOpen := os.Open(path)
|
||||||
@@ -42,58 +47,62 @@ func StartService(cfg *config.Config) {
|
|||||||
_ = f.Close()
|
_ = f.Close()
|
||||||
}()
|
}()
|
||||||
|
|
||||||
|
// Decode the token storage file.
|
||||||
var ts auth.TokenStorage
|
var ts auth.TokenStorage
|
||||||
if err = json.NewDecoder(f).Decode(&ts); err == nil {
|
if err = json.NewDecoder(f).Decode(&ts); err == nil {
|
||||||
// 2. Initialize authenticated HTTP Client
|
// For each valid token, create an authenticated client.
|
||||||
clientCtx := context.Background()
|
clientCtx := context.Background()
|
||||||
|
log.Info("Initializing authentication for token...")
|
||||||
log.Info("Initializing authentication...")
|
|
||||||
httpClient, errGetClient := auth.GetAuthenticatedClient(clientCtx, &ts, cfg)
|
httpClient, errGetClient := auth.GetAuthenticatedClient(clientCtx, &ts, cfg)
|
||||||
if errGetClient != nil {
|
if errGetClient != nil {
|
||||||
log.Fatalf("failed to get authenticated client: %v", errGetClient)
|
// Log fatal will exit, but we return the error for completeness.
|
||||||
|
log.Fatalf("failed to get authenticated client for token %s: %v", path, errGetClient)
|
||||||
return errGetClient
|
return errGetClient
|
||||||
}
|
}
|
||||||
log.Info("Authentication successful.")
|
log.Info("Authentication successful.")
|
||||||
|
|
||||||
// 3. Initialize CLI Client
|
// Add the new client to the pool.
|
||||||
cliClient := client.NewClient(httpClient, &ts, cfg)
|
cliClient := client.NewClient(httpClient, &ts, cfg)
|
||||||
cliClients = append(cliClients, cliClient)
|
cliClients = append(cliClients, cliClient)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
})
|
})
|
||||||
|
if err != nil {
|
||||||
|
log.Fatalf("Error walking auth directory: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
// Create API server
|
// Create and start the API server with the pool of clients.
|
||||||
apiServer := api.NewServer(apiConfig, cliClients)
|
apiServer := api.NewServer(apiConfig, cliClients)
|
||||||
log.Infof("Starting API server on port %s", apiConfig.Port)
|
log.Infof("Starting API server on port %s", apiConfig.Port)
|
||||||
if err = apiServer.Start(); err != nil {
|
if err = apiServer.Start(); err != nil {
|
||||||
log.Fatalf("API server failed to start: %v", err)
|
log.Fatalf("API server failed to start: %v", err)
|
||||||
return
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Set up graceful shutdown
|
// Set up a channel to listen for OS signals for graceful shutdown.
|
||||||
sigChan := make(chan os.Signal, 1)
|
sigChan := make(chan os.Signal, 1)
|
||||||
signal.Notify(sigChan, syscall.SIGINT, syscall.SIGTERM)
|
signal.Notify(sigChan, syscall.SIGINT, syscall.SIGTERM)
|
||||||
|
|
||||||
|
// Main loop to wait for shutdown signal.
|
||||||
for {
|
for {
|
||||||
select {
|
select {
|
||||||
case <-sigChan:
|
case <-sigChan:
|
||||||
log.Debugf("Received shutdown signal. Cleaning up...")
|
log.Debugf("Received shutdown signal. Cleaning up...")
|
||||||
|
|
||||||
// Create shutdown context
|
// Create a context with a timeout for the shutdown process.
|
||||||
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
|
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
|
||||||
_ = ctx // Mark ctx as used to avoid error, as apiServer.Stop(ctx) is commented out
|
_ = cancel
|
||||||
|
|
||||||
// Stop API server
|
// Stop the API server gracefully.
|
||||||
if err = apiServer.Stop(ctx); err != nil {
|
if err = apiServer.Stop(ctx); err != nil {
|
||||||
log.Debugf("Error stopping API server: %v", err)
|
log.Debugf("Error stopping API server: %v", err)
|
||||||
}
|
}
|
||||||
cancel()
|
|
||||||
|
|
||||||
log.Debugf("Cleanup completed. Exiting...")
|
log.Debugf("Cleanup completed. Exiting...")
|
||||||
os.Exit(0)
|
os.Exit(0)
|
||||||
case <-time.After(5 * time.Second):
|
case <-time.After(5 * time.Second):
|
||||||
|
// This case is currently empty and acts as a periodic check.
|
||||||
|
// It could be used for periodic tasks in the future.
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -6,33 +6,35 @@ import (
|
|||||||
"os"
|
"os"
|
||||||
)
|
)
|
||||||
|
|
||||||
// Config represents the application's configuration
|
// Config represents the application's configuration, loaded from a YAML file.
|
||||||
type Config struct {
|
type Config struct {
|
||||||
Port int `yaml:"port"`
|
// Port is the network port on which the API server will listen.
|
||||||
AuthDir string `yaml:"auth_dir"`
|
Port int `yaml:"port"`
|
||||||
Debug bool `yaml:"debug"`
|
// AuthDir is the directory where authentication token files are stored.
|
||||||
ProxyUrl string `yaml:"proxy-url"`
|
AuthDir string `yaml:"auth_dir"`
|
||||||
ApiKeys []string `yaml:"api_keys"`
|
// Debug enables or disables debug-level logging and other debug features.
|
||||||
|
Debug bool `yaml:"debug"`
|
||||||
|
// ProxyUrl is the URL of an optional proxy server to use for outbound requests.
|
||||||
|
ProxyUrl string `yaml:"proxy-url"`
|
||||||
|
// ApiKeys is a list of keys for authenticating clients to this proxy server.
|
||||||
|
ApiKeys []string `yaml:"api_keys"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// / LoadConfig loads the configuration from the specified file
|
// LoadConfig reads a YAML configuration file from the given path,
|
||||||
|
// unmarshals it into a Config struct, and returns it.
|
||||||
func LoadConfig(configFile string) (*Config, error) {
|
func LoadConfig(configFile string) (*Config, error) {
|
||||||
// Read the configuration file
|
// Read the entire configuration file into memory.
|
||||||
data, err := os.ReadFile(configFile)
|
data, err := os.ReadFile(configFile)
|
||||||
// If reading the file fails
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
// Return an error
|
|
||||||
return nil, fmt.Errorf("failed to read config file: %w", err)
|
return nil, fmt.Errorf("failed to read config file: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Parse the YAML data
|
// Unmarshal the YAML data into the Config struct.
|
||||||
var config Config
|
var config Config
|
||||||
// If parsing the YAML data fails
|
|
||||||
if err = yaml.Unmarshal(data, &config); err != nil {
|
if err = yaml.Unmarshal(data, &config); err != nil {
|
||||||
// Return an error
|
|
||||||
return nil, fmt.Errorf("failed to parse config file: %w", err)
|
return nil, fmt.Errorf("failed to parse config file: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Return the configuration
|
// Return the populated configuration struct.
|
||||||
return &config, nil
|
return &config, nil
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user