Add system instruction support and enhance internal API handlers

- Introduced `SystemInstruction` field in `PrepareRequest` and `GenerateContentRequest` for better message parsing.
- Updated `SendMessage` and `SendMessageStream` to handle system instructions in client API calls.
- Enhanced error handling and manual flushing logic in response flows.
- Added new internal API endpoints `/v1internal:generateContent` and `/v1internal:streamGenerateContent`.
- Improved proxy handling and transport logic in HTTP client initialization.
This commit is contained in:
Luis Pater
2025-07-10 05:16:54 +08:00
parent 65f47c196a
commit 273e1d9cbe
5 changed files with 442 additions and 35 deletions

View File

@@ -266,6 +266,12 @@ func (c *Client) APIRequest(ctx context.Context, endpoint string, body interface
url = url + "?alt=sse"
}
jsonBody = []byte(gjson.GetBytes(jsonBody, "request").Raw)
systemInstructionResult := gjson.GetBytes(jsonBody, "systemInstruction")
if systemInstructionResult.Exists() {
jsonBody, _ = sjson.SetRawBytes(jsonBody, "system_instruction", []byte(systemInstructionResult.Raw))
jsonBody, _ = sjson.DeleteBytes(jsonBody, "systemInstruction")
jsonBody, _ = sjson.DeleteBytes(jsonBody, "session_id")
}
}
// log.Debug(string(jsonBody))
@@ -303,15 +309,14 @@ func (c *Client) APIRequest(ctx context.Context, endpoint string, body interface
}
}()
bodyBytes, _ := io.ReadAll(resp.Body)
return nil, &ErrorMessage{resp.StatusCode, fmt.Errorf(string(bodyBytes))}
}
return resp.Body, nil
}
// SendMessageStream handles a single conversational turn, including tool calls.
func (c *Client) SendMessage(ctx context.Context, rawJson []byte, model string, contents []Content, tools []ToolDeclaration) ([]byte, *ErrorMessage) {
// SendMessage handles a single conversational turn, including tool calls.
func (c *Client) SendMessage(ctx context.Context, rawJson []byte, model string, systemInstruction *Content, contents []Content, tools []ToolDeclaration) ([]byte, *ErrorMessage) {
request := GenerateContentRequest{
Contents: contents,
GenerationConfig: GenerationConfig{
@@ -320,6 +325,9 @@ func (c *Client) SendMessage(ctx context.Context, rawJson []byte, model string,
},
},
}
request.SystemInstruction = systemInstruction
request.Tools = tools
requestBody := map[string]interface{}{
@@ -402,7 +410,7 @@ func (c *Client) SendMessage(ctx context.Context, rawJson []byte, model string,
}
// SendMessageStream handles a single conversational turn, including tool calls.
func (c *Client) SendMessageStream(ctx context.Context, rawJson []byte, model string, contents []Content, tools []ToolDeclaration) (<-chan []byte, <-chan *ErrorMessage) {
func (c *Client) SendMessageStream(ctx context.Context, rawJson []byte, model string, systemInstruction *Content, contents []Content, tools []ToolDeclaration) (<-chan []byte, <-chan *ErrorMessage) {
dataTag := []byte("data: ")
errChan := make(chan *ErrorMessage)
dataChan := make(chan []byte)
@@ -418,6 +426,9 @@ func (c *Client) SendMessageStream(ctx context.Context, rawJson []byte, model st
},
},
}
request.SystemInstruction = systemInstruction
request.Tools = tools
requestBody := map[string]interface{}{
@@ -519,6 +530,117 @@ func (c *Client) SendMessageStream(ctx context.Context, rawJson []byte, model st
return dataChan, errChan
}
// SendRawMessage handles a single conversational turn, including tool calls.
func (c *Client) SendRawMessage(ctx context.Context, rawJson []byte) ([]byte, *ErrorMessage) {
rawJson, _ = sjson.SetBytes(rawJson, "project", c.GetProjectID())
modelResult := gjson.GetBytes(rawJson, "model")
model := modelResult.String()
modelName := model
for {
if c.isModelQuotaExceeded(modelName) {
if c.cfg.QuotaExceeded.SwitchPreviewModel && c.glAPIKey == "" {
modelName = c.getPreviewModel(model)
if modelName != "" {
log.Debugf("Model %s is quota exceeded. Switch to preview model %s", model, modelName)
rawJson, _ = sjson.SetBytes(rawJson, "model", modelName)
continue
}
}
return nil, &ErrorMessage{
StatusCode: 429,
Error: fmt.Errorf(`{"error":{"code":429,"message":"All the models of '%s' are quota exceeded","status":"RESOURCE_EXHAUSTED"}}`, model),
}
}
respBody, err := c.APIRequest(ctx, "generateContent", rawJson, false)
if err != nil {
if err.StatusCode == 429 {
now := time.Now()
c.modelQuotaExceeded[modelName] = &now
if c.cfg.QuotaExceeded.SwitchPreviewModel && c.glAPIKey == "" {
continue
}
}
return nil, err
}
delete(c.modelQuotaExceeded, modelName)
bodyBytes, errReadAll := io.ReadAll(respBody)
if errReadAll != nil {
return nil, &ErrorMessage{StatusCode: 500, Error: errReadAll}
}
return bodyBytes, nil
}
}
// SendRawMessageStream handles a single conversational turn, including tool calls.
func (c *Client) SendRawMessageStream(ctx context.Context, rawJson []byte) (<-chan []byte, <-chan *ErrorMessage) {
dataTag := []byte("data: ")
errChan := make(chan *ErrorMessage)
dataChan := make(chan []byte)
go func() {
defer close(errChan)
defer close(dataChan)
rawJson, _ = sjson.SetBytes(rawJson, "project", c.GetProjectID())
modelResult := gjson.GetBytes(rawJson, "model")
model := modelResult.String()
modelName := model
var stream io.ReadCloser
for {
if c.isModelQuotaExceeded(modelName) {
if c.cfg.QuotaExceeded.SwitchPreviewModel && c.glAPIKey == "" {
modelName = c.getPreviewModel(model)
if modelName != "" {
log.Debugf("Model %s is quota exceeded. Switch to preview model %s", model, modelName)
rawJson, _ = sjson.SetBytes(rawJson, "model", modelName)
continue
}
}
errChan <- &ErrorMessage{
StatusCode: 429,
Error: fmt.Errorf(`{"error":{"code":429,"message":"All the models of '%s' are quota exceeded","status":"RESOURCE_EXHAUSTED"}}`, model),
}
return
}
var err *ErrorMessage
stream, err = c.APIRequest(ctx, "streamGenerateContent", rawJson, true)
if err != nil {
if err.StatusCode == 429 {
now := time.Now()
c.modelQuotaExceeded[modelName] = &now
if c.cfg.QuotaExceeded.SwitchPreviewModel && c.glAPIKey == "" {
continue
}
}
errChan <- err
return
}
delete(c.modelQuotaExceeded, modelName)
break
}
scanner := bufio.NewScanner(stream)
for scanner.Scan() {
line := scanner.Bytes()
if bytes.HasPrefix(line, dataTag) {
dataChan <- line[6:]
}
}
if errScanner := scanner.Err(); errScanner != nil {
errChan <- &ErrorMessage{500, errScanner}
_ = stream.Close()
return
}
_ = stream.Close()
}()
return dataChan, errChan
}
func (c *Client) isModelQuotaExceeded(model string) bool {
if lastExceededTime, hasKey := c.modelQuotaExceeded[model]; hasKey {
duration := time.Now().Sub(*lastExceededTime)