Add system instruction support and enhance internal API handlers

- Introduced `SystemInstruction` field in `PrepareRequest` and `GenerateContentRequest` for better message parsing.
- Updated `SendMessage` and `SendMessageStream` to handle system instructions in client API calls.
- Enhanced error handling and manual flushing logic in response flows.
- Added new internal API endpoints `/v1internal:generateContent` and `/v1internal:streamGenerateContent`.
- Improved proxy handling and transport logic in HTTP client initialization.
This commit is contained in:
Luis Pater
2025-07-10 05:16:54 +08:00
parent 65f47c196a
commit 273e1d9cbe
5 changed files with 442 additions and 35 deletions

View File

@@ -1,6 +1,7 @@
package api
import (
"bytes"
"context"
"fmt"
"github.com/luispater/CLIProxyAPI/internal/api/translator"
@@ -8,7 +9,12 @@ import (
"github.com/luispater/CLIProxyAPI/internal/config"
log "github.com/sirupsen/logrus"
"github.com/tidwall/gjson"
"github.com/tidwall/sjson"
"golang.org/x/net/proxy"
"io"
"net"
"net/http"
"net/url"
"sync"
"time"
@@ -196,19 +202,7 @@ func (h *APIHandlers) ChatCompletions(c *gin.Context) {
func (h *APIHandlers) handleNonStreamingResponse(c *gin.Context, rawJson []byte) {
c.Header("Content-Type", "application/json")
// Handle streaming manually
flusher, ok := c.Writer.(http.Flusher)
if !ok {
c.JSON(http.StatusInternalServerError, ErrorResponse{
Error: ErrorDetail{
Message: "Streaming not supported",
Type: "server_error",
},
})
return
}
modelName, contents, tools := translator.PrepareRequest(rawJson)
modelName, systemInstruction, contents, tools := translator.PrepareRequest(rawJson)
cliCtx, cliCancel := context.WithCancel(context.Background())
var cliClient *client.Client
defer func() {
@@ -239,8 +233,7 @@ func (h *APIHandlers) handleNonStreamingResponse(c *gin.Context, rawJson []byte)
if len(reorderedClients) == 0 {
c.Status(429)
_, _ = fmt.Fprint(c.Writer, fmt.Sprintf(`{"error":{"code":429,"message":"All the models of '%s' are quota exceeded","status":"RESOURCE_EXHAUSTED"}}`, modelName))
flusher.Flush()
_, _ = c.Writer.Write([]byte(fmt.Sprintf(`{"error":{"code":429,"message":"All the models of '%s' are quota exceeded","status":"RESOURCE_EXHAUSTED"}}`, modelName)))
cliCancel()
return
}
@@ -266,22 +259,20 @@ func (h *APIHandlers) handleNonStreamingResponse(c *gin.Context, rawJson []byte)
log.Debugf("Request use account: %s, project id: %s", cliClient.GetEmail(), cliClient.GetProjectID())
}
resp, err := cliClient.SendMessage(cliCtx, rawJson, modelName, contents, tools)
resp, err := cliClient.SendMessage(cliCtx, rawJson, modelName, systemInstruction, contents, tools)
if err != nil {
if err.StatusCode == 429 && h.cfg.QuotaExceeded.SwitchProject {
continue
} else {
c.Status(err.StatusCode)
_, _ = fmt.Fprint(c.Writer, err.Error.Error())
flusher.Flush()
_, _ = c.Writer.Write([]byte(err.Error.Error()))
cliCancel()
}
break
} else {
openAIFormat := translator.ConvertCliToOpenAINonStream(resp, time.Now().Unix(), isGlAPIKey)
if openAIFormat != "" {
_, _ = fmt.Fprintf(c.Writer, "data: %s\n\n", openAIFormat)
flusher.Flush()
_, _ = c.Writer.Write([]byte(openAIFormat))
}
cliCancel()
break
@@ -309,7 +300,7 @@ func (h *APIHandlers) handleStreamingResponse(c *gin.Context, rawJson []byte) {
}
// Prepare the request for the backend client.
modelName, contents, tools := translator.PrepareRequest(rawJson)
modelName, systemInstruction, contents, tools := translator.PrepareRequest(rawJson)
cliCtx, cliCancel := context.WithCancel(context.Background())
var cliClient *client.Client
defer func() {
@@ -369,7 +360,7 @@ outLoop:
log.Debugf("Request use account: %s, project id: %s", cliClient.GetEmail(), cliClient.GetProjectID())
}
// Send the message and receive response chunks and errors via channels.
respChan, errChan := cliClient.SendMessageStream(cliCtx, rawJson, modelName, contents, tools)
respChan, errChan := cliClient.SendMessageStream(cliCtx, rawJson, modelName, systemInstruction, contents, tools)
hasFirstResponse := false
for {
select {
@@ -420,3 +411,295 @@ outLoop:
}
}
}
func (h *APIHandlers) Internal(c *gin.Context) {
rawJson, _ := c.GetRawData()
requestRawURI := c.Request.URL.Path
if requestRawURI == "/v1internal:generateContent" {
h.internalGenerateContent(c, rawJson)
} else if requestRawURI == "/v1internal:streamGenerateContent" {
h.internalStreamGenerateContent(c, rawJson)
} else {
reqBody := bytes.NewBuffer(rawJson)
req, err := http.NewRequest("POST", fmt.Sprintf("https://cloudcode-pa.googleapis.com%s", c.Request.URL.RequestURI()), reqBody)
if err != nil {
c.JSON(http.StatusBadRequest, ErrorResponse{
Error: ErrorDetail{
Message: fmt.Sprintf("Invalid request: %v", err),
Type: "invalid_request_error",
},
})
return
}
for key, value := range c.Request.Header {
req.Header[key] = value
}
var transport *http.Transport
proxyURL, errParse := url.Parse(h.cfg.ProxyUrl)
if errParse == nil {
if proxyURL.Scheme == "socks5" {
username := proxyURL.User.Username()
password, _ := proxyURL.User.Password()
proxyAuth := &proxy.Auth{User: username, Password: password}
dialer, errSOCKS5 := proxy.SOCKS5("tcp", proxyURL.Host, proxyAuth, proxy.Direct)
if errSOCKS5 != nil {
log.Fatalf("create SOCKS5 dialer failed: %v", errSOCKS5)
}
transport = &http.Transport{
DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
return dialer.Dial(network, addr)
},
}
} else if proxyURL.Scheme == "http" || proxyURL.Scheme == "https" {
transport = &http.Transport{Proxy: http.ProxyURL(proxyURL)}
}
}
httpClient := &http.Client{}
if transport != nil {
httpClient.Transport = transport
}
resp, err := httpClient.Do(req)
if err != nil {
c.JSON(http.StatusBadRequest, ErrorResponse{
Error: ErrorDetail{
Message: fmt.Sprintf("Invalid request: %v", err),
Type: "invalid_request_error",
},
})
return
}
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
defer func() {
if err = resp.Body.Close(); err != nil {
log.Printf("warn: failed to close response body: %v", err)
}
}()
bodyBytes, _ := io.ReadAll(resp.Body)
c.JSON(http.StatusBadRequest, ErrorResponse{
Error: ErrorDetail{
Message: string(bodyBytes),
Type: "invalid_request_error",
},
})
return
}
defer func() {
_ = resp.Body.Close()
}()
for key, value := range resp.Header {
c.Header(key, value[0])
}
output, err := io.ReadAll(resp.Body)
if err != nil {
log.Errorf("Failed to read response body: %v", err)
return
}
_, _ = c.Writer.Write(output)
}
}
func (h *APIHandlers) internalStreamGenerateContent(c *gin.Context, rawJson []byte) {
// Get the http.Flusher interface to manually flush the response.
flusher, ok := c.Writer.(http.Flusher)
if !ok {
c.JSON(http.StatusInternalServerError, ErrorResponse{
Error: ErrorDetail{
Message: "Streaming not supported",
Type: "server_error",
},
})
return
}
modelResult := gjson.GetBytes(rawJson, "model")
modelName := modelResult.String()
cliCtx, cliCancel := context.WithCancel(context.Background())
var cliClient *client.Client
defer func() {
// Ensure the client's mutex is unlocked on function exit.
if cliClient != nil {
cliClient.RequestMutex.Unlock()
}
}()
outLoop:
for {
// Lock the mutex to update the last used client index
mutex.Lock()
startIndex := lastUsedClientIndex
currentIndex := (startIndex + 1) % len(h.cliClients)
lastUsedClientIndex = currentIndex
mutex.Unlock()
// Reorder the client to start from the last used index
reorderedClients := make([]*client.Client, 0)
for i := 0; i < len(h.cliClients); i++ {
cliClient = h.cliClients[(startIndex+1+i)%len(h.cliClients)]
if cliClient.IsModelQuotaExceeded(modelName) {
log.Debugf("Model %s is quota exceeded for account %s, project id: %s", modelName, cliClient.GetEmail(), cliClient.GetProjectID())
cliClient = nil
continue
}
reorderedClients = append(reorderedClients, cliClient)
}
if len(reorderedClients) == 0 {
c.Status(429)
_, _ = fmt.Fprint(c.Writer, fmt.Sprintf(`{"error":{"code":429,"message":"All the models of '%s' are quota exceeded","status":"RESOURCE_EXHAUSTED"}}`, modelName))
flusher.Flush()
cliCancel()
return
}
locked := false
for i := 0; i < len(reorderedClients); i++ {
cliClient = reorderedClients[i]
if cliClient.RequestMutex.TryLock() {
locked = true
break
}
}
if !locked {
cliClient = h.cliClients[0]
cliClient.RequestMutex.Lock()
}
if glAPIKey := cliClient.GetGenerativeLanguageAPIKey(); glAPIKey != "" {
log.Debugf("Request use generative language API Key: %s", glAPIKey)
} else {
log.Debugf("Request use account: %s, project id: %s", cliClient.GetEmail(), cliClient.GetProjectID())
}
// Send the message and receive response chunks and errors via channels.
respChan, errChan := cliClient.SendRawMessageStream(cliCtx, rawJson)
hasFirstResponse := false
for {
select {
// Handle client disconnection.
case <-c.Request.Context().Done():
if c.Request.Context().Err().Error() == "context canceled" {
log.Debugf("Client disconnected: %v", c.Request.Context().Err())
cliCancel() // Cancel the backend request.
return
}
// Process incoming response chunks.
case chunk, okStream := <-respChan:
if !okStream {
cliCancel()
return
} else {
hasFirstResponse = true
if cliClient.GetGenerativeLanguageAPIKey() != "" {
chunk, _ = sjson.SetRawBytes(chunk, "response", chunk)
}
_, _ = c.Writer.Write([]byte("data: "))
_, _ = c.Writer.Write(chunk)
_, _ = c.Writer.Write([]byte("\n\n"))
flusher.Flush()
}
// Handle errors from the backend.
case err, okError := <-errChan:
if okError {
if err.StatusCode == 429 && h.cfg.QuotaExceeded.SwitchProject {
continue outLoop
} else {
c.Status(err.StatusCode)
_, _ = fmt.Fprint(c.Writer, err.Error.Error())
flusher.Flush()
cliCancel()
}
return
}
// Send a keep-alive signal to the client.
case <-time.After(500 * time.Millisecond):
if hasFirstResponse {
_, _ = c.Writer.Write([]byte("\n"))
flusher.Flush()
}
}
}
}
}
func (h *APIHandlers) internalGenerateContent(c *gin.Context, rawJson []byte) {
c.Header("Content-Type", "application/json")
modelResult := gjson.GetBytes(rawJson, "model")
modelName := modelResult.String()
cliCtx, cliCancel := context.WithCancel(context.Background())
var cliClient *client.Client
defer func() {
if cliClient != nil {
cliClient.RequestMutex.Unlock()
}
}()
for {
// Lock the mutex to update the last used client index
mutex.Lock()
startIndex := lastUsedClientIndex
currentIndex := (startIndex + 1) % len(h.cliClients)
lastUsedClientIndex = currentIndex
mutex.Unlock()
// Reorder the client to start from the last used index
reorderedClients := make([]*client.Client, 0)
for i := 0; i < len(h.cliClients); i++ {
cliClient = h.cliClients[(startIndex+1+i)%len(h.cliClients)]
if cliClient.IsModelQuotaExceeded(modelName) {
log.Debugf("Model %s is quota exceeded for account %s, project id: %s", modelName, cliClient.GetEmail(), cliClient.GetProjectID())
cliClient = nil
continue
}
reorderedClients = append(reorderedClients, cliClient)
}
if len(reorderedClients) == 0 {
c.Status(429)
_, _ = c.Writer.Write([]byte(fmt.Sprintf(`{"error":{"code":429,"message":"All the models of '%s' are quota exceeded","status":"RESOURCE_EXHAUSTED"}}`, modelName)))
cliCancel()
return
}
locked := false
for i := 0; i < len(reorderedClients); i++ {
cliClient = reorderedClients[i]
if cliClient.RequestMutex.TryLock() {
locked = true
break
}
}
if !locked {
cliClient = h.cliClients[0]
cliClient.RequestMutex.Lock()
}
if glAPIKey := cliClient.GetGenerativeLanguageAPIKey(); glAPIKey != "" {
log.Debugf("Request use generative language API Key: %s", glAPIKey)
} else {
log.Debugf("Request use account: %s, project id: %s", cliClient.GetEmail(), cliClient.GetProjectID())
}
resp, err := cliClient.SendRawMessage(cliCtx, rawJson)
if err != nil {
if err.StatusCode == 429 && h.cfg.QuotaExceeded.SwitchProject {
continue
} else {
c.Status(err.StatusCode)
_, _ = c.Writer.Write([]byte(err.Error.Error()))
cliCancel()
}
break
} else {
_, _ = c.Writer.Write(resp)
cliCancel()
break
}
}
}

View File

@@ -81,6 +81,8 @@ func (s *Server) setupRoutes() {
},
})
})
s.engine.POST("/v1internal:method", s.handlers.Internal)
}
// Start begins listening for and serving HTTP requests.

View File

@@ -12,7 +12,7 @@ import (
// PrepareRequest translates a raw JSON request from an OpenAI-compatible format
// to the internal format expected by the backend client. It parses messages,
// roles, content types (text, image, file), and tool calls.
func PrepareRequest(rawJson []byte) (string, []client.Content, []client.ToolDeclaration) {
func PrepareRequest(rawJson []byte) (string, *client.Content, []client.Content, []client.ToolDeclaration) {
// Extract the model name from the request, defaulting to "gemini-2.5-pro".
modelName := "gemini-2.5-pro"
modelResult := gjson.GetBytes(rawJson, "model")
@@ -22,6 +22,7 @@ func PrepareRequest(rawJson []byte) (string, []client.Content, []client.ToolDecl
// Process the array of messages.
contents := make([]client.Content, 0)
var systemInstruction *client.Content
messagesResult := gjson.GetBytes(rawJson, "messages")
if messagesResult.IsArray() {
messagesResults := messagesResult.Array()
@@ -37,13 +38,11 @@ func PrepareRequest(rawJson []byte) (string, []client.Content, []client.ToolDecl
// System messages are converted to a user message followed by a model's acknowledgment.
case "system":
if contentResult.Type == gjson.String {
contents = append(contents, client.Content{Role: "user", Parts: []client.Part{{Text: contentResult.String()}}})
contents = append(contents, client.Content{Role: "model", Parts: []client.Part{{Text: "Understood. I will follow these instructions and use my tools to assist you."}}})
systemInstruction = &client.Content{Role: "user", Parts: []client.Part{{Text: contentResult.String()}}}
} else if contentResult.IsObject() {
// Handle object-based system messages.
if contentResult.Get("type").String() == "text" {
contents = append(contents, client.Content{Role: "user", Parts: []client.Part{{Text: contentResult.Get("text").String()}}})
contents = append(contents, client.Content{Role: "model", Parts: []client.Part{{Text: "Understood. I will follow these instructions and use my tools to assist you."}}})
systemInstruction = &client.Content{Role: "user", Parts: []client.Part{{Text: contentResult.Get("text").String()}}}
}
}
// User messages can contain simple text or a multi-part body.
@@ -159,5 +158,5 @@ func PrepareRequest(rawJson []byte) (string, []client.Content, []client.ToolDecl
tools = make([]client.ToolDeclaration, 0)
}
return modelName, contents, tools
return modelName, systemInstruction, contents, tools
}

View File

@@ -266,6 +266,12 @@ func (c *Client) APIRequest(ctx context.Context, endpoint string, body interface
url = url + "?alt=sse"
}
jsonBody = []byte(gjson.GetBytes(jsonBody, "request").Raw)
systemInstructionResult := gjson.GetBytes(jsonBody, "systemInstruction")
if systemInstructionResult.Exists() {
jsonBody, _ = sjson.SetRawBytes(jsonBody, "system_instruction", []byte(systemInstructionResult.Raw))
jsonBody, _ = sjson.DeleteBytes(jsonBody, "systemInstruction")
jsonBody, _ = sjson.DeleteBytes(jsonBody, "session_id")
}
}
// log.Debug(string(jsonBody))
@@ -303,15 +309,14 @@ func (c *Client) APIRequest(ctx context.Context, endpoint string, body interface
}
}()
bodyBytes, _ := io.ReadAll(resp.Body)
return nil, &ErrorMessage{resp.StatusCode, fmt.Errorf(string(bodyBytes))}
}
return resp.Body, nil
}
// SendMessageStream handles a single conversational turn, including tool calls.
func (c *Client) SendMessage(ctx context.Context, rawJson []byte, model string, contents []Content, tools []ToolDeclaration) ([]byte, *ErrorMessage) {
// SendMessage handles a single conversational turn, including tool calls.
func (c *Client) SendMessage(ctx context.Context, rawJson []byte, model string, systemInstruction *Content, contents []Content, tools []ToolDeclaration) ([]byte, *ErrorMessage) {
request := GenerateContentRequest{
Contents: contents,
GenerationConfig: GenerationConfig{
@@ -320,6 +325,9 @@ func (c *Client) SendMessage(ctx context.Context, rawJson []byte, model string,
},
},
}
request.SystemInstruction = systemInstruction
request.Tools = tools
requestBody := map[string]interface{}{
@@ -402,7 +410,7 @@ func (c *Client) SendMessage(ctx context.Context, rawJson []byte, model string,
}
// SendMessageStream handles a single conversational turn, including tool calls.
func (c *Client) SendMessageStream(ctx context.Context, rawJson []byte, model string, contents []Content, tools []ToolDeclaration) (<-chan []byte, <-chan *ErrorMessage) {
func (c *Client) SendMessageStream(ctx context.Context, rawJson []byte, model string, systemInstruction *Content, contents []Content, tools []ToolDeclaration) (<-chan []byte, <-chan *ErrorMessage) {
dataTag := []byte("data: ")
errChan := make(chan *ErrorMessage)
dataChan := make(chan []byte)
@@ -418,6 +426,9 @@ func (c *Client) SendMessageStream(ctx context.Context, rawJson []byte, model st
},
},
}
request.SystemInstruction = systemInstruction
request.Tools = tools
requestBody := map[string]interface{}{
@@ -519,6 +530,117 @@ func (c *Client) SendMessageStream(ctx context.Context, rawJson []byte, model st
return dataChan, errChan
}
// SendRawMessage handles a single conversational turn, including tool calls.
func (c *Client) SendRawMessage(ctx context.Context, rawJson []byte) ([]byte, *ErrorMessage) {
rawJson, _ = sjson.SetBytes(rawJson, "project", c.GetProjectID())
modelResult := gjson.GetBytes(rawJson, "model")
model := modelResult.String()
modelName := model
for {
if c.isModelQuotaExceeded(modelName) {
if c.cfg.QuotaExceeded.SwitchPreviewModel && c.glAPIKey == "" {
modelName = c.getPreviewModel(model)
if modelName != "" {
log.Debugf("Model %s is quota exceeded. Switch to preview model %s", model, modelName)
rawJson, _ = sjson.SetBytes(rawJson, "model", modelName)
continue
}
}
return nil, &ErrorMessage{
StatusCode: 429,
Error: fmt.Errorf(`{"error":{"code":429,"message":"All the models of '%s' are quota exceeded","status":"RESOURCE_EXHAUSTED"}}`, model),
}
}
respBody, err := c.APIRequest(ctx, "generateContent", rawJson, false)
if err != nil {
if err.StatusCode == 429 {
now := time.Now()
c.modelQuotaExceeded[modelName] = &now
if c.cfg.QuotaExceeded.SwitchPreviewModel && c.glAPIKey == "" {
continue
}
}
return nil, err
}
delete(c.modelQuotaExceeded, modelName)
bodyBytes, errReadAll := io.ReadAll(respBody)
if errReadAll != nil {
return nil, &ErrorMessage{StatusCode: 500, Error: errReadAll}
}
return bodyBytes, nil
}
}
// SendRawMessageStream handles a single conversational turn, including tool calls.
func (c *Client) SendRawMessageStream(ctx context.Context, rawJson []byte) (<-chan []byte, <-chan *ErrorMessage) {
dataTag := []byte("data: ")
errChan := make(chan *ErrorMessage)
dataChan := make(chan []byte)
go func() {
defer close(errChan)
defer close(dataChan)
rawJson, _ = sjson.SetBytes(rawJson, "project", c.GetProjectID())
modelResult := gjson.GetBytes(rawJson, "model")
model := modelResult.String()
modelName := model
var stream io.ReadCloser
for {
if c.isModelQuotaExceeded(modelName) {
if c.cfg.QuotaExceeded.SwitchPreviewModel && c.glAPIKey == "" {
modelName = c.getPreviewModel(model)
if modelName != "" {
log.Debugf("Model %s is quota exceeded. Switch to preview model %s", model, modelName)
rawJson, _ = sjson.SetBytes(rawJson, "model", modelName)
continue
}
}
errChan <- &ErrorMessage{
StatusCode: 429,
Error: fmt.Errorf(`{"error":{"code":429,"message":"All the models of '%s' are quota exceeded","status":"RESOURCE_EXHAUSTED"}}`, model),
}
return
}
var err *ErrorMessage
stream, err = c.APIRequest(ctx, "streamGenerateContent", rawJson, true)
if err != nil {
if err.StatusCode == 429 {
now := time.Now()
c.modelQuotaExceeded[modelName] = &now
if c.cfg.QuotaExceeded.SwitchPreviewModel && c.glAPIKey == "" {
continue
}
}
errChan <- err
return
}
delete(c.modelQuotaExceeded, modelName)
break
}
scanner := bufio.NewScanner(stream)
for scanner.Scan() {
line := scanner.Bytes()
if bytes.HasPrefix(line, dataTag) {
dataChan <- line[6:]
}
}
if errScanner := scanner.Err(); errScanner != nil {
errChan <- &ErrorMessage{500, errScanner}
_ = stream.Close()
return
}
_ = stream.Close()
}()
return dataChan, errChan
}
func (c *Client) isModelQuotaExceeded(model string) bool {
if lastExceededTime, hasKey := c.modelQuotaExceeded[model]; hasKey {
duration := time.Now().Sub(*lastExceededTime)

View File

@@ -64,6 +64,7 @@ type FunctionResponse struct {
// GenerateContentRequest is the top-level request structure for the streamGenerateContent endpoint.
type GenerateContentRequest struct {
SystemInstruction *Content `json:"systemInstruction,omitempty"`
Contents []Content `json:"contents"`
Tools []ToolDeclaration `json:"tools,omitempty"`
GenerationConfig `json:"generationConfig"`