From ef68a97526320e3e05ec1da0326ad6c9a7780c36 Mon Sep 17 00:00:00 2001 From: Luis Pater Date: Thu, 10 Jul 2025 17:45:28 +0800 Subject: [PATCH] Refactor API handlers and proxy logic - Centralized `getClient` logic into a dedicated function to reduce redundancy. - Moved proxy initialization to a new utility function `SetProxy` in `internal/util/proxy.go`. - Replaced `Internal` handler with `CLIHandler` in `server.go` for improved clarity and consistency. - Removed unused functions and redundant HTTP client setup across the codebase for better maintainability. --- internal/api/cli-handlers.go | 228 +++++++++++++++++++ internal/api/handlers.go | 420 +++++------------------------------ internal/api/server.go | 2 +- internal/cmd/run.go | 33 +-- internal/util/proxy.go | 37 +++ 5 files changed, 323 insertions(+), 397 deletions(-) create mode 100644 internal/api/cli-handlers.go create mode 100644 internal/util/proxy.go diff --git a/internal/api/cli-handlers.go b/internal/api/cli-handlers.go new file mode 100644 index 00000000..2f342e8b --- /dev/null +++ b/internal/api/cli-handlers.go @@ -0,0 +1,228 @@ +package api + +import ( + "bytes" + "context" + "fmt" + "github.com/gin-gonic/gin" + "github.com/luispater/CLIProxyAPI/internal/client" + "github.com/luispater/CLIProxyAPI/internal/util" + log "github.com/sirupsen/logrus" + "github.com/tidwall/gjson" + "github.com/tidwall/sjson" + "io" + "net/http" + "time" +) + +func (h *APIHandlers) CLIHandler(c *gin.Context) { + rawJson, _ := c.GetRawData() + requestRawURI := c.Request.URL.Path + if requestRawURI == "/v1internal:generateContent" { + h.internalGenerateContent(c, rawJson) + } else if requestRawURI == "/v1internal:streamGenerateContent" { + h.internalStreamGenerateContent(c, rawJson) + } else { + reqBody := bytes.NewBuffer(rawJson) + req, err := http.NewRequest("POST", fmt.Sprintf("https://cloudcode-pa.googleapis.com%s", c.Request.URL.RequestURI()), reqBody) + if err != nil { + c.JSON(http.StatusBadRequest, ErrorResponse{ + Error: ErrorDetail{ + Message: fmt.Sprintf("Invalid request: %v", err), + Type: "invalid_request_error", + }, + }) + return + } + for key, value := range c.Request.Header { + req.Header[key] = value + } + + httpClient, err := util.SetProxy(h.cfg, &http.Client{}) + if err != nil { + log.Fatalf("set proxy failed: %v", err) + } + + resp, err := httpClient.Do(req) + if err != nil { + c.JSON(http.StatusBadRequest, ErrorResponse{ + Error: ErrorDetail{ + Message: fmt.Sprintf("Invalid request: %v", err), + Type: "invalid_request_error", + }, + }) + return + } + + if resp.StatusCode < 200 || resp.StatusCode >= 300 { + defer func() { + if err = resp.Body.Close(); err != nil { + log.Printf("warn: failed to close response body: %v", err) + } + }() + bodyBytes, _ := io.ReadAll(resp.Body) + + c.JSON(http.StatusBadRequest, ErrorResponse{ + Error: ErrorDetail{ + Message: string(bodyBytes), + Type: "invalid_request_error", + }, + }) + return + } + + defer func() { + _ = resp.Body.Close() + }() + + for key, value := range resp.Header { + c.Header(key, value[0]) + } + output, err := io.ReadAll(resp.Body) + if err != nil { + log.Errorf("Failed to read response body: %v", err) + return + } + _, _ = c.Writer.Write(output) + } +} + +func (h *APIHandlers) internalStreamGenerateContent(c *gin.Context, rawJson []byte) { + // Get the http.Flusher interface to manually flush the response. + flusher, ok := c.Writer.(http.Flusher) + if !ok { + c.JSON(http.StatusInternalServerError, ErrorResponse{ + Error: ErrorDetail{ + Message: "Streaming not supported", + Type: "server_error", + }, + }) + return + } + + modelResult := gjson.GetBytes(rawJson, "model") + modelName := modelResult.String() + + cliCtx, cliCancel := context.WithCancel(context.Background()) + var cliClient *client.Client + defer func() { + // Ensure the client's mutex is unlocked on function exit. + if cliClient != nil { + cliClient.RequestMutex.Unlock() + } + }() + +outLoop: + for { + var errorResponse *client.ErrorMessage + cliClient, errorResponse = h.getClient(modelName) + if errorResponse != nil { + c.Status(errorResponse.StatusCode) + _, _ = fmt.Fprint(c.Writer, errorResponse.Error) + flusher.Flush() + cliCancel() + return + } + + if glAPIKey := cliClient.GetGenerativeLanguageAPIKey(); glAPIKey != "" { + log.Debugf("Request use generative language API Key: %s", glAPIKey) + } else { + log.Debugf("Request use account: %s, project id: %s", cliClient.GetEmail(), cliClient.GetProjectID()) + } + // Send the message and receive response chunks and errors via channels. + respChan, errChan := cliClient.SendRawMessageStream(cliCtx, rawJson) + hasFirstResponse := false + for { + select { + // Handle client disconnection. + case <-c.Request.Context().Done(): + if c.Request.Context().Err().Error() == "context canceled" { + log.Debugf("Client disconnected: %v", c.Request.Context().Err()) + cliCancel() // Cancel the backend request. + return + } + // Process incoming response chunks. + case chunk, okStream := <-respChan: + if !okStream { + cliCancel() + return + } else { + hasFirstResponse = true + if cliClient.GetGenerativeLanguageAPIKey() != "" { + chunk, _ = sjson.SetRawBytes(chunk, "response", chunk) + } + _, _ = c.Writer.Write([]byte("data: ")) + _, _ = c.Writer.Write(chunk) + _, _ = c.Writer.Write([]byte("\n\n")) + flusher.Flush() + } + // Handle errors from the backend. + case err, okError := <-errChan: + if okError { + if err.StatusCode == 429 && h.cfg.QuotaExceeded.SwitchProject { + continue outLoop + } else { + c.Status(err.StatusCode) + _, _ = fmt.Fprint(c.Writer, err.Error.Error()) + flusher.Flush() + cliCancel() + } + return + } + // Send a keep-alive signal to the client. + case <-time.After(500 * time.Millisecond): + if hasFirstResponse { + _, _ = c.Writer.Write([]byte("\n")) + flusher.Flush() + } + } + } + } +} + +func (h *APIHandlers) internalGenerateContent(c *gin.Context, rawJson []byte) { + c.Header("Content-Type", "application/json") + + modelResult := gjson.GetBytes(rawJson, "model") + modelName := modelResult.String() + cliCtx, cliCancel := context.WithCancel(context.Background()) + var cliClient *client.Client + defer func() { + if cliClient != nil { + cliClient.RequestMutex.Unlock() + } + }() + + for { + var errorResponse *client.ErrorMessage + cliClient, errorResponse = h.getClient(modelName) + if errorResponse != nil { + c.Status(errorResponse.StatusCode) + _, _ = fmt.Fprint(c.Writer, errorResponse.Error) + cliCancel() + return + } + + if glAPIKey := cliClient.GetGenerativeLanguageAPIKey(); glAPIKey != "" { + log.Debugf("Request use generative language API Key: %s", glAPIKey) + } else { + log.Debugf("Request use account: %s, project id: %s", cliClient.GetEmail(), cliClient.GetProjectID()) + } + + resp, err := cliClient.SendRawMessage(cliCtx, rawJson) + if err != nil { + if err.StatusCode == 429 && h.cfg.QuotaExceeded.SwitchProject { + continue + } else { + c.Status(err.StatusCode) + _, _ = c.Writer.Write([]byte(err.Error.Error())) + cliCancel() + } + break + } else { + _, _ = c.Writer.Write(resp) + cliCancel() + break + } + } +} diff --git a/internal/api/handlers.go b/internal/api/handlers.go index 93404459..f627914d 100644 --- a/internal/api/handlers.go +++ b/internal/api/handlers.go @@ -1,7 +1,6 @@ package api import ( - "bytes" "context" "fmt" "github.com/luispater/CLIProxyAPI/internal/api/translator" @@ -9,12 +8,7 @@ import ( "github.com/luispater/CLIProxyAPI/internal/config" log "github.com/sirupsen/logrus" "github.com/tidwall/gjson" - "github.com/tidwall/sjson" - "golang.org/x/net/proxy" - "io" - "net" "net/http" - "net/url" "sync" "time" @@ -171,6 +165,48 @@ func (h *APIHandlers) Models(c *gin.Context) { }) } +func (h *APIHandlers) getClient(modelName string) (*client.Client, *client.ErrorMessage) { + var cliClient *client.Client + + // Lock the mutex to update the last used client index + mutex.Lock() + startIndex := lastUsedClientIndex + currentIndex := (startIndex + 1) % len(h.cliClients) + lastUsedClientIndex = currentIndex + mutex.Unlock() + + // Reorder the client to start from the last used index + reorderedClients := make([]*client.Client, 0) + for i := 0; i < len(h.cliClients); i++ { + cliClient = h.cliClients[(startIndex+1+i)%len(h.cliClients)] + if cliClient.IsModelQuotaExceeded(modelName) { + log.Debugf("Model %s is quota exceeded for account %s, project id: %s", modelName, cliClient.GetEmail(), cliClient.GetProjectID()) + cliClient = nil + continue + } + reorderedClients = append(reorderedClients, cliClient) + } + + if len(reorderedClients) == 0 { + return nil, &client.ErrorMessage{StatusCode: 429, Error: fmt.Errorf(`{"error":{"code":429,"message":"All the models of '%s' are quota exceeded","status":"RESOURCE_EXHAUSTED"}}`, modelName)} + } + + locked := false + for i := 0; i < len(reorderedClients); i++ { + cliClient = reorderedClients[i] + if cliClient.RequestMutex.TryLock() { + locked = true + break + } + } + if !locked { + cliClient = h.cliClients[0] + cliClient.RequestMutex.Lock() + } + + return cliClient, nil +} + // ChatCompletions handles the /v1/chat/completions endpoint. // It determines whether the request is for a streaming or non-streaming response // and calls the appropriate handler. @@ -212,45 +248,15 @@ func (h *APIHandlers) handleNonStreamingResponse(c *gin.Context, rawJson []byte) }() for { - // Lock the mutex to update the last used client index - mutex.Lock() - startIndex := lastUsedClientIndex - currentIndex := (startIndex + 1) % len(h.cliClients) - lastUsedClientIndex = currentIndex - mutex.Unlock() - - // Reorder the client to start from the last used index - reorderedClients := make([]*client.Client, 0) - for i := 0; i < len(h.cliClients); i++ { - cliClient = h.cliClients[(startIndex+1+i)%len(h.cliClients)] - if cliClient.IsModelQuotaExceeded(modelName) { - log.Debugf("Model %s is quota exceeded for account %s, project id: %s", modelName, cliClient.GetEmail(), cliClient.GetProjectID()) - cliClient = nil - continue - } - reorderedClients = append(reorderedClients, cliClient) - } - - if len(reorderedClients) == 0 { - c.Status(429) - _, _ = c.Writer.Write([]byte(fmt.Sprintf(`{"error":{"code":429,"message":"All the models of '%s' are quota exceeded","status":"RESOURCE_EXHAUSTED"}}`, modelName))) + var errorResponse *client.ErrorMessage + cliClient, errorResponse = h.getClient(modelName) + if errorResponse != nil { + c.Status(errorResponse.StatusCode) + _, _ = fmt.Fprint(c.Writer, errorResponse.Error) cliCancel() return } - locked := false - for i := 0; i < len(reorderedClients); i++ { - cliClient = reorderedClients[i] - if cliClient.RequestMutex.TryLock() { - locked = true - break - } - } - if !locked { - cliClient = h.cliClients[0] - cliClient.RequestMutex.Lock() - } - isGlAPIKey := false if glAPIKey := cliClient.GetGenerativeLanguageAPIKey(); glAPIKey != "" { log.Debugf("Request use generative language API Key: %s", glAPIKey) @@ -312,46 +318,16 @@ func (h *APIHandlers) handleStreamingResponse(c *gin.Context, rawJson []byte) { outLoop: for { - // Lock the mutex to update the last used client index - mutex.Lock() - startIndex := lastUsedClientIndex - currentIndex := (startIndex + 1) % len(h.cliClients) - lastUsedClientIndex = currentIndex - mutex.Unlock() - - // Reorder the client to start from the last used index - reorderedClients := make([]*client.Client, 0) - for i := 0; i < len(h.cliClients); i++ { - cliClient = h.cliClients[(startIndex+1+i)%len(h.cliClients)] - if cliClient.IsModelQuotaExceeded(modelName) { - log.Debugf("Model %s is quota exceeded for account %s, project id: %s", modelName, cliClient.GetEmail(), cliClient.GetProjectID()) - cliClient = nil - continue - } - reorderedClients = append(reorderedClients, cliClient) - } - - if len(reorderedClients) == 0 { - c.Status(429) - _, _ = fmt.Fprint(c.Writer, fmt.Sprintf(`{"error":{"code":429,"message":"All the models of '%s' are quota exceeded","status":"RESOURCE_EXHAUSTED"}}`, modelName)) + var errorResponse *client.ErrorMessage + cliClient, errorResponse = h.getClient(modelName) + if errorResponse != nil { + c.Status(errorResponse.StatusCode) + _, _ = fmt.Fprint(c.Writer, errorResponse.Error) flusher.Flush() cliCancel() return } - locked := false - for i := 0; i < len(reorderedClients); i++ { - cliClient = reorderedClients[i] - if cliClient.RequestMutex.TryLock() { - locked = true - break - } - } - if !locked { - cliClient = h.cliClients[0] - cliClient.RequestMutex.Lock() - } - isGlAPIKey := false if glAPIKey := cliClient.GetGenerativeLanguageAPIKey(); glAPIKey != "" { log.Debugf("Request use generative language API Key: %s", glAPIKey) @@ -411,295 +387,3 @@ outLoop: } } } - -func (h *APIHandlers) Internal(c *gin.Context) { - rawJson, _ := c.GetRawData() - requestRawURI := c.Request.URL.Path - if requestRawURI == "/v1internal:generateContent" { - h.internalGenerateContent(c, rawJson) - } else if requestRawURI == "/v1internal:streamGenerateContent" { - h.internalStreamGenerateContent(c, rawJson) - } else { - reqBody := bytes.NewBuffer(rawJson) - req, err := http.NewRequest("POST", fmt.Sprintf("https://cloudcode-pa.googleapis.com%s", c.Request.URL.RequestURI()), reqBody) - if err != nil { - c.JSON(http.StatusBadRequest, ErrorResponse{ - Error: ErrorDetail{ - Message: fmt.Sprintf("Invalid request: %v", err), - Type: "invalid_request_error", - }, - }) - return - } - for key, value := range c.Request.Header { - req.Header[key] = value - } - - var transport *http.Transport - proxyURL, errParse := url.Parse(h.cfg.ProxyUrl) - if errParse == nil { - if proxyURL.Scheme == "socks5" { - username := proxyURL.User.Username() - password, _ := proxyURL.User.Password() - proxyAuth := &proxy.Auth{User: username, Password: password} - dialer, errSOCKS5 := proxy.SOCKS5("tcp", proxyURL.Host, proxyAuth, proxy.Direct) - if errSOCKS5 != nil { - log.Fatalf("create SOCKS5 dialer failed: %v", errSOCKS5) - } - transport = &http.Transport{ - DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) { - return dialer.Dial(network, addr) - }, - } - } else if proxyURL.Scheme == "http" || proxyURL.Scheme == "https" { - transport = &http.Transport{Proxy: http.ProxyURL(proxyURL)} - } - } - httpClient := &http.Client{} - if transport != nil { - httpClient.Transport = transport - } - - resp, err := httpClient.Do(req) - if err != nil { - c.JSON(http.StatusBadRequest, ErrorResponse{ - Error: ErrorDetail{ - Message: fmt.Sprintf("Invalid request: %v", err), - Type: "invalid_request_error", - }, - }) - return - } - - if resp.StatusCode < 200 || resp.StatusCode >= 300 { - defer func() { - if err = resp.Body.Close(); err != nil { - log.Printf("warn: failed to close response body: %v", err) - } - }() - bodyBytes, _ := io.ReadAll(resp.Body) - - c.JSON(http.StatusBadRequest, ErrorResponse{ - Error: ErrorDetail{ - Message: string(bodyBytes), - Type: "invalid_request_error", - }, - }) - return - } - - defer func() { - _ = resp.Body.Close() - }() - - for key, value := range resp.Header { - c.Header(key, value[0]) - } - output, err := io.ReadAll(resp.Body) - if err != nil { - log.Errorf("Failed to read response body: %v", err) - return - } - _, _ = c.Writer.Write(output) - } -} - -func (h *APIHandlers) internalStreamGenerateContent(c *gin.Context, rawJson []byte) { - // Get the http.Flusher interface to manually flush the response. - flusher, ok := c.Writer.(http.Flusher) - if !ok { - c.JSON(http.StatusInternalServerError, ErrorResponse{ - Error: ErrorDetail{ - Message: "Streaming not supported", - Type: "server_error", - }, - }) - return - } - - modelResult := gjson.GetBytes(rawJson, "model") - modelName := modelResult.String() - - cliCtx, cliCancel := context.WithCancel(context.Background()) - var cliClient *client.Client - defer func() { - // Ensure the client's mutex is unlocked on function exit. - if cliClient != nil { - cliClient.RequestMutex.Unlock() - } - }() - -outLoop: - for { - // Lock the mutex to update the last used client index - mutex.Lock() - startIndex := lastUsedClientIndex - currentIndex := (startIndex + 1) % len(h.cliClients) - lastUsedClientIndex = currentIndex - mutex.Unlock() - - // Reorder the client to start from the last used index - reorderedClients := make([]*client.Client, 0) - for i := 0; i < len(h.cliClients); i++ { - cliClient = h.cliClients[(startIndex+1+i)%len(h.cliClients)] - if cliClient.IsModelQuotaExceeded(modelName) { - log.Debugf("Model %s is quota exceeded for account %s, project id: %s", modelName, cliClient.GetEmail(), cliClient.GetProjectID()) - cliClient = nil - continue - } - reorderedClients = append(reorderedClients, cliClient) - } - - if len(reorderedClients) == 0 { - c.Status(429) - _, _ = fmt.Fprint(c.Writer, fmt.Sprintf(`{"error":{"code":429,"message":"All the models of '%s' are quota exceeded","status":"RESOURCE_EXHAUSTED"}}`, modelName)) - flusher.Flush() - cliCancel() - return - } - - locked := false - for i := 0; i < len(reorderedClients); i++ { - cliClient = reorderedClients[i] - if cliClient.RequestMutex.TryLock() { - locked = true - break - } - } - if !locked { - cliClient = h.cliClients[0] - cliClient.RequestMutex.Lock() - } - - if glAPIKey := cliClient.GetGenerativeLanguageAPIKey(); glAPIKey != "" { - log.Debugf("Request use generative language API Key: %s", glAPIKey) - } else { - log.Debugf("Request use account: %s, project id: %s", cliClient.GetEmail(), cliClient.GetProjectID()) - } - // Send the message and receive response chunks and errors via channels. - respChan, errChan := cliClient.SendRawMessageStream(cliCtx, rawJson) - hasFirstResponse := false - for { - select { - // Handle client disconnection. - case <-c.Request.Context().Done(): - if c.Request.Context().Err().Error() == "context canceled" { - log.Debugf("Client disconnected: %v", c.Request.Context().Err()) - cliCancel() // Cancel the backend request. - return - } - // Process incoming response chunks. - case chunk, okStream := <-respChan: - if !okStream { - cliCancel() - return - } else { - hasFirstResponse = true - if cliClient.GetGenerativeLanguageAPIKey() != "" { - chunk, _ = sjson.SetRawBytes(chunk, "response", chunk) - } - _, _ = c.Writer.Write([]byte("data: ")) - _, _ = c.Writer.Write(chunk) - _, _ = c.Writer.Write([]byte("\n\n")) - flusher.Flush() - } - // Handle errors from the backend. - case err, okError := <-errChan: - if okError { - if err.StatusCode == 429 && h.cfg.QuotaExceeded.SwitchProject { - continue outLoop - } else { - c.Status(err.StatusCode) - _, _ = fmt.Fprint(c.Writer, err.Error.Error()) - flusher.Flush() - cliCancel() - } - return - } - // Send a keep-alive signal to the client. - case <-time.After(500 * time.Millisecond): - if hasFirstResponse { - _, _ = c.Writer.Write([]byte("\n")) - flusher.Flush() - } - } - } - } -} - -func (h *APIHandlers) internalGenerateContent(c *gin.Context, rawJson []byte) { - c.Header("Content-Type", "application/json") - - modelResult := gjson.GetBytes(rawJson, "model") - modelName := modelResult.String() - cliCtx, cliCancel := context.WithCancel(context.Background()) - var cliClient *client.Client - defer func() { - if cliClient != nil { - cliClient.RequestMutex.Unlock() - } - }() - - for { - // Lock the mutex to update the last used client index - mutex.Lock() - startIndex := lastUsedClientIndex - currentIndex := (startIndex + 1) % len(h.cliClients) - lastUsedClientIndex = currentIndex - mutex.Unlock() - - // Reorder the client to start from the last used index - reorderedClients := make([]*client.Client, 0) - for i := 0; i < len(h.cliClients); i++ { - cliClient = h.cliClients[(startIndex+1+i)%len(h.cliClients)] - if cliClient.IsModelQuotaExceeded(modelName) { - log.Debugf("Model %s is quota exceeded for account %s, project id: %s", modelName, cliClient.GetEmail(), cliClient.GetProjectID()) - cliClient = nil - continue - } - reorderedClients = append(reorderedClients, cliClient) - } - - if len(reorderedClients) == 0 { - c.Status(429) - _, _ = c.Writer.Write([]byte(fmt.Sprintf(`{"error":{"code":429,"message":"All the models of '%s' are quota exceeded","status":"RESOURCE_EXHAUSTED"}}`, modelName))) - cliCancel() - return - } - - locked := false - for i := 0; i < len(reorderedClients); i++ { - cliClient = reorderedClients[i] - if cliClient.RequestMutex.TryLock() { - locked = true - break - } - } - if !locked { - cliClient = h.cliClients[0] - cliClient.RequestMutex.Lock() - } - - if glAPIKey := cliClient.GetGenerativeLanguageAPIKey(); glAPIKey != "" { - log.Debugf("Request use generative language API Key: %s", glAPIKey) - } else { - log.Debugf("Request use account: %s, project id: %s", cliClient.GetEmail(), cliClient.GetProjectID()) - } - - resp, err := cliClient.SendRawMessage(cliCtx, rawJson) - if err != nil { - if err.StatusCode == 429 && h.cfg.QuotaExceeded.SwitchProject { - continue - } else { - c.Status(err.StatusCode) - _, _ = c.Writer.Write([]byte(err.Error.Error())) - cliCancel() - } - break - } else { - _, _ = c.Writer.Write(resp) - cliCancel() - break - } - } -} diff --git a/internal/api/server.go b/internal/api/server.go index 531c4a25..52b32006 100644 --- a/internal/api/server.go +++ b/internal/api/server.go @@ -81,7 +81,7 @@ func (s *Server) setupRoutes() { }, }) }) - s.engine.POST("/v1internal:method", s.handlers.Internal) + s.engine.POST("/v1internal:method", s.handlers.CLIHandler) } diff --git a/internal/cmd/run.go b/internal/cmd/run.go index 64df3043..a2ac7511 100644 --- a/internal/cmd/run.go +++ b/internal/cmd/run.go @@ -7,12 +7,10 @@ import ( "github.com/luispater/CLIProxyAPI/internal/auth" "github.com/luispater/CLIProxyAPI/internal/client" "github.com/luispater/CLIProxyAPI/internal/config" + "github.com/luispater/CLIProxyAPI/internal/util" log "github.com/sirupsen/logrus" - "golang.org/x/net/proxy" "io/fs" - "net" "net/http" - "net/url" "os" "os/signal" "path/filepath" @@ -69,33 +67,12 @@ func StartService(cfg *config.Config) { } if len(cfg.GlAPIKey) > 0 { - var transport *http.Transport - proxyURL, errParse := url.Parse(cfg.ProxyUrl) - if errParse == nil { - if proxyURL.Scheme == "socks5" { - username := proxyURL.User.Username() - password, _ := proxyURL.User.Password() - proxyAuth := &proxy.Auth{User: username, Password: password} - dialer, errSOCKS5 := proxy.SOCKS5("tcp", proxyURL.Host, proxyAuth, proxy.Direct) - if errSOCKS5 != nil { - log.Fatalf("create SOCKS5 dialer failed: %v", errSOCKS5) - } - transport = &http.Transport{ - DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) { - return dialer.Dial(network, addr) - }, - } - } else if proxyURL.Scheme == "http" || proxyURL.Scheme == "https" { - // Handle HTTP/HTTPS proxy. - transport = &http.Transport{Proxy: http.ProxyURL(proxyURL)} - } - } - for i := 0; i < len(cfg.GlAPIKey); i++ { - httpClient := &http.Client{} - if transport != nil { - httpClient.Transport = transport + httpClient, errSetProxy := util.SetProxy(cfg, &http.Client{}) + if errSetProxy != nil { + log.Fatalf("set proxy failed: %v", errSetProxy) } + log.Debug("Initializing with Generative Language API key...") cliClient := client.NewClient(httpClient, nil, cfg, cfg.GlAPIKey[i]) cliClients = append(cliClients, cliClient) diff --git a/internal/util/proxy.go b/internal/util/proxy.go new file mode 100644 index 00000000..f48c6fc3 --- /dev/null +++ b/internal/util/proxy.go @@ -0,0 +1,37 @@ +package util + +import ( + "context" + "github.com/luispater/CLIProxyAPI/internal/config" + "golang.org/x/net/proxy" + "net" + "net/http" + "net/url" +) + +func SetProxy(cfg *config.Config, httpClient *http.Client) (*http.Client, error) { + var transport *http.Transport + proxyURL, errParse := url.Parse(cfg.ProxyUrl) + if errParse == nil { + if proxyURL.Scheme == "socks5" { + username := proxyURL.User.Username() + password, _ := proxyURL.User.Password() + proxyAuth := &proxy.Auth{User: username, Password: password} + dialer, errSOCKS5 := proxy.SOCKS5("tcp", proxyURL.Host, proxyAuth, proxy.Direct) + if errSOCKS5 != nil { + return nil, errSOCKS5 + } + transport = &http.Transport{ + DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) { + return dialer.Dial(network, addr) + }, + } + } else if proxyURL.Scheme == "http" || proxyURL.Scheme == "https" { + transport = &http.Transport{Proxy: http.ProxyURL(proxyURL)} + } + } + if transport != nil { + httpClient.Transport = transport + } + return httpClient, nil +}