Compare commits

..

4 Commits

Author SHA1 Message Date
Luis Pater
aa9fd057fe Add FixCLIToolResponse for enhanced function call-response mapping
- Introduced `FixCLIToolResponse` in `translator` to group function calls with corresponding responses.
- Updated Gemini handlers to integrate new function for improved response handling.
- Enhanced error handling in case response mapping fails.
2025-07-11 10:17:25 +08:00
Luis Pater
b3607d3981 Add Gemini-compatible API and improve error handling
- Introduced a new Gemini-compatible API with routes under `/v1beta`.
- Added `GeminiHandler` to manage `generateContent` and `streamGenerateContent` actions.
- Enhanced `AuthMiddleware` to support `X-Goog-Api-Key` header.
- Improved client metadata handling and added conditional project ID updates in API calls.
- Updated logging to debug raw API request payloads for better traceability.
2025-07-11 04:01:45 +08:00
Luis Pater
fa8d94971f Enhance response and request handling in translators
- Refactored response handling to process multiple content parts effectively.
- Improved `tool_calls` structure with unique ID generation and enhanced mapping logic.
- Simplified `SystemInstruction` and tool message parsing in requests for better accuracy.
- Enhanced handling of function calls and tool responses with improved data integration.
2025-07-10 22:26:04 +08:00
Luis Pater
ef68a97526 Refactor API handlers and proxy logic
- Centralized `getClient` logic into a dedicated function to reduce redundancy.
- Moved proxy initialization to a new utility function `SetProxy` in `internal/util/proxy.go`.
- Replaced `Internal` handler with `CLIHandler` in `server.go` for improved clarity and consistency.
- Removed unused functions and redundant HTTP client setup across the codebase for better maintainability.
2025-07-10 17:45:28 +08:00
9 changed files with 840 additions and 446 deletions

View File

@@ -0,0 +1,228 @@
package api
import (
"bytes"
"context"
"fmt"
"github.com/gin-gonic/gin"
"github.com/luispater/CLIProxyAPI/internal/client"
"github.com/luispater/CLIProxyAPI/internal/util"
log "github.com/sirupsen/logrus"
"github.com/tidwall/gjson"
"github.com/tidwall/sjson"
"io"
"net/http"
"time"
)
func (h *APIHandlers) CLIHandler(c *gin.Context) {
rawJson, _ := c.GetRawData()
requestRawURI := c.Request.URL.Path
if requestRawURI == "/v1internal:generateContent" {
h.internalGenerateContent(c, rawJson)
} else if requestRawURI == "/v1internal:streamGenerateContent" {
h.internalStreamGenerateContent(c, rawJson)
} else {
reqBody := bytes.NewBuffer(rawJson)
req, err := http.NewRequest("POST", fmt.Sprintf("https://cloudcode-pa.googleapis.com%s", c.Request.URL.RequestURI()), reqBody)
if err != nil {
c.JSON(http.StatusBadRequest, ErrorResponse{
Error: ErrorDetail{
Message: fmt.Sprintf("Invalid request: %v", err),
Type: "invalid_request_error",
},
})
return
}
for key, value := range c.Request.Header {
req.Header[key] = value
}
httpClient, err := util.SetProxy(h.cfg, &http.Client{})
if err != nil {
log.Fatalf("set proxy failed: %v", err)
}
resp, err := httpClient.Do(req)
if err != nil {
c.JSON(http.StatusBadRequest, ErrorResponse{
Error: ErrorDetail{
Message: fmt.Sprintf("Invalid request: %v", err),
Type: "invalid_request_error",
},
})
return
}
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
defer func() {
if err = resp.Body.Close(); err != nil {
log.Printf("warn: failed to close response body: %v", err)
}
}()
bodyBytes, _ := io.ReadAll(resp.Body)
c.JSON(http.StatusBadRequest, ErrorResponse{
Error: ErrorDetail{
Message: string(bodyBytes),
Type: "invalid_request_error",
},
})
return
}
defer func() {
_ = resp.Body.Close()
}()
for key, value := range resp.Header {
c.Header(key, value[0])
}
output, err := io.ReadAll(resp.Body)
if err != nil {
log.Errorf("Failed to read response body: %v", err)
return
}
_, _ = c.Writer.Write(output)
}
}
func (h *APIHandlers) internalStreamGenerateContent(c *gin.Context, rawJson []byte) {
// Get the http.Flusher interface to manually flush the response.
flusher, ok := c.Writer.(http.Flusher)
if !ok {
c.JSON(http.StatusInternalServerError, ErrorResponse{
Error: ErrorDetail{
Message: "Streaming not supported",
Type: "server_error",
},
})
return
}
modelResult := gjson.GetBytes(rawJson, "model")
modelName := modelResult.String()
cliCtx, cliCancel := context.WithCancel(context.Background())
var cliClient *client.Client
defer func() {
// Ensure the client's mutex is unlocked on function exit.
if cliClient != nil {
cliClient.RequestMutex.Unlock()
}
}()
outLoop:
for {
var errorResponse *client.ErrorMessage
cliClient, errorResponse = h.getClient(modelName)
if errorResponse != nil {
c.Status(errorResponse.StatusCode)
_, _ = fmt.Fprint(c.Writer, errorResponse.Error)
flusher.Flush()
cliCancel()
return
}
if glAPIKey := cliClient.GetGenerativeLanguageAPIKey(); glAPIKey != "" {
log.Debugf("Request use generative language API Key: %s", glAPIKey)
} else {
log.Debugf("Request use account: %s, project id: %s", cliClient.GetEmail(), cliClient.GetProjectID())
}
// Send the message and receive response chunks and errors via channels.
respChan, errChan := cliClient.SendRawMessageStream(cliCtx, rawJson)
hasFirstResponse := false
for {
select {
// Handle client disconnection.
case <-c.Request.Context().Done():
if c.Request.Context().Err().Error() == "context canceled" {
log.Debugf("Client disconnected: %v", c.Request.Context().Err())
cliCancel() // Cancel the backend request.
return
}
// Process incoming response chunks.
case chunk, okStream := <-respChan:
if !okStream {
cliCancel()
return
} else {
hasFirstResponse = true
if cliClient.GetGenerativeLanguageAPIKey() != "" {
chunk, _ = sjson.SetRawBytes(chunk, "response", chunk)
}
_, _ = c.Writer.Write([]byte("data: "))
_, _ = c.Writer.Write(chunk)
_, _ = c.Writer.Write([]byte("\n\n"))
flusher.Flush()
}
// Handle errors from the backend.
case err, okError := <-errChan:
if okError {
if err.StatusCode == 429 && h.cfg.QuotaExceeded.SwitchProject {
continue outLoop
} else {
c.Status(err.StatusCode)
_, _ = fmt.Fprint(c.Writer, err.Error.Error())
flusher.Flush()
cliCancel()
}
return
}
// Send a keep-alive signal to the client.
case <-time.After(500 * time.Millisecond):
if hasFirstResponse {
_, _ = c.Writer.Write([]byte("\n"))
flusher.Flush()
}
}
}
}
}
func (h *APIHandlers) internalGenerateContent(c *gin.Context, rawJson []byte) {
c.Header("Content-Type", "application/json")
modelResult := gjson.GetBytes(rawJson, "model")
modelName := modelResult.String()
cliCtx, cliCancel := context.WithCancel(context.Background())
var cliClient *client.Client
defer func() {
if cliClient != nil {
cliClient.RequestMutex.Unlock()
}
}()
for {
var errorResponse *client.ErrorMessage
cliClient, errorResponse = h.getClient(modelName)
if errorResponse != nil {
c.Status(errorResponse.StatusCode)
_, _ = fmt.Fprint(c.Writer, errorResponse.Error)
cliCancel()
return
}
if glAPIKey := cliClient.GetGenerativeLanguageAPIKey(); glAPIKey != "" {
log.Debugf("Request use generative language API Key: %s", glAPIKey)
} else {
log.Debugf("Request use account: %s, project id: %s", cliClient.GetEmail(), cliClient.GetProjectID())
}
resp, err := cliClient.SendRawMessage(cliCtx, rawJson)
if err != nil {
if err.StatusCode == 429 && h.cfg.QuotaExceeded.SwitchProject {
continue
} else {
c.Status(err.StatusCode)
_, _ = c.Writer.Write([]byte(err.Error.Error()))
cliCancel()
}
break
} else {
_, _ = c.Writer.Write(resp)
cliCancel()
break
}
}
}

View File

@@ -0,0 +1,242 @@
package api
import (
"context"
"fmt"
"github.com/gin-gonic/gin"
"github.com/luispater/CLIProxyAPI/internal/api/translator"
"github.com/luispater/CLIProxyAPI/internal/client"
log "github.com/sirupsen/logrus"
"github.com/tidwall/gjson"
"github.com/tidwall/sjson"
"net/http"
"strings"
"time"
)
func (h *APIHandlers) GeminiHandler(c *gin.Context) {
var person struct {
Action string `uri:"action" binding:"required"`
}
if err := c.ShouldBindUri(&person); err != nil {
c.JSON(http.StatusBadRequest, ErrorResponse{
Error: ErrorDetail{
Message: fmt.Sprintf("Invalid request: %v", err),
Type: "invalid_request_error",
},
})
return
}
action := strings.Split(person.Action, ":")
if len(action) != 2 {
c.JSON(http.StatusNotFound, ErrorResponse{
Error: ErrorDetail{
Message: fmt.Sprintf("%s not found.", c.Request.URL.Path),
Type: "invalid_request_error",
},
})
return
}
modelName := action[0]
method := action[1]
rawJson, _ := c.GetRawData()
rawJson, _ = sjson.SetBytes(rawJson, "model", []byte(modelName))
if method == "generateContent" {
h.geminiGenerateContent(c, rawJson)
} else if method == "streamGenerateContent" {
h.geminiStreamGenerateContent(c, rawJson)
}
}
func (h *APIHandlers) geminiStreamGenerateContent(c *gin.Context, rawJson []byte) {
// Get the http.Flusher interface to manually flush the response.
flusher, ok := c.Writer.(http.Flusher)
if !ok {
c.JSON(http.StatusInternalServerError, ErrorResponse{
Error: ErrorDetail{
Message: "Streaming not supported",
Type: "server_error",
},
})
return
}
modelResult := gjson.GetBytes(rawJson, "model")
modelName := modelResult.String()
cliCtx, cliCancel := context.WithCancel(context.Background())
var cliClient *client.Client
defer func() {
// Ensure the client's mutex is unlocked on function exit.
if cliClient != nil {
cliClient.RequestMutex.Unlock()
}
}()
outLoop:
for {
var errorResponse *client.ErrorMessage
cliClient, errorResponse = h.getClient(modelName)
if errorResponse != nil {
c.Status(errorResponse.StatusCode)
_, _ = fmt.Fprint(c.Writer, errorResponse.Error)
flusher.Flush()
cliCancel()
return
}
template := `{"project":"","request":{},"model":""}`
template, _ = sjson.SetRaw(template, "request", string(rawJson))
template, _ = sjson.Set(template, "model", gjson.Get(template, "request.model").String())
template, _ = sjson.Delete(template, "request.model")
template, errFixCLIToolResponse := translator.FixCLIToolResponse(template)
if errFixCLIToolResponse != nil {
c.JSON(http.StatusInternalServerError, ErrorResponse{
Error: ErrorDetail{
Message: errFixCLIToolResponse.Error(),
Type: "server_error",
},
})
cliCancel()
return
}
systemInstructionResult := gjson.Get(template, "request.system_instruction")
if systemInstructionResult.Exists() {
template, _ = sjson.SetRaw(template, "request.systemInstruction", systemInstructionResult.Raw)
template, _ = sjson.Delete(template, "request.system_instruction")
}
rawJson = []byte(template)
if glAPIKey := cliClient.GetGenerativeLanguageAPIKey(); glAPIKey != "" {
log.Debugf("Request use generative language API Key: %s", glAPIKey)
} else {
log.Debugf("Request use account: %s, project id: %s", cliClient.GetEmail(), cliClient.GetProjectID())
}
// Send the message and receive response chunks and errors via channels.
respChan, errChan := cliClient.SendRawMessageStream(cliCtx, rawJson)
for {
select {
// Handle client disconnection.
case <-c.Request.Context().Done():
if c.Request.Context().Err().Error() == "context canceled" {
log.Debugf("Client disconnected: %v", c.Request.Context().Err())
cliCancel() // Cancel the backend request.
return
}
// Process incoming response chunks.
case chunk, okStream := <-respChan:
if !okStream {
cliCancel()
return
} else {
if cliClient.GetGenerativeLanguageAPIKey() == "" {
responseResult := gjson.GetBytes(chunk, "response")
if responseResult.Exists() {
chunk = []byte(responseResult.Raw)
}
}
_, _ = c.Writer.Write([]byte("data: "))
_, _ = c.Writer.Write(chunk)
_, _ = c.Writer.Write([]byte("\n\n"))
flusher.Flush()
}
// Handle errors from the backend.
case err, okError := <-errChan:
if okError {
if err.StatusCode == 429 && h.cfg.QuotaExceeded.SwitchProject {
continue outLoop
} else {
c.Status(err.StatusCode)
_, _ = fmt.Fprint(c.Writer, err.Error.Error())
flusher.Flush()
cliCancel()
}
return
}
// Send a keep-alive signal to the client.
case <-time.After(500 * time.Millisecond):
}
}
}
}
func (h *APIHandlers) geminiGenerateContent(c *gin.Context, rawJson []byte) {
c.Header("Content-Type", "application/json")
modelResult := gjson.GetBytes(rawJson, "model")
modelName := modelResult.String()
cliCtx, cliCancel := context.WithCancel(context.Background())
var cliClient *client.Client
defer func() {
if cliClient != nil {
cliClient.RequestMutex.Unlock()
}
}()
for {
var errorResponse *client.ErrorMessage
cliClient, errorResponse = h.getClient(modelName)
if errorResponse != nil {
c.Status(errorResponse.StatusCode)
_, _ = fmt.Fprint(c.Writer, errorResponse.Error)
cliCancel()
return
}
template := `{"project":"","request":{},"model":""}`
template, _ = sjson.SetRaw(template, "request", string(rawJson))
template, _ = sjson.Set(template, "model", gjson.Get(template, "request.model").String())
template, _ = sjson.Delete(template, "request.model")
template, errFixCLIToolResponse := translator.FixCLIToolResponse(template)
if errFixCLIToolResponse != nil {
c.JSON(http.StatusInternalServerError, ErrorResponse{
Error: ErrorDetail{
Message: errFixCLIToolResponse.Error(),
Type: "server_error",
},
})
cliCancel()
return
}
systemInstructionResult := gjson.Get(template, "request.system_instruction")
if systemInstructionResult.Exists() {
template, _ = sjson.SetRaw(template, "request.systemInstruction", systemInstructionResult.Raw)
template, _ = sjson.Delete(template, "request.system_instruction")
}
rawJson = []byte(template)
if glAPIKey := cliClient.GetGenerativeLanguageAPIKey(); glAPIKey != "" {
log.Debugf("Request use generative language API Key: %s", glAPIKey)
} else {
log.Debugf("Request use account: %s, project id: %s", cliClient.GetEmail(), cliClient.GetProjectID())
}
resp, err := cliClient.SendRawMessage(cliCtx, rawJson)
if err != nil {
if err.StatusCode == 429 && h.cfg.QuotaExceeded.SwitchProject {
continue
} else {
c.Status(err.StatusCode)
_, _ = c.Writer.Write([]byte(err.Error.Error()))
cliCancel()
}
break
} else {
if cliClient.GetGenerativeLanguageAPIKey() == "" {
responseResult := gjson.GetBytes(resp, "response")
if responseResult.Exists() {
resp = []byte(responseResult.Raw)
}
}
_, _ = c.Writer.Write(resp)
cliCancel()
break
}
}
}

View File

@@ -1,7 +1,6 @@
package api
import (
"bytes"
"context"
"fmt"
"github.com/luispater/CLIProxyAPI/internal/api/translator"
@@ -9,12 +8,7 @@ import (
"github.com/luispater/CLIProxyAPI/internal/config"
log "github.com/sirupsen/logrus"
"github.com/tidwall/gjson"
"github.com/tidwall/sjson"
"golang.org/x/net/proxy"
"io"
"net"
"net/http"
"net/url"
"sync"
"time"
@@ -171,6 +165,48 @@ func (h *APIHandlers) Models(c *gin.Context) {
})
}
func (h *APIHandlers) getClient(modelName string) (*client.Client, *client.ErrorMessage) {
var cliClient *client.Client
// Lock the mutex to update the last used client index
mutex.Lock()
startIndex := lastUsedClientIndex
currentIndex := (startIndex + 1) % len(h.cliClients)
lastUsedClientIndex = currentIndex
mutex.Unlock()
// Reorder the client to start from the last used index
reorderedClients := make([]*client.Client, 0)
for i := 0; i < len(h.cliClients); i++ {
cliClient = h.cliClients[(startIndex+1+i)%len(h.cliClients)]
if cliClient.IsModelQuotaExceeded(modelName) {
log.Debugf("Model %s is quota exceeded for account %s, project id: %s", modelName, cliClient.GetEmail(), cliClient.GetProjectID())
cliClient = nil
continue
}
reorderedClients = append(reorderedClients, cliClient)
}
if len(reorderedClients) == 0 {
return nil, &client.ErrorMessage{StatusCode: 429, Error: fmt.Errorf(`{"error":{"code":429,"message":"All the models of '%s' are quota exceeded","status":"RESOURCE_EXHAUSTED"}}`, modelName)}
}
locked := false
for i := 0; i < len(reorderedClients); i++ {
cliClient = reorderedClients[i]
if cliClient.RequestMutex.TryLock() {
locked = true
break
}
}
if !locked {
cliClient = h.cliClients[0]
cliClient.RequestMutex.Lock()
}
return cliClient, nil
}
// ChatCompletions handles the /v1/chat/completions endpoint.
// It determines whether the request is for a streaming or non-streaming response
// and calls the appropriate handler.
@@ -212,45 +248,15 @@ func (h *APIHandlers) handleNonStreamingResponse(c *gin.Context, rawJson []byte)
}()
for {
// Lock the mutex to update the last used client index
mutex.Lock()
startIndex := lastUsedClientIndex
currentIndex := (startIndex + 1) % len(h.cliClients)
lastUsedClientIndex = currentIndex
mutex.Unlock()
// Reorder the client to start from the last used index
reorderedClients := make([]*client.Client, 0)
for i := 0; i < len(h.cliClients); i++ {
cliClient = h.cliClients[(startIndex+1+i)%len(h.cliClients)]
if cliClient.IsModelQuotaExceeded(modelName) {
log.Debugf("Model %s is quota exceeded for account %s, project id: %s", modelName, cliClient.GetEmail(), cliClient.GetProjectID())
cliClient = nil
continue
}
reorderedClients = append(reorderedClients, cliClient)
}
if len(reorderedClients) == 0 {
c.Status(429)
_, _ = c.Writer.Write([]byte(fmt.Sprintf(`{"error":{"code":429,"message":"All the models of '%s' are quota exceeded","status":"RESOURCE_EXHAUSTED"}}`, modelName)))
var errorResponse *client.ErrorMessage
cliClient, errorResponse = h.getClient(modelName)
if errorResponse != nil {
c.Status(errorResponse.StatusCode)
_, _ = fmt.Fprint(c.Writer, errorResponse.Error)
cliCancel()
return
}
locked := false
for i := 0; i < len(reorderedClients); i++ {
cliClient = reorderedClients[i]
if cliClient.RequestMutex.TryLock() {
locked = true
break
}
}
if !locked {
cliClient = h.cliClients[0]
cliClient.RequestMutex.Lock()
}
isGlAPIKey := false
if glAPIKey := cliClient.GetGenerativeLanguageAPIKey(); glAPIKey != "" {
log.Debugf("Request use generative language API Key: %s", glAPIKey)
@@ -312,46 +318,16 @@ func (h *APIHandlers) handleStreamingResponse(c *gin.Context, rawJson []byte) {
outLoop:
for {
// Lock the mutex to update the last used client index
mutex.Lock()
startIndex := lastUsedClientIndex
currentIndex := (startIndex + 1) % len(h.cliClients)
lastUsedClientIndex = currentIndex
mutex.Unlock()
// Reorder the client to start from the last used index
reorderedClients := make([]*client.Client, 0)
for i := 0; i < len(h.cliClients); i++ {
cliClient = h.cliClients[(startIndex+1+i)%len(h.cliClients)]
if cliClient.IsModelQuotaExceeded(modelName) {
log.Debugf("Model %s is quota exceeded for account %s, project id: %s", modelName, cliClient.GetEmail(), cliClient.GetProjectID())
cliClient = nil
continue
}
reorderedClients = append(reorderedClients, cliClient)
}
if len(reorderedClients) == 0 {
c.Status(429)
_, _ = fmt.Fprint(c.Writer, fmt.Sprintf(`{"error":{"code":429,"message":"All the models of '%s' are quota exceeded","status":"RESOURCE_EXHAUSTED"}}`, modelName))
var errorResponse *client.ErrorMessage
cliClient, errorResponse = h.getClient(modelName)
if errorResponse != nil {
c.Status(errorResponse.StatusCode)
_, _ = fmt.Fprint(c.Writer, errorResponse.Error)
flusher.Flush()
cliCancel()
return
}
locked := false
for i := 0; i < len(reorderedClients); i++ {
cliClient = reorderedClients[i]
if cliClient.RequestMutex.TryLock() {
locked = true
break
}
}
if !locked {
cliClient = h.cliClients[0]
cliClient.RequestMutex.Lock()
}
isGlAPIKey := false
if glAPIKey := cliClient.GetGenerativeLanguageAPIKey(); glAPIKey != "" {
log.Debugf("Request use generative language API Key: %s", glAPIKey)
@@ -411,295 +387,3 @@ outLoop:
}
}
}
func (h *APIHandlers) Internal(c *gin.Context) {
rawJson, _ := c.GetRawData()
requestRawURI := c.Request.URL.Path
if requestRawURI == "/v1internal:generateContent" {
h.internalGenerateContent(c, rawJson)
} else if requestRawURI == "/v1internal:streamGenerateContent" {
h.internalStreamGenerateContent(c, rawJson)
} else {
reqBody := bytes.NewBuffer(rawJson)
req, err := http.NewRequest("POST", fmt.Sprintf("https://cloudcode-pa.googleapis.com%s", c.Request.URL.RequestURI()), reqBody)
if err != nil {
c.JSON(http.StatusBadRequest, ErrorResponse{
Error: ErrorDetail{
Message: fmt.Sprintf("Invalid request: %v", err),
Type: "invalid_request_error",
},
})
return
}
for key, value := range c.Request.Header {
req.Header[key] = value
}
var transport *http.Transport
proxyURL, errParse := url.Parse(h.cfg.ProxyUrl)
if errParse == nil {
if proxyURL.Scheme == "socks5" {
username := proxyURL.User.Username()
password, _ := proxyURL.User.Password()
proxyAuth := &proxy.Auth{User: username, Password: password}
dialer, errSOCKS5 := proxy.SOCKS5("tcp", proxyURL.Host, proxyAuth, proxy.Direct)
if errSOCKS5 != nil {
log.Fatalf("create SOCKS5 dialer failed: %v", errSOCKS5)
}
transport = &http.Transport{
DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
return dialer.Dial(network, addr)
},
}
} else if proxyURL.Scheme == "http" || proxyURL.Scheme == "https" {
transport = &http.Transport{Proxy: http.ProxyURL(proxyURL)}
}
}
httpClient := &http.Client{}
if transport != nil {
httpClient.Transport = transport
}
resp, err := httpClient.Do(req)
if err != nil {
c.JSON(http.StatusBadRequest, ErrorResponse{
Error: ErrorDetail{
Message: fmt.Sprintf("Invalid request: %v", err),
Type: "invalid_request_error",
},
})
return
}
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
defer func() {
if err = resp.Body.Close(); err != nil {
log.Printf("warn: failed to close response body: %v", err)
}
}()
bodyBytes, _ := io.ReadAll(resp.Body)
c.JSON(http.StatusBadRequest, ErrorResponse{
Error: ErrorDetail{
Message: string(bodyBytes),
Type: "invalid_request_error",
},
})
return
}
defer func() {
_ = resp.Body.Close()
}()
for key, value := range resp.Header {
c.Header(key, value[0])
}
output, err := io.ReadAll(resp.Body)
if err != nil {
log.Errorf("Failed to read response body: %v", err)
return
}
_, _ = c.Writer.Write(output)
}
}
func (h *APIHandlers) internalStreamGenerateContent(c *gin.Context, rawJson []byte) {
// Get the http.Flusher interface to manually flush the response.
flusher, ok := c.Writer.(http.Flusher)
if !ok {
c.JSON(http.StatusInternalServerError, ErrorResponse{
Error: ErrorDetail{
Message: "Streaming not supported",
Type: "server_error",
},
})
return
}
modelResult := gjson.GetBytes(rawJson, "model")
modelName := modelResult.String()
cliCtx, cliCancel := context.WithCancel(context.Background())
var cliClient *client.Client
defer func() {
// Ensure the client's mutex is unlocked on function exit.
if cliClient != nil {
cliClient.RequestMutex.Unlock()
}
}()
outLoop:
for {
// Lock the mutex to update the last used client index
mutex.Lock()
startIndex := lastUsedClientIndex
currentIndex := (startIndex + 1) % len(h.cliClients)
lastUsedClientIndex = currentIndex
mutex.Unlock()
// Reorder the client to start from the last used index
reorderedClients := make([]*client.Client, 0)
for i := 0; i < len(h.cliClients); i++ {
cliClient = h.cliClients[(startIndex+1+i)%len(h.cliClients)]
if cliClient.IsModelQuotaExceeded(modelName) {
log.Debugf("Model %s is quota exceeded for account %s, project id: %s", modelName, cliClient.GetEmail(), cliClient.GetProjectID())
cliClient = nil
continue
}
reorderedClients = append(reorderedClients, cliClient)
}
if len(reorderedClients) == 0 {
c.Status(429)
_, _ = fmt.Fprint(c.Writer, fmt.Sprintf(`{"error":{"code":429,"message":"All the models of '%s' are quota exceeded","status":"RESOURCE_EXHAUSTED"}}`, modelName))
flusher.Flush()
cliCancel()
return
}
locked := false
for i := 0; i < len(reorderedClients); i++ {
cliClient = reorderedClients[i]
if cliClient.RequestMutex.TryLock() {
locked = true
break
}
}
if !locked {
cliClient = h.cliClients[0]
cliClient.RequestMutex.Lock()
}
if glAPIKey := cliClient.GetGenerativeLanguageAPIKey(); glAPIKey != "" {
log.Debugf("Request use generative language API Key: %s", glAPIKey)
} else {
log.Debugf("Request use account: %s, project id: %s", cliClient.GetEmail(), cliClient.GetProjectID())
}
// Send the message and receive response chunks and errors via channels.
respChan, errChan := cliClient.SendRawMessageStream(cliCtx, rawJson)
hasFirstResponse := false
for {
select {
// Handle client disconnection.
case <-c.Request.Context().Done():
if c.Request.Context().Err().Error() == "context canceled" {
log.Debugf("Client disconnected: %v", c.Request.Context().Err())
cliCancel() // Cancel the backend request.
return
}
// Process incoming response chunks.
case chunk, okStream := <-respChan:
if !okStream {
cliCancel()
return
} else {
hasFirstResponse = true
if cliClient.GetGenerativeLanguageAPIKey() != "" {
chunk, _ = sjson.SetRawBytes(chunk, "response", chunk)
}
_, _ = c.Writer.Write([]byte("data: "))
_, _ = c.Writer.Write(chunk)
_, _ = c.Writer.Write([]byte("\n\n"))
flusher.Flush()
}
// Handle errors from the backend.
case err, okError := <-errChan:
if okError {
if err.StatusCode == 429 && h.cfg.QuotaExceeded.SwitchProject {
continue outLoop
} else {
c.Status(err.StatusCode)
_, _ = fmt.Fprint(c.Writer, err.Error.Error())
flusher.Flush()
cliCancel()
}
return
}
// Send a keep-alive signal to the client.
case <-time.After(500 * time.Millisecond):
if hasFirstResponse {
_, _ = c.Writer.Write([]byte("\n"))
flusher.Flush()
}
}
}
}
}
func (h *APIHandlers) internalGenerateContent(c *gin.Context, rawJson []byte) {
c.Header("Content-Type", "application/json")
modelResult := gjson.GetBytes(rawJson, "model")
modelName := modelResult.String()
cliCtx, cliCancel := context.WithCancel(context.Background())
var cliClient *client.Client
defer func() {
if cliClient != nil {
cliClient.RequestMutex.Unlock()
}
}()
for {
// Lock the mutex to update the last used client index
mutex.Lock()
startIndex := lastUsedClientIndex
currentIndex := (startIndex + 1) % len(h.cliClients)
lastUsedClientIndex = currentIndex
mutex.Unlock()
// Reorder the client to start from the last used index
reorderedClients := make([]*client.Client, 0)
for i := 0; i < len(h.cliClients); i++ {
cliClient = h.cliClients[(startIndex+1+i)%len(h.cliClients)]
if cliClient.IsModelQuotaExceeded(modelName) {
log.Debugf("Model %s is quota exceeded for account %s, project id: %s", modelName, cliClient.GetEmail(), cliClient.GetProjectID())
cliClient = nil
continue
}
reorderedClients = append(reorderedClients, cliClient)
}
if len(reorderedClients) == 0 {
c.Status(429)
_, _ = c.Writer.Write([]byte(fmt.Sprintf(`{"error":{"code":429,"message":"All the models of '%s' are quota exceeded","status":"RESOURCE_EXHAUSTED"}}`, modelName)))
cliCancel()
return
}
locked := false
for i := 0; i < len(reorderedClients); i++ {
cliClient = reorderedClients[i]
if cliClient.RequestMutex.TryLock() {
locked = true
break
}
}
if !locked {
cliClient = h.cliClients[0]
cliClient.RequestMutex.Lock()
}
if glAPIKey := cliClient.GetGenerativeLanguageAPIKey(); glAPIKey != "" {
log.Debugf("Request use generative language API Key: %s", glAPIKey)
} else {
log.Debugf("Request use account: %s, project id: %s", cliClient.GetEmail(), cliClient.GetProjectID())
}
resp, err := cliClient.SendRawMessage(cliCtx, rawJson)
if err != nil {
if err.StatusCode == 429 && h.cfg.QuotaExceeded.SwitchProject {
continue
} else {
c.Status(err.StatusCode)
_, _ = c.Writer.Write([]byte(err.Error.Error()))
cliCancel()
}
break
} else {
_, _ = c.Writer.Write(resp)
cliCancel()
break
}
}
}

View File

@@ -70,6 +70,14 @@ func (s *Server) setupRoutes() {
v1.POST("/chat/completions", s.handlers.ChatCompletions)
}
// Gemini compatible API routes
v1beta := s.engine.Group("/v1beta")
v1beta.Use(AuthMiddleware(s.cfg))
{
v1beta.GET("/models", s.handlers.Models)
v1beta.POST("/models/:action", s.handlers.GeminiHandler)
}
// Root endpoint
s.engine.GET("/", func(c *gin.Context) {
c.JSON(http.StatusOK, gin.H{
@@ -81,7 +89,7 @@ func (s *Server) setupRoutes() {
},
})
})
s.engine.POST("/v1internal:method", s.handlers.Internal)
s.engine.POST("/v1internal:method", s.handlers.CLIHandler)
}
@@ -140,7 +148,8 @@ func AuthMiddleware(cfg *config.Config) gin.HandlerFunc {
// Get the Authorization header
authHeader := c.GetHeader("Authorization")
if authHeader == "" {
authHeaderGoogle := c.GetHeader("X-Goog-Api-Key")
if authHeader == "" && authHeaderGoogle == "" {
c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{
"error": "Missing API key",
})
@@ -159,7 +168,7 @@ func AuthMiddleware(cfg *config.Config) gin.HandlerFunc {
// Find the API key in the in-memory list
var foundKey string
for i := range cfg.ApiKeys {
if cfg.ApiKeys[i] == apiKey {
if cfg.ApiKeys[i] == apiKey || cfg.ApiKeys[i] == authHeaderGoogle {
foundKey = cfg.ApiKeys[i]
break
}

View File

@@ -2,6 +2,8 @@ package translator
import (
"encoding/json"
"fmt"
"github.com/tidwall/sjson"
"strings"
"github.com/luispater/CLIProxyAPI/internal/client"
@@ -24,6 +26,39 @@ func PrepareRequest(rawJson []byte) (string, *client.Content, []client.Content,
contents := make([]client.Content, 0)
var systemInstruction *client.Content
messagesResult := gjson.GetBytes(rawJson, "messages")
toolItems := make(map[string]*client.FunctionResponse)
if messagesResult.IsArray() {
messagesResults := messagesResult.Array()
for i := 0; i < len(messagesResults); i++ {
messageResult := messagesResults[i]
roleResult := messageResult.Get("role")
if roleResult.Type != gjson.String {
continue
}
contentResult := messageResult.Get("content")
if roleResult.String() == "tool" {
toolCallID := messageResult.Get("tool_call_id").String()
if toolCallID != "" {
var responseData string
if contentResult.Type == gjson.String {
responseData = contentResult.String()
} else if contentResult.IsObject() && contentResult.Get("type").String() == "text" {
responseData = contentResult.Get("text").String()
}
// drop the timestamp from the tool call ID
toolCallIDs := strings.Split(toolCallID, "-")
strings.Join(toolCallIDs, "-")
newToolCallID := strings.Join(toolCallIDs[:len(toolCallIDs)-1], "-")
functionResponse := client.FunctionResponse{Name: newToolCallID, Response: map[string]interface{}{"result": responseData}}
toolItems[toolCallID] = &functionResponse
}
}
}
}
if messagesResult.IsArray() {
messagesResults := messagesResult.Array()
for i := 0; i < len(messagesResults); i++ {
@@ -97,40 +132,44 @@ func PrepareRequest(rawJson []byte) (string, *client.Content, []client.Content,
contents = append(contents, client.Content{Role: "model", Parts: []client.Part{{Text: contentResult.String()}}})
} else if !contentResult.Exists() || contentResult.Type == gjson.Null {
// Handle tool calls made by the assistant.
functionIDs := make([]string, 0)
toolCallsResult := messageResult.Get("tool_calls")
if toolCallsResult.IsArray() {
parts := make([]client.Part, 0)
tcsResult := toolCallsResult.Array()
for j := 0; j < len(tcsResult); j++ {
tcResult := tcsResult[j]
functionID := tcResult.Get("id").String()
functionIDs = append(functionIDs, functionID)
functionName := tcResult.Get("function.name").String()
functionArgs := tcResult.Get("function.arguments").String()
var args map[string]any
if err := json.Unmarshal([]byte(functionArgs), &args); err == nil {
contents = append(contents, client.Content{
Role: "model", Parts: []client.Part{{
FunctionCall: &client.FunctionCall{
Name: functionName,
Args: args,
},
}},
parts = append(parts, client.Part{
FunctionCall: &client.FunctionCall{
Name: functionName,
Args: args,
},
})
}
}
if len(parts) > 0 {
contents = append(contents, client.Content{
Role: "model", Parts: parts,
})
toolParts := make([]client.Part, 0)
for _, functionID := range functionIDs {
if functionResponse, ok := toolItems[functionID]; ok {
toolParts = append(toolParts, client.Part{FunctionResponse: functionResponse})
}
}
contents = append(contents, client.Content{Role: "tool", Parts: toolParts})
}
}
}
// Tool messages contain the output of a tool call.
case "tool":
toolCallID := messageResult.Get("tool_call_id").String()
if toolCallID != "" {
var responseData string
if contentResult.Type == gjson.String {
responseData = contentResult.String()
} else if contentResult.IsObject() && contentResult.Get("type").String() == "text" {
responseData = contentResult.Get("text").String()
}
functionResponse := client.FunctionResponse{Name: toolCallID, Response: map[string]interface{}{"result": responseData}}
contents = append(contents, client.Content{Role: "tool", Parts: []client.Part{{FunctionResponse: &functionResponse}}})
}
}
}
}
@@ -160,3 +199,167 @@ func PrepareRequest(rawJson []byte) (string, *client.Content, []client.Content,
return modelName, systemInstruction, contents, tools
}
// FunctionCallGroup represents a group of function calls and their responses
type FunctionCallGroup struct {
ModelContent map[string]interface{}
FunctionCalls []gjson.Result
ResponsesNeeded int
}
// FixCLIToolResponse converts the format from 1.json to 2.json
// It groups function calls with their corresponding responses
func FixCLIToolResponse(input string) (string, error) {
// Parse the input JSON
parsed := gjson.Parse(input)
// Get the contents array
contents := parsed.Get("request.contents")
if !contents.Exists() {
return input, fmt.Errorf("contents not found in input")
}
var newContents []interface{}
var pendingGroups []*FunctionCallGroup
var collectedResponses []gjson.Result
// Process each content object
contents.ForEach(func(key, value gjson.Result) bool {
role := value.Get("role").String()
parts := value.Get("parts")
// Check if this content has function responses
var responsePartsInThisContent []gjson.Result
parts.ForEach(func(_, part gjson.Result) bool {
if part.Get("functionResponse").Exists() {
responsePartsInThisContent = append(responsePartsInThisContent, part)
}
return true
})
// If this content has function responses, collect them
if len(responsePartsInThisContent) > 0 {
collectedResponses = append(collectedResponses, responsePartsInThisContent...)
// Check if any pending groups can be satisfied
for i := len(pendingGroups) - 1; i >= 0; i-- {
group := pendingGroups[i]
if len(collectedResponses) >= group.ResponsesNeeded {
// Take the needed responses for this group
groupResponses := collectedResponses[:group.ResponsesNeeded]
collectedResponses = collectedResponses[group.ResponsesNeeded:]
// Create merged function response content
var responseParts []interface{}
for _, response := range groupResponses {
var responseMap map[string]interface{}
errUnmarshal := json.Unmarshal([]byte(response.Raw), &responseMap)
if errUnmarshal != nil {
log.Warnf("failed to unmarshal function response: %v\n", errUnmarshal)
continue
}
responseParts = append(responseParts, responseMap)
}
if len(responseParts) > 0 {
functionResponseContent := map[string]interface{}{
"parts": responseParts,
"role": "function",
}
newContents = append(newContents, functionResponseContent)
}
// Remove this group as it's been satisfied
pendingGroups = append(pendingGroups[:i], pendingGroups[i+1:]...)
break
}
}
return true // Skip adding this content, responses are merged
}
// If this is a model with function calls, create a new group
if role == "model" {
var functionCallsInThisModel []gjson.Result
parts.ForEach(func(_, part gjson.Result) bool {
if part.Get("functionCall").Exists() {
functionCallsInThisModel = append(functionCallsInThisModel, part)
}
return true
})
if len(functionCallsInThisModel) > 0 {
// Add the model content
var contentMap map[string]interface{}
errUnmarshal := json.Unmarshal([]byte(value.Raw), &contentMap)
if errUnmarshal != nil {
log.Warnf("failed to unmarshal model content: %v\n", errUnmarshal)
return true
}
newContents = append(newContents, contentMap)
// Create a new group for tracking responses
group := &FunctionCallGroup{
ModelContent: contentMap,
FunctionCalls: functionCallsInThisModel,
ResponsesNeeded: len(functionCallsInThisModel),
}
pendingGroups = append(pendingGroups, group)
} else {
// Regular model content without function calls
var contentMap map[string]interface{}
errUnmarshal := json.Unmarshal([]byte(value.Raw), &contentMap)
if errUnmarshal != nil {
log.Warnf("failed to unmarshal content: %v\n", errUnmarshal)
return true
}
newContents = append(newContents, contentMap)
}
} else {
// Non-model content (user, etc.)
var contentMap map[string]interface{}
errUnmarshal := json.Unmarshal([]byte(value.Raw), &contentMap)
if errUnmarshal != nil {
log.Warnf("failed to unmarshal content: %v\n", errUnmarshal)
return true
}
newContents = append(newContents, contentMap)
}
return true
})
// Handle any remaining pending groups with remaining responses
for _, group := range pendingGroups {
if len(collectedResponses) >= group.ResponsesNeeded {
groupResponses := collectedResponses[:group.ResponsesNeeded]
collectedResponses = collectedResponses[group.ResponsesNeeded:]
var responseParts []interface{}
for _, response := range groupResponses {
var responseMap map[string]interface{}
errUnmarshal := json.Unmarshal([]byte(response.Raw), &responseMap)
if errUnmarshal != nil {
log.Warnf("failed to unmarshal function response: %v\n", errUnmarshal)
continue
}
responseParts = append(responseParts, responseMap)
}
if len(responseParts) > 0 {
functionResponseContent := map[string]interface{}{
"parts": responseParts,
"role": "function",
}
newContents = append(newContents, functionResponseContent)
}
}
}
// Update the original JSON with the new contents
result := input
newContentsJSON, _ := json.Marshal(newContents)
result, _ = sjson.Set(result, "request.contents", json.RawMessage(newContentsJSON))
return result, nil
}

View File

@@ -1,6 +1,7 @@
package translator
import (
"fmt"
"time"
"github.com/tidwall/gjson"
@@ -62,32 +63,40 @@ func ConvertCliToOpenAI(rawJson []byte, unixTimestamp int64, isGlAPIKey bool) st
}
// Process the main content part of the response.
partResult := gjson.GetBytes(rawJson, "response.candidates.0.content.parts.0")
partTextResult := partResult.Get("text")
functionCallResult := partResult.Get("functionCall")
partsResult := gjson.GetBytes(rawJson, "response.candidates.0.content.parts")
if partsResult.IsArray() {
partResults := partsResult.Array()
for i := 0; i < len(partResults); i++ {
partResult := partResults[i]
partTextResult := partResult.Get("text")
functionCallResult := partResult.Get("functionCall")
if partTextResult.Exists() {
// Handle text content, distinguishing between regular content and reasoning/thoughts.
if partResult.Get("thought").Bool() {
template, _ = sjson.Set(template, "choices.0.delta.reasoning_content", partTextResult.String())
} else {
template, _ = sjson.Set(template, "choices.0.delta.content", partTextResult.String())
if partTextResult.Exists() {
// Handle text content, distinguishing between regular content and reasoning/thoughts.
if partResult.Get("thought").Bool() {
template, _ = sjson.Set(template, "choices.0.delta.reasoning_content", partTextResult.String())
} else {
template, _ = sjson.Set(template, "choices.0.delta.content", partTextResult.String())
}
template, _ = sjson.Set(template, "choices.0.delta.role", "assistant")
} else if functionCallResult.Exists() {
// Handle function call content.
toolCallsResult := gjson.Get(template, "choices.0.delta.tool_calls")
if !toolCallsResult.Exists() || !toolCallsResult.IsArray() {
template, _ = sjson.SetRaw(template, "choices.0.delta.tool_calls", `[]`)
}
functionCallTemplate := `{"id": "","type": "function","function": {"name": "","arguments": ""}}`
fcName := functionCallResult.Get("name").String()
functionCallTemplate, _ = sjson.Set(functionCallTemplate, "id", fmt.Sprintf("%s-%d", fcName, time.Now().UnixNano()))
functionCallTemplate, _ = sjson.Set(functionCallTemplate, "function.name", fcName)
if fcArgsResult := functionCallResult.Get("args"); fcArgsResult.Exists() {
functionCallTemplate, _ = sjson.Set(functionCallTemplate, "function.arguments", fcArgsResult.Raw)
}
template, _ = sjson.Set(template, "choices.0.delta.role", "assistant")
template, _ = sjson.SetRaw(template, "choices.0.message.tool_calls.-1", functionCallTemplate)
}
}
template, _ = sjson.Set(template, "choices.0.delta.role", "assistant")
} else if functionCallResult.Exists() {
// Handle function call content.
functionCallTemplate := `[{"id": "","type": "function","function": {"name": "","arguments": ""}}]`
fcName := functionCallResult.Get("name").String()
functionCallTemplate, _ = sjson.Set(functionCallTemplate, "0.id", fcName)
functionCallTemplate, _ = sjson.Set(functionCallTemplate, "0.function.name", fcName)
if fcArgsResult := functionCallResult.Get("args"); fcArgsResult.Exists() {
functionCallTemplate, _ = sjson.Set(functionCallTemplate, "0.function.arguments", fcArgsResult.Raw)
}
template, _ = sjson.Set(template, "choices.0.delta.role", "assistant")
template, _ = sjson.SetRaw(template, "choices.0.delta.tool_calls", functionCallTemplate)
} else {
// If no usable content is found, return an empty string.
return ""
}
return template
@@ -163,7 +172,7 @@ func ConvertCliToOpenAINonStream(rawJson []byte, unixTimestamp int64, isGlAPIKey
}
functionCallItemTemplate := `{"id": "","type": "function","function": {"name": "","arguments": ""}}`
fcName := functionCallResult.Get("name").String()
functionCallItemTemplate, _ = sjson.Set(functionCallItemTemplate, "id", fcName)
functionCallItemTemplate, _ = sjson.Set(functionCallItemTemplate, "id", fmt.Sprintf("%s-%d", fcName, time.Now().UnixNano()))
functionCallItemTemplate, _ = sjson.Set(functionCallItemTemplate, "function.name", fcName)
if fcArgsResult := functionCallResult.Get("args"); fcArgsResult.Exists() {
functionCallItemTemplate, _ = sjson.Set(functionCallItemTemplate, "function.arguments", fcArgsResult.Raw)

View File

@@ -311,6 +311,7 @@ func (c *Client) APIRequest(ctx context.Context, endpoint string, body interface
}
}()
bodyBytes, _ := io.ReadAll(resp.Body)
// log.Debug(string(jsonBody))
return nil, &ErrorMessage{resp.StatusCode, fmt.Errorf(string(bodyBytes))}
}
@@ -534,7 +535,9 @@ func (c *Client) SendMessageStream(ctx context.Context, rawJson []byte, model st
// SendRawMessage handles a single conversational turn, including tool calls.
func (c *Client) SendRawMessage(ctx context.Context, rawJson []byte) ([]byte, *ErrorMessage) {
rawJson, _ = sjson.SetBytes(rawJson, "project", c.GetProjectID())
if c.glAPIKey == "" {
rawJson, _ = sjson.SetBytes(rawJson, "project", c.GetProjectID())
}
modelResult := gjson.GetBytes(rawJson, "model")
model := modelResult.String()
@@ -584,7 +587,9 @@ func (c *Client) SendRawMessageStream(ctx context.Context, rawJson []byte) (<-ch
defer close(errChan)
defer close(dataChan)
rawJson, _ = sjson.SetBytes(rawJson, "project", c.GetProjectID())
if c.glAPIKey == "" {
rawJson, _ = sjson.SetBytes(rawJson, "project", c.GetProjectID())
}
modelResult := gjson.GetBytes(rawJson, "model")
model := modelResult.String()

View File

@@ -7,12 +7,10 @@ import (
"github.com/luispater/CLIProxyAPI/internal/auth"
"github.com/luispater/CLIProxyAPI/internal/client"
"github.com/luispater/CLIProxyAPI/internal/config"
"github.com/luispater/CLIProxyAPI/internal/util"
log "github.com/sirupsen/logrus"
"golang.org/x/net/proxy"
"io/fs"
"net"
"net/http"
"net/url"
"os"
"os/signal"
"path/filepath"
@@ -69,33 +67,12 @@ func StartService(cfg *config.Config) {
}
if len(cfg.GlAPIKey) > 0 {
var transport *http.Transport
proxyURL, errParse := url.Parse(cfg.ProxyUrl)
if errParse == nil {
if proxyURL.Scheme == "socks5" {
username := proxyURL.User.Username()
password, _ := proxyURL.User.Password()
proxyAuth := &proxy.Auth{User: username, Password: password}
dialer, errSOCKS5 := proxy.SOCKS5("tcp", proxyURL.Host, proxyAuth, proxy.Direct)
if errSOCKS5 != nil {
log.Fatalf("create SOCKS5 dialer failed: %v", errSOCKS5)
}
transport = &http.Transport{
DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
return dialer.Dial(network, addr)
},
}
} else if proxyURL.Scheme == "http" || proxyURL.Scheme == "https" {
// Handle HTTP/HTTPS proxy.
transport = &http.Transport{Proxy: http.ProxyURL(proxyURL)}
}
}
for i := 0; i < len(cfg.GlAPIKey); i++ {
httpClient := &http.Client{}
if transport != nil {
httpClient.Transport = transport
httpClient, errSetProxy := util.SetProxy(cfg, &http.Client{})
if errSetProxy != nil {
log.Fatalf("set proxy failed: %v", errSetProxy)
}
log.Debug("Initializing with Generative Language API key...")
cliClient := client.NewClient(httpClient, nil, cfg, cfg.GlAPIKey[i])
cliClients = append(cliClients, cliClient)

37
internal/util/proxy.go Normal file
View File

@@ -0,0 +1,37 @@
package util
import (
"context"
"github.com/luispater/CLIProxyAPI/internal/config"
"golang.org/x/net/proxy"
"net"
"net/http"
"net/url"
)
func SetProxy(cfg *config.Config, httpClient *http.Client) (*http.Client, error) {
var transport *http.Transport
proxyURL, errParse := url.Parse(cfg.ProxyUrl)
if errParse == nil {
if proxyURL.Scheme == "socks5" {
username := proxyURL.User.Username()
password, _ := proxyURL.User.Password()
proxyAuth := &proxy.Auth{User: username, Password: password}
dialer, errSOCKS5 := proxy.SOCKS5("tcp", proxyURL.Host, proxyAuth, proxy.Direct)
if errSOCKS5 != nil {
return nil, errSOCKS5
}
transport = &http.Transport{
DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
return dialer.Dial(network, addr)
},
}
} else if proxyURL.Scheme == "http" || proxyURL.Scheme == "https" {
transport = &http.Transport{Proxy: http.ProxyURL(proxyURL)}
}
}
if transport != nil {
httpClient.Transport = transport
}
return httpClient, nil
}