mirror of
https://github.com/router-for-me/CLIProxyAPI.git
synced 2026-02-02 12:30:50 +08:00
Compare commits
6 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
3b4634e2dc | ||
|
|
00bd6a3e46 | ||
|
|
5812229d9b | ||
|
|
0b026933a7 | ||
|
|
3b2ab0d7bd | ||
|
|
e64fa48823 |
1
.gitignore
vendored
Normal file
1
.gitignore
vendored
Normal file
@@ -0,0 +1 @@
|
||||
config.yaml
|
||||
@@ -15,4 +15,4 @@ archives:
|
||||
- LICENSE
|
||||
- README.md
|
||||
- README_CN.md
|
||||
- config.yaml
|
||||
- config.example.yaml
|
||||
@@ -99,6 +99,15 @@ func (h *APIHandlers) CLIHandler(c *gin.Context) {
|
||||
}
|
||||
|
||||
func (h *APIHandlers) internalStreamGenerateContent(c *gin.Context, rawJson []byte) {
|
||||
alt := h.getAlt(c)
|
||||
|
||||
if alt == "" {
|
||||
c.Header("Content-Type", "text/event-stream")
|
||||
c.Header("Cache-Control", "no-cache")
|
||||
c.Header("Connection", "keep-alive")
|
||||
c.Header("Access-Control-Allow-Origin", "*")
|
||||
}
|
||||
|
||||
// Get the http.Flusher interface to manually flush the response.
|
||||
flusher, ok := c.Writer.(http.Flusher)
|
||||
if !ok {
|
||||
|
||||
@@ -107,6 +107,15 @@ func (h *APIHandlers) GeminiHandler(c *gin.Context) {
|
||||
}
|
||||
|
||||
func (h *APIHandlers) geminiStreamGenerateContent(c *gin.Context, rawJson []byte) {
|
||||
alt := h.getAlt(c)
|
||||
|
||||
if alt == "" {
|
||||
c.Header("Content-Type", "text/event-stream")
|
||||
c.Header("Cache-Control", "no-cache")
|
||||
c.Header("Connection", "keep-alive")
|
||||
c.Header("Access-Control-Allow-Origin", "*")
|
||||
}
|
||||
|
||||
// Get the http.Flusher interface to manually flush the response.
|
||||
flusher, ok := c.Writer.(http.Flusher)
|
||||
if !ok {
|
||||
@@ -122,8 +131,6 @@ func (h *APIHandlers) geminiStreamGenerateContent(c *gin.Context, rawJson []byte
|
||||
modelResult := gjson.GetBytes(rawJson, "model")
|
||||
modelName := modelResult.String()
|
||||
|
||||
alt := h.getAlt(c)
|
||||
|
||||
cliCtx, cliCancel := context.WithCancel(context.Background())
|
||||
var cliClient *client.Client
|
||||
defer func() {
|
||||
@@ -246,7 +253,7 @@ func (h *APIHandlers) geminiCountTokens(c *gin.Context, rawJson []byte) {
|
||||
c.Header("Content-Type", "application/json")
|
||||
|
||||
alt := h.getAlt(c)
|
||||
|
||||
// orgRawJson := rawJson
|
||||
modelResult := gjson.GetBytes(rawJson, "model")
|
||||
modelName := modelResult.String()
|
||||
cliCtx, cliCancel := context.WithCancel(context.Background())
|
||||
@@ -259,7 +266,7 @@ func (h *APIHandlers) geminiCountTokens(c *gin.Context, rawJson []byte) {
|
||||
|
||||
for {
|
||||
var errorResponse *client.ErrorMessage
|
||||
cliClient, errorResponse = h.getClient(modelName)
|
||||
cliClient, errorResponse = h.getClient(modelName, false)
|
||||
if errorResponse != nil {
|
||||
c.Status(errorResponse.StatusCode)
|
||||
_, _ = fmt.Fprint(c.Writer, errorResponse.Error)
|
||||
@@ -273,8 +280,13 @@ func (h *APIHandlers) geminiCountTokens(c *gin.Context, rawJson []byte) {
|
||||
log.Debugf("Request use account: %s, project id: %s", cliClient.GetEmail(), cliClient.GetProjectID())
|
||||
|
||||
template := `{"request":{}}`
|
||||
template, _ = sjson.SetRaw(template, "request", gjson.GetBytes(rawJson, "generateContentRequest").Raw)
|
||||
template, _ = sjson.Delete(template, "generateContentRequest")
|
||||
if gjson.GetBytes(rawJson, "generateContentRequest").Exists() {
|
||||
template, _ = sjson.SetRaw(template, "request", gjson.GetBytes(rawJson, "generateContentRequest").Raw)
|
||||
template, _ = sjson.Delete(template, "generateContentRequest")
|
||||
} else if gjson.GetBytes(rawJson, "contents").Exists() {
|
||||
template, _ = sjson.SetRaw(template, "request.contents", gjson.GetBytes(rawJson, "contents").Raw)
|
||||
template, _ = sjson.Delete(template, "contents")
|
||||
}
|
||||
rawJson = []byte(template)
|
||||
}
|
||||
|
||||
@@ -286,6 +298,9 @@ func (h *APIHandlers) geminiCountTokens(c *gin.Context, rawJson []byte) {
|
||||
c.Status(err.StatusCode)
|
||||
_, _ = c.Writer.Write([]byte(err.Error.Error()))
|
||||
cliCancel()
|
||||
// log.Debugf(err.Error.Error())
|
||||
// log.Debugf(string(rawJson))
|
||||
// log.Debugf(string(orgRawJson))
|
||||
}
|
||||
break
|
||||
} else {
|
||||
|
||||
@@ -85,7 +85,7 @@ func (h *APIHandlers) Models(c *gin.Context) {
|
||||
})
|
||||
}
|
||||
|
||||
func (h *APIHandlers) getClient(modelName string) (*client.Client, *client.ErrorMessage) {
|
||||
func (h *APIHandlers) getClient(modelName string, isGenerateContent ...bool) (*client.Client, *client.ErrorMessage) {
|
||||
if len(h.cliClients) == 0 {
|
||||
return nil, &client.ErrorMessage{StatusCode: 500, Error: fmt.Errorf("no clients available")}
|
||||
}
|
||||
@@ -95,8 +95,10 @@ func (h *APIHandlers) getClient(modelName string) (*client.Client, *client.Error
|
||||
// Lock the mutex to update the last used client index
|
||||
mutex.Lock()
|
||||
startIndex := lastUsedClientIndex
|
||||
currentIndex := (startIndex + 1) % len(h.cliClients)
|
||||
lastUsedClientIndex = currentIndex
|
||||
if (len(isGenerateContent) > 0 && isGenerateContent[0]) || len(isGenerateContent) == 0 {
|
||||
currentIndex := (startIndex + 1) % len(h.cliClients)
|
||||
lastUsedClientIndex = currentIndex
|
||||
}
|
||||
mutex.Unlock()
|
||||
|
||||
// Reorder the client to start from the last used index
|
||||
|
||||
Reference in New Issue
Block a user