Compare commits

...

10 Commits

Author SHA1 Message Date
Luis Pater
3c4dc07980 Add file watcher for dynamic configuration and client reloading
- Introduced `Watcher` for monitoring updates to the configuration file and authentication directory.
- Integrated file watching into `StartService` to handle dynamic changes without restarting.
- Enhanced API server and handlers to support client and configuration updates.
- Updated `.gitignore` to include `docs/` directory.
- Modified go dependencies to include `fsnotify` for the file watcher.
2025-08-02 16:15:56 +08:00
Luis Pater
3b4634e2dc Improve getClient logic with optional content generation flag
- Added `isGenerateContent` optional parameter to `getClient` for conditional client selection.
- Updated `gemini-handlers` to utilize the new parameter for enhanced control.
2025-07-27 02:30:08 +08:00
Luis Pater
00bd6a3e46 Update .goreleaser.yml to include config.example.yaml instead of config.yaml in release assets 2025-07-26 22:19:33 +08:00
Luis Pater
5812229d9b Add .gitignore and ignore config.yaml 2025-07-26 22:10:07 +08:00
Luis Pater
0b026933a7 Update example configuration file (config.example.yaml) 2025-07-26 22:08:25 +08:00
Luis Pater
3b2ab0d7bd Fix SSE headers initialization for geminiStreamGenerateContent and internalStreamGenerateContent
- Added conditional logic to properly initialize SSE headers only when `alt` is empty.
- Ensured headers like `Content-Type`, `Cache-Control`, and `Access-Control-Allow-Origin` are set for better compatibility.
2025-07-26 17:16:55 +08:00
Luis Pater
e64fa48823 Enhance Gemini request handling with fallback support for contents
- Added conditional logic to support `contents` as a fallback to `generateContentRequest`.
- Improved template construction and ensured proper cleanup of request fields.
- Introduced debug logging for troubleshooting request generation.
2025-07-26 17:04:14 +08:00
Luis Pater
beff9282f6 Fix alt parameter handling in URL construction
- Ensured `alt` parameter is only appended when non-empty.
- Added debug logging for constructed URLs.
2025-07-26 15:51:04 +08:00
Luis Pater
31a9e2d11f Add GeminiGetHandler, enhance Gemini functionality, and enable token counting
- Added `GeminiGetHandler` for handling GET requests with extended Gemini model support.
- Introduced `geminiCountTokens` function to calculate token usage.
- Refactored `APIRequest` and related methods to support `alt` parameter for enhanced flexibility.
- Updated routes and request processing to integrate new handler and functions.
2025-07-26 06:51:49 +08:00
Luis Pater
423faae3da Add GeminiModels handler and enhance API key validation
- Introduced `GeminiModels` handler to serve Gemini model information under `/v1beta/models`.
- Updated `AuthMiddleware` to validate API keys from query parameters for improved flexibility.
- Adjusted route to use the new handler for model retrieval.
2025-07-26 04:41:55 +08:00
13 changed files with 648 additions and 59 deletions

2
.gitignore vendored Normal file
View File

@@ -0,0 +1,2 @@
config.yaml
docs/

View File

@@ -15,4 +15,4 @@ archives:
- LICENSE
- README.md
- README_CN.md
- config.yaml
- config.example.yaml

View File

@@ -63,14 +63,17 @@ func main() {
var wd string
// Load configuration from the specified path or the default path.
var configFilePath string
if configPath != "" {
configFilePath = configPath
cfg, err = config.LoadConfig(configPath)
} else {
wd, err = os.Getwd()
if err != nil {
log.Fatalf("failed to get working directory: %v", err)
}
cfg, err = config.LoadConfig(path.Join(wd, "config.yaml"))
configFilePath = path.Join(wd, "config.yaml")
cfg, err = config.LoadConfig(configFilePath)
}
if err != nil {
log.Fatalf("failed to load config: %v", err)
@@ -102,6 +105,6 @@ func main() {
if login {
cmd.DoLogin(cfg, projectID)
} else {
cmd.StartService(cfg)
cmd.StartService(cfg, configFilePath)
}
}

3
go.mod
View File

@@ -8,6 +8,7 @@ require (
github.com/skratchdot/open-golang v0.0.0-20200116055534-eef842397966
github.com/tidwall/gjson v1.18.0
github.com/tidwall/sjson v1.2.5
golang.org/x/net v0.37.1-0.20250305215238-2914f4677317
golang.org/x/oauth2 v0.30.0
gopkg.in/yaml.v3 v3.0.1
)
@@ -18,6 +19,7 @@ require (
github.com/bytedance/sonic/loader v0.1.1 // indirect
github.com/cloudwego/base64x v0.1.4 // indirect
github.com/cloudwego/iasm v0.2.0 // indirect
github.com/fsnotify/fsnotify v1.9.0 // indirect
github.com/gabriel-vasile/mimetype v1.4.3 // indirect
github.com/gin-contrib/sse v0.1.0 // indirect
github.com/go-playground/locales v0.14.1 // indirect
@@ -37,7 +39,6 @@ require (
github.com/ugorji/go/codec v1.2.12 // indirect
golang.org/x/arch v0.8.0 // indirect
golang.org/x/crypto v0.36.0 // indirect
golang.org/x/net v0.37.1-0.20250305215238-2914f4677317 // indirect
golang.org/x/sys v0.31.0 // indirect
golang.org/x/text v0.23.0 // indirect
google.golang.org/protobuf v1.34.1 // indirect

2
go.sum
View File

@@ -11,6 +11,8 @@ github.com/cloudwego/iasm v0.2.0/go.mod h1:8rXZaNYT2n95jn+zTI1sDr+IgcD2GVs0nlbbQ
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/fsnotify/fsnotify v1.9.0 h1:2Ml+OJNzbYCTzsxtv8vKSFD9PbJjmhYF14k/jKC7S9k=
github.com/fsnotify/fsnotify v1.9.0/go.mod h1:8jBTzvmWwFyi3Pb8djgCCO5IBqzKJ/Jwo8TRcHyHii0=
github.com/gabriel-vasile/mimetype v1.4.3 h1:in2uUcidCuFcDKtdcBxlR0rJ1+fsokWf+uqxgUFjbI0=
github.com/gabriel-vasile/mimetype v1.4.3/go.mod h1:d8uq/6HKRL6CGdk+aubisF/M5GcPfT7nKyLpA0lbSSk=
github.com/gin-contrib/sse v0.1.0 h1:Y/yl/+YNO8GZSjAhjMsSuLt29uWRFHdHYUb5lYOV9qE=

View File

@@ -99,6 +99,15 @@ func (h *APIHandlers) CLIHandler(c *gin.Context) {
}
func (h *APIHandlers) internalStreamGenerateContent(c *gin.Context, rawJson []byte) {
alt := h.getAlt(c)
if alt == "" {
c.Header("Content-Type", "text/event-stream")
c.Header("Cache-Control", "no-cache")
c.Header("Connection", "keep-alive")
c.Header("Access-Control-Allow-Origin", "*")
}
// Get the http.Flusher interface to manually flush the response.
flusher, ok := c.Writer.(http.Flusher)
if !ok {
@@ -141,7 +150,7 @@ outLoop:
log.Debugf("Request use account: %s, project id: %s", cliClient.GetEmail(), cliClient.GetProjectID())
}
// Send the message and receive response chunks and errors via channels.
respChan, errChan := cliClient.SendRawMessageStream(cliCtx, rawJson)
respChan, errChan := cliClient.SendRawMessageStream(cliCtx, rawJson, "")
hasFirstResponse := false
for {
select {
@@ -220,7 +229,7 @@ func (h *APIHandlers) internalGenerateContent(c *gin.Context, rawJson []byte) {
log.Debugf("Request use account: %s, project id: %s", cliClient.GetEmail(), cliClient.GetProjectID())
}
resp, err := cliClient.SendRawMessage(cliCtx, rawJson)
resp, err := cliClient.SendRawMessage(cliCtx, rawJson, "")
if err != nil {
if err.StatusCode == 429 && h.cfg.QuotaExceeded.SwitchProject {
continue

View File

@@ -14,11 +14,27 @@ import (
"time"
)
func (h *APIHandlers) GeminiHandler(c *gin.Context) {
var person struct {
func (h *APIHandlers) GeminiModels(c *gin.Context) {
c.Status(http.StatusOK)
c.Header("Content-Type", "application/json; charset=UTF-8")
_, _ = c.Writer.Write([]byte(`{"models":[{"name":"models/gemini-2.5-flash","version":"001","displayName":"Gemini `))
_, _ = c.Writer.Write([]byte(`2.5 Flash","description":"Stable version of Gemini 2.5 Flash, our mid-size multimod`))
_, _ = c.Writer.Write([]byte(`al model that supports up to 1 million tokens, released in June of 2025.","inputTok`))
_, _ = c.Writer.Write([]byte(`enLimit":1048576,"outputTokenLimit":65536,"supportedGenerationMethods":["generateCo`))
_, _ = c.Writer.Write([]byte(`ntent","countTokens","createCachedContent","batchGenerateContent"],"temperature":1,`))
_, _ = c.Writer.Write([]byte(`"topP":0.95,"topK":64,"maxTemperature":2,"thinking":true},{"name":"models/gemini-2.`))
_, _ = c.Writer.Write([]byte(`5-pro","version":"2.5","displayName":"Gemini 2.5 Pro","description":"Stable release`))
_, _ = c.Writer.Write([]byte(` (June 17th, 2025) of Gemini 2.5 Pro","inputTokenLimit":1048576,"outputTokenLimit":`))
_, _ = c.Writer.Write([]byte(`65536,"supportedGenerationMethods":["generateContent","countTokens","createCachedCo`))
_, _ = c.Writer.Write([]byte(`ntent","batchGenerateContent"],"temperature":1,"topP":0.95,"topK":64,"maxTemperatur`))
_, _ = c.Writer.Write([]byte(`e":2,"thinking":true}],"nextPageToken":""}`))
}
func (h *APIHandlers) GeminiGetHandler(c *gin.Context) {
var request struct {
Action string `uri:"action" binding:"required"`
}
if err := c.ShouldBindUri(&person); err != nil {
if err := c.ShouldBindUri(&request); err != nil {
c.JSON(http.StatusBadRequest, ErrorResponse{
Error: ErrorDetail{
Message: fmt.Sprintf("Invalid request: %v", err),
@@ -27,7 +43,45 @@ func (h *APIHandlers) GeminiHandler(c *gin.Context) {
})
return
}
action := strings.Split(person.Action, ":")
if request.Action == "gemini-2.5-pro" {
c.Status(http.StatusOK)
c.Header("Content-Type", "application/json; charset=UTF-8")
_, _ = c.Writer.Write([]byte(`{"name":"models/gemini-2.5-pro","version":"2.5","displayName":"Gemini 2.5 Pro",`))
_, _ = c.Writer.Write([]byte(`"description":"Stable release (June 17th, 2025) of Gemini 2.5 Pro","inputTokenL`))
_, _ = c.Writer.Write([]byte(`imit":1048576,"outputTokenLimit":65536,"supportedGenerationMethods":["generateC`))
_, _ = c.Writer.Write([]byte(`ontent","countTokens","createCachedContent","batchGenerateContent"],"temperatur`))
_, _ = c.Writer.Write([]byte(`e":1,"topP":0.95,"topK":64,"maxTemperature":2,"thinking":true}`))
} else if request.Action == "gemini-2.5-flash" {
c.Status(http.StatusOK)
c.Header("Content-Type", "application/json; charset=UTF-8")
_, _ = c.Writer.Write([]byte(`{"name":"models/gemini-2.5-flash","version":"001","displayName":"Gemini 2.5 Fla`))
_, _ = c.Writer.Write([]byte(`sh","description":"Stable version of Gemini 2.5 Flash, our mid-size multimodal `))
_, _ = c.Writer.Write([]byte(`model that supports up to 1 million tokens, released in June of 2025.","inputTo`))
_, _ = c.Writer.Write([]byte(`kenLimit":1048576,"outputTokenLimit":65536,"supportedGenerationMethods":["gener`))
_, _ = c.Writer.Write([]byte(`ateContent","countTokens","createCachedContent","batchGenerateContent"],"temper`))
_, _ = c.Writer.Write([]byte(`ature":1,"topP":0.95,"topK":64,"maxTemperature":2,"thinking":true}`))
} else {
c.Status(http.StatusNotFound)
_, _ = c.Writer.Write([]byte(
`{"error":{"message":"Not Found","code":404,"status":"NOT_FOUND"}}`,
))
}
}
func (h *APIHandlers) GeminiHandler(c *gin.Context) {
var request struct {
Action string `uri:"action" binding:"required"`
}
if err := c.ShouldBindUri(&request); err != nil {
c.JSON(http.StatusBadRequest, ErrorResponse{
Error: ErrorDetail{
Message: fmt.Sprintf("Invalid request: %v", err),
Type: "invalid_request_error",
},
})
return
}
action := strings.Split(request.Action, ":")
if len(action) != 2 {
c.JSON(http.StatusNotFound, ErrorResponse{
Error: ErrorDetail{
@@ -47,10 +101,21 @@ func (h *APIHandlers) GeminiHandler(c *gin.Context) {
h.geminiGenerateContent(c, rawJson)
} else if method == "streamGenerateContent" {
h.geminiStreamGenerateContent(c, rawJson)
} else if method == "countTokens" {
h.geminiCountTokens(c, rawJson)
}
}
func (h *APIHandlers) geminiStreamGenerateContent(c *gin.Context, rawJson []byte) {
alt := h.getAlt(c)
if alt == "" {
c.Header("Content-Type", "text/event-stream")
c.Header("Cache-Control", "no-cache")
c.Header("Connection", "keep-alive")
c.Header("Access-Control-Allow-Origin", "*")
}
// Get the http.Flusher interface to manually flush the response.
flusher, ok := c.Writer.(http.Flusher)
if !ok {
@@ -118,7 +183,7 @@ outLoop:
}
// Send the message and receive response chunks and errors via channels.
respChan, errChan := cliClient.SendRawMessageStream(cliCtx, rawJson)
respChan, errChan := cliClient.SendRawMessageStream(cliCtx, rawJson, alt)
for {
select {
// Handle client disconnection.
@@ -135,14 +200,33 @@ outLoop:
return
} else {
if cliClient.GetGenerativeLanguageAPIKey() == "" {
responseResult := gjson.GetBytes(chunk, "response")
if responseResult.Exists() {
chunk = []byte(responseResult.Raw)
if alt == "" {
responseResult := gjson.GetBytes(chunk, "response")
if responseResult.Exists() {
chunk = []byte(responseResult.Raw)
}
} else {
chunkTemplate := "[]"
responseResult := gjson.ParseBytes(chunk)
if responseResult.IsArray() {
responseResultItems := responseResult.Array()
for i := 0; i < len(responseResultItems); i++ {
responseResultItem := responseResultItems[i]
if responseResultItem.Get("response").Exists() {
chunkTemplate, _ = sjson.SetRaw(chunkTemplate, "-1", responseResultItem.Get("response").Raw)
}
}
}
chunk = []byte(chunkTemplate)
}
}
_, _ = c.Writer.Write([]byte("data: "))
_, _ = c.Writer.Write(chunk)
_, _ = c.Writer.Write([]byte("\n\n"))
if alt == "" {
_, _ = c.Writer.Write([]byte("data: "))
_, _ = c.Writer.Write(chunk)
_, _ = c.Writer.Write([]byte("\n\n"))
} else {
_, _ = c.Writer.Write(chunk)
}
flusher.Flush()
}
// Handle errors from the backend.
@@ -165,9 +249,79 @@ outLoop:
}
}
func (h *APIHandlers) geminiCountTokens(c *gin.Context, rawJson []byte) {
c.Header("Content-Type", "application/json")
alt := h.getAlt(c)
// orgRawJson := rawJson
modelResult := gjson.GetBytes(rawJson, "model")
modelName := modelResult.String()
cliCtx, cliCancel := context.WithCancel(context.Background())
var cliClient *client.Client
defer func() {
if cliClient != nil {
cliClient.RequestMutex.Unlock()
}
}()
for {
var errorResponse *client.ErrorMessage
cliClient, errorResponse = h.getClient(modelName, false)
if errorResponse != nil {
c.Status(errorResponse.StatusCode)
_, _ = fmt.Fprint(c.Writer, errorResponse.Error)
cliCancel()
return
}
if glAPIKey := cliClient.GetGenerativeLanguageAPIKey(); glAPIKey != "" {
log.Debugf("Request use generative language API Key: %s", glAPIKey)
} else {
log.Debugf("Request use account: %s, project id: %s", cliClient.GetEmail(), cliClient.GetProjectID())
template := `{"request":{}}`
if gjson.GetBytes(rawJson, "generateContentRequest").Exists() {
template, _ = sjson.SetRaw(template, "request", gjson.GetBytes(rawJson, "generateContentRequest").Raw)
template, _ = sjson.Delete(template, "generateContentRequest")
} else if gjson.GetBytes(rawJson, "contents").Exists() {
template, _ = sjson.SetRaw(template, "request.contents", gjson.GetBytes(rawJson, "contents").Raw)
template, _ = sjson.Delete(template, "contents")
}
rawJson = []byte(template)
}
resp, err := cliClient.SendRawTokenCount(cliCtx, rawJson, alt)
if err != nil {
if err.StatusCode == 429 && h.cfg.QuotaExceeded.SwitchProject {
continue
} else {
c.Status(err.StatusCode)
_, _ = c.Writer.Write([]byte(err.Error.Error()))
cliCancel()
// log.Debugf(err.Error.Error())
// log.Debugf(string(rawJson))
// log.Debugf(string(orgRawJson))
}
break
} else {
if cliClient.GetGenerativeLanguageAPIKey() == "" {
responseResult := gjson.GetBytes(resp, "response")
if responseResult.Exists() {
resp = []byte(responseResult.Raw)
}
}
_, _ = c.Writer.Write(resp)
cliCancel()
break
}
}
}
func (h *APIHandlers) geminiGenerateContent(c *gin.Context, rawJson []byte) {
c.Header("Content-Type", "application/json")
alt := h.getAlt(c)
modelResult := gjson.GetBytes(rawJson, "model")
modelName := modelResult.String()
cliCtx, cliCancel := context.WithCancel(context.Background())
@@ -217,7 +371,7 @@ func (h *APIHandlers) geminiGenerateContent(c *gin.Context, rawJson []byte) {
} else {
log.Debugf("Request use account: %s, project id: %s", cliClient.GetEmail(), cliClient.GetProjectID())
}
resp, err := cliClient.SendRawMessage(cliCtx, rawJson)
resp, err := cliClient.SendRawMessage(cliCtx, rawJson, alt)
if err != nil {
if err.StatusCode == 429 && h.cfg.QuotaExceeded.SwitchProject {
continue
@@ -240,3 +394,16 @@ func (h *APIHandlers) geminiGenerateContent(c *gin.Context, rawJson []byte) {
}
}
}
func (h *APIHandlers) getAlt(c *gin.Context) string {
var alt string
var hasAlt bool
alt, hasAlt = c.GetQuery("alt")
if !hasAlt {
alt, _ = c.GetQuery("$alt")
}
if alt == "sse" {
return ""
}
return alt
}

View File

@@ -36,6 +36,12 @@ func NewAPIHandlers(cliClients []*client.Client, cfg *config.Config) *APIHandler
}
}
// UpdateClients updates the handlers' client list and configuration
func (h *APIHandlers) UpdateClients(clients []*client.Client, cfg *config.Config) {
h.cliClients = clients
h.cfg = cfg
}
// Models handles the /v1/models endpoint.
// It returns a hardcoded list of available AI models.
func (h *APIHandlers) Models(c *gin.Context) {
@@ -85,7 +91,7 @@ func (h *APIHandlers) Models(c *gin.Context) {
})
}
func (h *APIHandlers) getClient(modelName string) (*client.Client, *client.ErrorMessage) {
func (h *APIHandlers) getClient(modelName string, isGenerateContent ...bool) (*client.Client, *client.ErrorMessage) {
if len(h.cliClients) == 0 {
return nil, &client.ErrorMessage{StatusCode: 500, Error: fmt.Errorf("no clients available")}
}
@@ -95,8 +101,10 @@ func (h *APIHandlers) getClient(modelName string) (*client.Client, *client.Error
// Lock the mutex to update the last used client index
mutex.Lock()
startIndex := lastUsedClientIndex
currentIndex := (startIndex + 1) % len(h.cliClients)
lastUsedClientIndex = currentIndex
if (len(isGenerateContent) > 0 && isGenerateContent[0]) || len(isGenerateContent) == 0 {
currentIndex := (startIndex + 1) % len(h.cliClients)
lastUsedClientIndex = currentIndex
}
mutex.Unlock()
// Reorder the client to start from the last used index

View File

@@ -75,8 +75,9 @@ func (s *Server) setupRoutes() {
v1beta := s.engine.Group("/v1beta")
v1beta.Use(AuthMiddleware(s.cfg))
{
v1beta.GET("/models", s.handlers.Models)
v1beta.GET("/models", s.handlers.GeminiModels)
v1beta.POST("/models/:action", s.handlers.GeminiHandler)
v1beta.GET("/models/:action", s.handlers.GeminiGetHandler)
}
// Root endpoint
@@ -138,6 +139,13 @@ func corsMiddleware() gin.HandlerFunc {
}
}
// UpdateClients updates the server's client list and configuration
func (s *Server) UpdateClients(clients []*client.Client, cfg *config.Config) {
s.cfg = cfg
s.handlers.UpdateClients(clients, cfg)
log.Infof("server clients and configuration updated: %d clients", len(clients))
}
// AuthMiddleware returns a Gin middleware handler that authenticates requests
// using API keys. If no API keys are configured, it allows all requests.
func AuthMiddleware(cfg *config.Config) gin.HandlerFunc {
@@ -151,7 +159,11 @@ func AuthMiddleware(cfg *config.Config) gin.HandlerFunc {
authHeader := c.GetHeader("Authorization")
authHeaderGoogle := c.GetHeader("X-Goog-Api-Key")
authHeaderAnthropic := c.GetHeader("X-Api-Key")
if authHeader == "" && authHeaderGoogle == "" && authHeaderAnthropic == "" {
// Get the API key from the query parameter
apiKeyQuery, _ := c.GetQuery("key")
if authHeader == "" && authHeaderGoogle == "" && authHeaderAnthropic == "" && apiKeyQuery == "" {
c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{
"error": "Missing API key",
})
@@ -170,7 +182,7 @@ func AuthMiddleware(cfg *config.Config) gin.HandlerFunc {
// Find the API key in the in-memory list
var foundKey string
for i := range cfg.ApiKeys {
if cfg.ApiKeys[i] == apiKey || cfg.ApiKeys[i] == authHeaderGoogle || cfg.ApiKeys[i] == authHeaderAnthropic {
if cfg.ApiKeys[i] == apiKey || cfg.ApiKeys[i] == authHeaderGoogle || cfg.ApiKeys[i] == authHeaderAnthropic || cfg.ApiKeys[i] == apiKeyQuery {
foundKey = cfg.ApiKeys[i]
break
}

View File

@@ -28,7 +28,7 @@ const (
apiVersion = "v1internal"
pluginVersion = "0.1.9"
glEndPoint = "https://generativelanguage.googleapis.com/"
glEndPoint = "https://generativelanguage.googleapis.com"
glApiVersion = "v1beta"
)
@@ -241,7 +241,7 @@ func (c *Client) makeAPIRequest(ctx context.Context, endpoint, method string, bo
}
// APIRequest handles making requests to the CLI API endpoints.
func (c *Client) APIRequest(ctx context.Context, endpoint string, body interface{}, stream bool) (io.ReadCloser, *ErrorMessage) {
func (c *Client) APIRequest(ctx context.Context, endpoint string, body interface{}, alt string, stream bool) (io.ReadCloser, *ErrorMessage) {
var jsonBody []byte
var err error
if byteBody, ok := body.([]byte); ok {
@@ -257,25 +257,39 @@ func (c *Client) APIRequest(ctx context.Context, endpoint string, body interface
if c.glAPIKey == "" {
// Add alt=sse for streaming
url = fmt.Sprintf("%s/%s:%s", codeAssistEndpoint, apiVersion, endpoint)
if stream {
if alt == "" && stream {
url = url + "?alt=sse"
} else {
if alt != "" {
url = url + fmt.Sprintf("?$alt=%s", alt)
}
}
} else {
modelResult := gjson.GetBytes(jsonBody, "model")
url = fmt.Sprintf("%s/%s/models/%s:%s", glEndPoint, glApiVersion, modelResult.String(), endpoint)
if stream {
url = url + "?alt=sse"
}
jsonBody = []byte(gjson.GetBytes(jsonBody, "request").Raw)
systemInstructionResult := gjson.GetBytes(jsonBody, "systemInstruction")
if systemInstructionResult.Exists() {
jsonBody, _ = sjson.SetRawBytes(jsonBody, "system_instruction", []byte(systemInstructionResult.Raw))
jsonBody, _ = sjson.DeleteBytes(jsonBody, "systemInstruction")
jsonBody, _ = sjson.DeleteBytes(jsonBody, "session_id")
if endpoint == "countTokens" {
modelResult := gjson.GetBytes(jsonBody, "model")
url = fmt.Sprintf("%s/%s/models/%s:%s", glEndPoint, glApiVersion, modelResult.String(), endpoint)
} else {
modelResult := gjson.GetBytes(jsonBody, "model")
url = fmt.Sprintf("%s/%s/models/%s:%s", glEndPoint, glApiVersion, modelResult.String(), endpoint)
if alt == "" && stream {
url = url + "?alt=sse"
} else {
if alt != "" {
url = url + fmt.Sprintf("?$alt=%s", alt)
}
}
jsonBody = []byte(gjson.GetBytes(jsonBody, "request").Raw)
systemInstructionResult := gjson.GetBytes(jsonBody, "systemInstruction")
if systemInstructionResult.Exists() {
jsonBody, _ = sjson.SetRawBytes(jsonBody, "system_instruction", []byte(systemInstructionResult.Raw))
jsonBody, _ = sjson.DeleteBytes(jsonBody, "systemInstruction")
jsonBody, _ = sjson.DeleteBytes(jsonBody, "session_id")
}
}
}
// log.Debug(string(jsonBody))
// log.Debug(url)
reqBody := bytes.NewBuffer(jsonBody)
req, err := http.NewRequestWithContext(ctx, "POST", url, reqBody)
@@ -392,7 +406,7 @@ func (c *Client) SendMessage(ctx context.Context, rawJson []byte, model string,
}
}
respBody, err := c.APIRequest(ctx, "generateContent", byteRequestBody, false)
respBody, err := c.APIRequest(ctx, "generateContent", byteRequestBody, "", false)
if err != nil {
if err.StatusCode == 429 {
now := time.Now()
@@ -544,7 +558,7 @@ func (c *Client) SendMessageStream(ctx context.Context, rawJson []byte, model st
// Attempt to establish a streaming connection with the API
var err *ErrorMessage
stream, err = c.APIRequest(ctx, "streamGenerateContent", byteRequestBody, true)
stream, err = c.APIRequest(ctx, "streamGenerateContent", byteRequestBody, "", true)
if err != nil {
// Handle quota exceeded errors by marking the model and potentially retrying
if err.StatusCode == 429 {
@@ -593,8 +607,49 @@ func (c *Client) SendMessageStream(ctx context.Context, rawJson []byte, model st
return dataChan, errChan
}
// SendRawTokenCount handles a token count.
func (c *Client) SendRawTokenCount(ctx context.Context, rawJson []byte, alt string) ([]byte, *ErrorMessage) {
modelResult := gjson.GetBytes(rawJson, "model")
model := modelResult.String()
modelName := model
for {
if c.isModelQuotaExceeded(modelName) {
if c.cfg.QuotaExceeded.SwitchPreviewModel && c.glAPIKey == "" {
modelName = c.getPreviewModel(model)
if modelName != "" {
log.Debugf("Model %s is quota exceeded. Switch to preview model %s", model, modelName)
rawJson, _ = sjson.SetBytes(rawJson, "model", modelName)
continue
}
}
return nil, &ErrorMessage{
StatusCode: 429,
Error: fmt.Errorf(`{"error":{"code":429,"message":"All the models of '%s' are quota exceeded","status":"RESOURCE_EXHAUSTED"}}`, model),
}
}
respBody, err := c.APIRequest(ctx, "countTokens", rawJson, alt, false)
if err != nil {
if err.StatusCode == 429 {
now := time.Now()
c.modelQuotaExceeded[modelName] = &now
if c.cfg.QuotaExceeded.SwitchPreviewModel && c.glAPIKey == "" {
continue
}
}
return nil, err
}
delete(c.modelQuotaExceeded, modelName)
bodyBytes, errReadAll := io.ReadAll(respBody)
if errReadAll != nil {
return nil, &ErrorMessage{StatusCode: 500, Error: errReadAll}
}
return bodyBytes, nil
}
}
// SendRawMessage handles a single conversational turn, including tool calls.
func (c *Client) SendRawMessage(ctx context.Context, rawJson []byte) ([]byte, *ErrorMessage) {
func (c *Client) SendRawMessage(ctx context.Context, rawJson []byte, alt string) ([]byte, *ErrorMessage) {
if c.glAPIKey == "" {
rawJson, _ = sjson.SetBytes(rawJson, "project", c.GetProjectID())
}
@@ -618,7 +673,7 @@ func (c *Client) SendRawMessage(ctx context.Context, rawJson []byte) ([]byte, *E
}
}
respBody, err := c.APIRequest(ctx, "generateContent", rawJson, false)
respBody, err := c.APIRequest(ctx, "generateContent", rawJson, alt, false)
if err != nil {
if err.StatusCode == 429 {
now := time.Now()
@@ -639,7 +694,7 @@ func (c *Client) SendRawMessage(ctx context.Context, rawJson []byte) ([]byte, *E
}
// SendRawMessageStream handles a single conversational turn, including tool calls.
func (c *Client) SendRawMessageStream(ctx context.Context, rawJson []byte) (<-chan []byte, <-chan *ErrorMessage) {
func (c *Client) SendRawMessageStream(ctx context.Context, rawJson []byte, alt string) (<-chan []byte, <-chan *ErrorMessage) {
dataTag := []byte("data: ")
errChan := make(chan *ErrorMessage)
dataChan := make(chan []byte)
@@ -672,7 +727,7 @@ func (c *Client) SendRawMessageStream(ctx context.Context, rawJson []byte) (<-ch
return
}
var err *ErrorMessage
stream, err = c.APIRequest(ctx, "streamGenerateContent", rawJson, true)
stream, err = c.APIRequest(ctx, "streamGenerateContent", rawJson, alt, true)
if err != nil {
if err.StatusCode == 429 {
now := time.Now()
@@ -688,21 +743,32 @@ func (c *Client) SendRawMessageStream(ctx context.Context, rawJson []byte) (<-ch
break
}
scanner := bufio.NewScanner(stream)
for scanner.Scan() {
line := scanner.Bytes()
if bytes.HasPrefix(line, dataTag) {
dataChan <- line[6:]
if alt == "" {
scanner := bufio.NewScanner(stream)
for scanner.Scan() {
line := scanner.Bytes()
if bytes.HasPrefix(line, dataTag) {
dataChan <- line[6:]
}
}
}
if errScanner := scanner.Err(); errScanner != nil {
errChan <- &ErrorMessage{500, errScanner}
_ = stream.Close()
return
}
if errScanner := scanner.Err(); errScanner != nil {
errChan <- &ErrorMessage{500, errScanner}
_ = stream.Close()
return
}
} else {
data, err := io.ReadAll(stream)
if err != nil {
errChan <- &ErrorMessage{500, err}
_ = stream.Close()
return
}
dataChan <- data
}
_ = stream.Close()
}()
return dataChan, errChan
@@ -754,7 +820,7 @@ func (c *Client) CheckCloudAPIIsEnabled() (bool, error) {
// A simple request to test the API endpoint.
requestBody := fmt.Sprintf(`{"project":"%s","request":{"contents":[{"role":"user","parts":[{"text":"Be concise. What is the capital of France?"}]}],"generationConfig":{"thinkingConfig":{"include_thoughts":false,"thinkingBudget":0}}},"model":"gemini-2.5-flash"}`, c.tokenStorage.ProjectID)
stream, err := c.APIRequest(ctx, "streamGenerateContent", []byte(requestBody), true)
stream, err := c.APIRequest(ctx, "streamGenerateContent", []byte(requestBody), "", true)
if err != nil {
// If a 403 Forbidden error occurs, it likely means the API is not enabled.
if err.StatusCode == 403 {

View File

@@ -8,6 +8,7 @@ import (
"github.com/luispater/CLIProxyAPI/internal/client"
"github.com/luispater/CLIProxyAPI/internal/config"
"github.com/luispater/CLIProxyAPI/internal/util"
"github.com/luispater/CLIProxyAPI/internal/watcher"
log "github.com/sirupsen/logrus"
"io/fs"
"net/http"
@@ -22,7 +23,7 @@ import (
// StartService initializes and starts the main API proxy service.
// It loads all available authentication tokens, creates a pool of clients,
// starts the API server, and handles graceful shutdown signals.
func StartService(cfg *config.Config) {
func StartService(cfg *config.Config, configPath string) {
// Create a pool of API clients, one for each token file found.
cliClients := make([]*client.Client, 0)
err := filepath.Walk(cfg.AuthDir, func(path string, info fs.FileInfo, err error) error {
@@ -82,10 +83,46 @@ func StartService(cfg *config.Config) {
// Create and start the API server with the pool of clients.
apiServer := api.NewServer(cfg, cliClients)
log.Infof("Starting API server on port %d", cfg.Port)
if err = apiServer.Start(); err != nil {
log.Fatalf("API server failed to start: %v", err)
// Start the API server in a goroutine so it doesn't block the main thread
go func() {
if err = apiServer.Start(); err != nil {
log.Fatalf("API server failed to start: %v", err)
}
}()
// Give the server a moment to start up
time.Sleep(100 * time.Millisecond)
log.Info("API server started successfully")
// Setup file watcher for config and auth directory changes
fileWatcher, errNewWatcher := watcher.NewWatcher(configPath, cfg.AuthDir, func(newClients []*client.Client, newCfg *config.Config) {
// Update the API server with new clients and configuration
apiServer.UpdateClients(newClients, newCfg)
})
if errNewWatcher != nil {
log.Fatalf("failed to create file watcher: %v", errNewWatcher)
}
// Set initial state for the watcher
fileWatcher.SetConfig(cfg)
fileWatcher.SetClients(cliClients)
// Start the file watcher
watcherCtx, watcherCancel := context.WithCancel(context.Background())
if errStartWatcher := fileWatcher.Start(watcherCtx); errStartWatcher != nil {
log.Fatalf("failed to start file watcher: %v", errStartWatcher)
}
log.Info("file watcher started for config and auth directory changes")
defer func() {
watcherCancel()
errStopWatcher := fileWatcher.Stop()
if errStopWatcher != nil {
log.Errorf("error stopping file watcher: %v", errStopWatcher)
}
}()
// Set up a channel to listen for OS signals for graceful shutdown.
sigChan := make(chan os.Signal, 1)
signal.Notify(sigChan, syscall.SIGINT, syscall.SIGTERM)

282
internal/watcher/watcher.go Normal file
View File

@@ -0,0 +1,282 @@
package watcher
import (
"context"
"encoding/json"
"github.com/fsnotify/fsnotify"
"github.com/luispater/CLIProxyAPI/internal/auth"
"github.com/luispater/CLIProxyAPI/internal/client"
"github.com/luispater/CLIProxyAPI/internal/config"
"github.com/luispater/CLIProxyAPI/internal/util"
log "github.com/sirupsen/logrus"
"io/fs"
"net/http"
"os"
"path/filepath"
"strings"
"sync"
"time"
)
// Watcher manages file watching for configuration and authentication files
type Watcher struct {
configPath string
authDir string
config *config.Config
clients []*client.Client
clientsMutex sync.RWMutex
reloadCallback func([]*client.Client, *config.Config)
watcher *fsnotify.Watcher
}
// NewWatcher creates a new file watcher instance
func NewWatcher(configPath, authDir string, reloadCallback func([]*client.Client, *config.Config)) (*Watcher, error) {
watcher, errNewWatcher := fsnotify.NewWatcher()
if errNewWatcher != nil {
return nil, errNewWatcher
}
return &Watcher{
configPath: configPath,
authDir: authDir,
reloadCallback: reloadCallback,
watcher: watcher,
}, nil
}
// Start begins watching the configuration file and authentication directory
func (w *Watcher) Start(ctx context.Context) error {
// Watch the config file
if errAddConfig := w.watcher.Add(w.configPath); errAddConfig != nil {
log.Errorf("failed to watch config file %s: %v", w.configPath, errAddConfig)
return errAddConfig
}
log.Debugf("watching config file: %s", w.configPath)
// Watch the auth directory
if errAddAuthDir := w.watcher.Add(w.authDir); errAddAuthDir != nil {
log.Errorf("failed to watch auth directory %s: %v", w.authDir, errAddAuthDir)
return errAddAuthDir
}
log.Debugf("watching auth directory: %s", w.authDir)
// Start the event processing goroutine
go w.processEvents(ctx)
return nil
}
// Stop stops the file watcher
func (w *Watcher) Stop() error {
return w.watcher.Close()
}
// SetConfig updates the current configuration
func (w *Watcher) SetConfig(cfg *config.Config) {
w.clientsMutex.Lock()
defer w.clientsMutex.Unlock()
w.config = cfg
}
// SetClients updates the current client list
func (w *Watcher) SetClients(clients []*client.Client) {
w.clientsMutex.Lock()
defer w.clientsMutex.Unlock()
w.clients = clients
}
// processEvents handles file system events
func (w *Watcher) processEvents(ctx context.Context) {
for {
select {
case <-ctx.Done():
return
case event, ok := <-w.watcher.Events:
if !ok {
return
}
w.handleEvent(event)
case errWatch, ok := <-w.watcher.Errors:
if !ok {
return
}
log.Errorf("file watcher error: %v", errWatch)
}
}
}
// handleEvent processes individual file system events
func (w *Watcher) handleEvent(event fsnotify.Event) {
now := time.Now()
log.Debugf("file system event detected: %s %s", event.Op.String(), event.Name)
// Handle config file changes
if event.Name == w.configPath && (event.Op&fsnotify.Write == fsnotify.Write || event.Op&fsnotify.Create == fsnotify.Create) {
log.Infof("config file changed, reloading: %s", w.configPath)
log.Debugf("config file change details - operation: %s, timestamp: %s", event.Op.String(), now.Format("2006-01-02 15:04:05.000"))
w.reloadConfig()
return
}
// Handle auth directory changes (only for .json files)
// Simplified: reload on any change to .json files in auth directory
if strings.HasPrefix(event.Name, w.authDir) && strings.HasSuffix(event.Name, ".json") {
log.Infof("auth file changed (%s): %s, reloading clients", event.Op.String(), filepath.Base(event.Name))
log.Debugf("auth file change details - operation: %s, file: %s, timestamp: %s",
event.Op.String(), filepath.Base(event.Name), now.Format("2006-01-02 15:04:05.000"))
w.reloadClients()
}
}
// reloadConfig reloads the configuration and triggers a full reload
func (w *Watcher) reloadConfig() {
log.Debugf("starting config reload from: %s", w.configPath)
newConfig, errLoadConfig := config.LoadConfig(w.configPath)
if errLoadConfig != nil {
log.Errorf("failed to reload config: %v", errLoadConfig)
return
}
w.clientsMutex.Lock()
oldConfig := w.config
w.config = newConfig
w.clientsMutex.Unlock()
// Log configuration changes in debug mode
if oldConfig != nil {
log.Debugf("config changes detected:")
if oldConfig.Port != newConfig.Port {
log.Debugf(" port: %d -> %d", oldConfig.Port, newConfig.Port)
}
if oldConfig.AuthDir != newConfig.AuthDir {
log.Debugf(" auth-dir: %s -> %s", oldConfig.AuthDir, newConfig.AuthDir)
}
if oldConfig.Debug != newConfig.Debug {
log.Debugf(" debug: %t -> %t", oldConfig.Debug, newConfig.Debug)
}
if oldConfig.ProxyUrl != newConfig.ProxyUrl {
log.Debugf(" proxy-url: %s -> %s", oldConfig.ProxyUrl, newConfig.ProxyUrl)
}
if len(oldConfig.ApiKeys) != len(newConfig.ApiKeys) {
log.Debugf(" api-keys count: %d -> %d", len(oldConfig.ApiKeys), len(newConfig.ApiKeys))
}
if len(oldConfig.GlAPIKey) != len(newConfig.GlAPIKey) {
log.Debugf(" generative-language-api-key count: %d -> %d", len(oldConfig.GlAPIKey), len(newConfig.GlAPIKey))
}
}
log.Infof("config successfully reloaded, triggering client reload")
// Reload clients with new config
w.reloadClients()
}
// reloadClients reloads all authentication clients
func (w *Watcher) reloadClients() {
log.Debugf("starting client reload process")
w.clientsMutex.RLock()
cfg := w.config
oldClientCount := len(w.clients)
w.clientsMutex.RUnlock()
if cfg == nil {
log.Error("config is nil, cannot reload clients")
return
}
log.Debugf("scanning auth directory: %s", cfg.AuthDir)
// Create new client list
newClients := make([]*client.Client, 0)
authFileCount := 0
successfulAuthCount := 0
// Load clients from auth directory
errWalk := filepath.Walk(cfg.AuthDir, func(path string, info fs.FileInfo, err error) error {
if err != nil {
log.Debugf("error accessing path %s: %v", path, err)
return err
}
// Process only JSON files in the auth directory
if !info.IsDir() && strings.HasSuffix(info.Name(), ".json") {
authFileCount++
log.Debugf("processing auth file %d: %s", authFileCount, filepath.Base(path))
f, errOpen := os.Open(path)
if errOpen != nil {
log.Errorf("failed to open token file %s: %v", path, errOpen)
return nil // Continue processing other files
}
defer func() {
errClose := f.Close()
if errClose != nil {
log.Errorf("failed to close token file %s: %v", path, errClose)
}
}()
// Decode the token storage file
var ts auth.TokenStorage
if errDecode := json.NewDecoder(f).Decode(&ts); errDecode == nil {
// For each valid token, create an authenticated client
clientCtx := context.Background()
log.Debugf(" initializing authentication for token from %s...", filepath.Base(path))
httpClient, errGetClient := auth.GetAuthenticatedClient(clientCtx, &ts, cfg)
if errGetClient != nil {
log.Errorf(" failed to get authenticated client for token %s: %v", path, errGetClient)
return nil // Continue processing other files
}
log.Debugf(" authentication successful for token from %s", filepath.Base(path))
// Add the new client to the pool
cliClient := client.NewClient(httpClient, &ts, cfg)
newClients = append(newClients, cliClient)
successfulAuthCount++
} else {
log.Errorf(" failed to decode token file %s: %v", path, errDecode)
}
}
return nil
})
if errWalk != nil {
log.Errorf("error walking auth directory: %v", errWalk)
return
}
log.Debugf("auth directory scan complete - found %d .json files, %d successful authentications", authFileCount, successfulAuthCount)
// Add clients for Generative Language API keys if configured
glApiKeyCount := 0
if len(cfg.GlAPIKey) > 0 {
log.Debugf("processing %d Generative Language API keys", len(cfg.GlAPIKey))
for i := 0; i < len(cfg.GlAPIKey); i++ {
httpClient, errSetProxy := util.SetProxy(cfg, &http.Client{})
if errSetProxy != nil {
log.Errorf("set proxy failed for GL API key %d: %v", i+1, errSetProxy)
continue
}
log.Debugf(" initializing with Generative Language API key %d...", i+1)
cliClient := client.NewClient(httpClient, nil, cfg, cfg.GlAPIKey[i])
newClients = append(newClients, cliClient)
glApiKeyCount++
}
log.Debugf("successfully initialized %d Generative Language API key clients", glApiKeyCount)
}
// Update the client list
w.clientsMutex.Lock()
w.clients = newClients
w.clientsMutex.Unlock()
log.Infof("client reload complete - old: %d clients, new: %d clients (%d auth files + %d GL API keys)",
oldClientCount, len(newClients), successfulAuthCount, glApiKeyCount)
// Trigger the callback to update the server
if w.reloadCallback != nil {
log.Debugf("triggering server update callback")
w.reloadCallback(newClients, cfg)
}
}