mirror of
https://github.com/router-for-me/CLIProxyAPI.git
synced 2026-02-18 04:10:51 +08:00
feat(vertex): add Imagen image generation model support
Add support for Imagen 3.0 and 4.0 image generation models in Vertex AI: - Add 5 Imagen model definitions (4.0, 4.0-ultra, 4.0-fast, 3.0, 3.0-fast) - Implement :predict action routing for Imagen models - Convert Imagen request/response format to match Gemini structure like gemini-3-pro-image - Transform prompts to Imagen's instances/parameters format - Convert base64 image responses to Gemini-compatible inline data
This commit is contained in:
@@ -287,6 +287,67 @@ func GetGeminiVertexModels() []*ModelInfo {
|
|||||||
SupportedGenerationMethods: []string{"generateContent", "countTokens", "createCachedContent", "batchGenerateContent"},
|
SupportedGenerationMethods: []string{"generateContent", "countTokens", "createCachedContent", "batchGenerateContent"},
|
||||||
Thinking: &ThinkingSupport{Min: 128, Max: 32768, ZeroAllowed: false, DynamicAllowed: true, Levels: []string{"low", "high"}},
|
Thinking: &ThinkingSupport{Min: 128, Max: 32768, ZeroAllowed: false, DynamicAllowed: true, Levels: []string{"low", "high"}},
|
||||||
},
|
},
|
||||||
|
// Imagen image generation models - use :predict action
|
||||||
|
{
|
||||||
|
ID: "imagen-4.0-generate-001",
|
||||||
|
Object: "model",
|
||||||
|
Created: 1750000000,
|
||||||
|
OwnedBy: "google",
|
||||||
|
Type: "gemini",
|
||||||
|
Name: "models/imagen-4.0-generate-001",
|
||||||
|
Version: "4.0",
|
||||||
|
DisplayName: "Imagen 4.0 Generate",
|
||||||
|
Description: "Imagen 4.0 image generation model",
|
||||||
|
SupportedGenerationMethods: []string{"predict"},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
ID: "imagen-4.0-ultra-generate-001",
|
||||||
|
Object: "model",
|
||||||
|
Created: 1750000000,
|
||||||
|
OwnedBy: "google",
|
||||||
|
Type: "gemini",
|
||||||
|
Name: "models/imagen-4.0-ultra-generate-001",
|
||||||
|
Version: "4.0",
|
||||||
|
DisplayName: "Imagen 4.0 Ultra Generate",
|
||||||
|
Description: "Imagen 4.0 Ultra high-quality image generation model",
|
||||||
|
SupportedGenerationMethods: []string{"predict"},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
ID: "imagen-3.0-generate-002",
|
||||||
|
Object: "model",
|
||||||
|
Created: 1740000000,
|
||||||
|
OwnedBy: "google",
|
||||||
|
Type: "gemini",
|
||||||
|
Name: "models/imagen-3.0-generate-002",
|
||||||
|
Version: "3.0",
|
||||||
|
DisplayName: "Imagen 3.0 Generate",
|
||||||
|
Description: "Imagen 3.0 image generation model",
|
||||||
|
SupportedGenerationMethods: []string{"predict"},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
ID: "imagen-3.0-fast-generate-001",
|
||||||
|
Object: "model",
|
||||||
|
Created: 1740000000,
|
||||||
|
OwnedBy: "google",
|
||||||
|
Type: "gemini",
|
||||||
|
Name: "models/imagen-3.0-fast-generate-001",
|
||||||
|
Version: "3.0",
|
||||||
|
DisplayName: "Imagen 3.0 Fast Generate",
|
||||||
|
Description: "Imagen 3.0 fast image generation model",
|
||||||
|
SupportedGenerationMethods: []string{"predict"},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
ID: "imagen-4.0-fast-generate-001",
|
||||||
|
Object: "model",
|
||||||
|
Created: 1750000000,
|
||||||
|
OwnedBy: "google",
|
||||||
|
Type: "gemini",
|
||||||
|
Name: "models/imagen-4.0-fast-generate-001",
|
||||||
|
Version: "4.0",
|
||||||
|
DisplayName: "Imagen 4.0 Fast Generate",
|
||||||
|
Description: "Imagen 4.0 fast image generation model",
|
||||||
|
SupportedGenerationMethods: []string{"predict"},
|
||||||
|
},
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -12,6 +12,7 @@ import (
|
|||||||
"io"
|
"io"
|
||||||
"net/http"
|
"net/http"
|
||||||
"strings"
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
vertexauth "github.com/router-for-me/CLIProxyAPI/v6/internal/auth/vertex"
|
vertexauth "github.com/router-for-me/CLIProxyAPI/v6/internal/auth/vertex"
|
||||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||||
@@ -31,6 +32,143 @@ const (
|
|||||||
vertexAPIVersion = "v1"
|
vertexAPIVersion = "v1"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// isImagenModel checks if the model name is an Imagen image generation model.
|
||||||
|
// Imagen models use the :predict action instead of :generateContent.
|
||||||
|
func isImagenModel(model string) bool {
|
||||||
|
lowerModel := strings.ToLower(model)
|
||||||
|
return strings.Contains(lowerModel, "imagen")
|
||||||
|
}
|
||||||
|
|
||||||
|
// getVertexAction returns the appropriate action for the given model.
|
||||||
|
// Imagen models use "predict", while Gemini models use "generateContent".
|
||||||
|
func getVertexAction(model string, isStream bool) string {
|
||||||
|
if isImagenModel(model) {
|
||||||
|
return "predict"
|
||||||
|
}
|
||||||
|
if isStream {
|
||||||
|
return "streamGenerateContent"
|
||||||
|
}
|
||||||
|
return "generateContent"
|
||||||
|
}
|
||||||
|
|
||||||
|
// convertImagenToGeminiResponse converts Imagen API response to Gemini format
|
||||||
|
// so it can be processed by the standard translation pipeline.
|
||||||
|
// This ensures Imagen models return responses in the same format as gemini-3-pro-image-preview.
|
||||||
|
func convertImagenToGeminiResponse(data []byte, model string) []byte {
|
||||||
|
predictions := gjson.GetBytes(data, "predictions")
|
||||||
|
if !predictions.Exists() || !predictions.IsArray() {
|
||||||
|
return data
|
||||||
|
}
|
||||||
|
|
||||||
|
// Build Gemini-compatible response with inlineData
|
||||||
|
parts := make([]map[string]any, 0)
|
||||||
|
for _, pred := range predictions.Array() {
|
||||||
|
imageData := pred.Get("bytesBase64Encoded").String()
|
||||||
|
mimeType := pred.Get("mimeType").String()
|
||||||
|
if mimeType == "" {
|
||||||
|
mimeType = "image/png"
|
||||||
|
}
|
||||||
|
if imageData != "" {
|
||||||
|
parts = append(parts, map[string]any{
|
||||||
|
"inlineData": map[string]any{
|
||||||
|
"mimeType": mimeType,
|
||||||
|
"data": imageData,
|
||||||
|
},
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Generate unique response ID using timestamp
|
||||||
|
responseId := fmt.Sprintf("imagen-%d", time.Now().UnixNano())
|
||||||
|
|
||||||
|
response := map[string]any{
|
||||||
|
"candidates": []map[string]any{{
|
||||||
|
"content": map[string]any{
|
||||||
|
"parts": parts,
|
||||||
|
"role": "model",
|
||||||
|
},
|
||||||
|
"finishReason": "STOP",
|
||||||
|
}},
|
||||||
|
"responseId": responseId,
|
||||||
|
"modelVersion": model,
|
||||||
|
// Imagen API doesn't return token counts, set to 0 for tracking purposes
|
||||||
|
"usageMetadata": map[string]any{
|
||||||
|
"promptTokenCount": 0,
|
||||||
|
"candidatesTokenCount": 0,
|
||||||
|
"totalTokenCount": 0,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
result, err := json.Marshal(response)
|
||||||
|
if err != nil {
|
||||||
|
return data
|
||||||
|
}
|
||||||
|
return result
|
||||||
|
}
|
||||||
|
|
||||||
|
// convertToImagenRequest converts a Gemini-style request to Imagen API format.
|
||||||
|
// Imagen API uses a different structure: instances[].prompt instead of contents[].
|
||||||
|
func convertToImagenRequest(payload []byte) ([]byte, error) {
|
||||||
|
// Extract prompt from Gemini-style contents
|
||||||
|
prompt := ""
|
||||||
|
|
||||||
|
// Try to get prompt from contents[0].parts[0].text
|
||||||
|
contentsText := gjson.GetBytes(payload, "contents.0.parts.0.text")
|
||||||
|
if contentsText.Exists() {
|
||||||
|
prompt = contentsText.String()
|
||||||
|
}
|
||||||
|
|
||||||
|
// If no contents, try messages format (OpenAI-compatible)
|
||||||
|
if prompt == "" {
|
||||||
|
messagesText := gjson.GetBytes(payload, "messages.#.content")
|
||||||
|
if messagesText.Exists() && messagesText.IsArray() {
|
||||||
|
for _, msg := range messagesText.Array() {
|
||||||
|
if msg.String() != "" {
|
||||||
|
prompt = msg.String()
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// If still no prompt, try direct prompt field
|
||||||
|
if prompt == "" {
|
||||||
|
directPrompt := gjson.GetBytes(payload, "prompt")
|
||||||
|
if directPrompt.Exists() {
|
||||||
|
prompt = directPrompt.String()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if prompt == "" {
|
||||||
|
return nil, fmt.Errorf("imagen: no prompt found in request")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Build Imagen API request
|
||||||
|
imagenReq := map[string]any{
|
||||||
|
"instances": []map[string]any{
|
||||||
|
{
|
||||||
|
"prompt": prompt,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"parameters": map[string]any{
|
||||||
|
"sampleCount": 1,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
// Extract optional parameters
|
||||||
|
if aspectRatio := gjson.GetBytes(payload, "aspectRatio"); aspectRatio.Exists() {
|
||||||
|
imagenReq["parameters"].(map[string]any)["aspectRatio"] = aspectRatio.String()
|
||||||
|
}
|
||||||
|
if sampleCount := gjson.GetBytes(payload, "sampleCount"); sampleCount.Exists() {
|
||||||
|
imagenReq["parameters"].(map[string]any)["sampleCount"] = int(sampleCount.Int())
|
||||||
|
}
|
||||||
|
if negativePrompt := gjson.GetBytes(payload, "negativePrompt"); negativePrompt.Exists() {
|
||||||
|
imagenReq["instances"].([]map[string]any)[0]["negativePrompt"] = negativePrompt.String()
|
||||||
|
}
|
||||||
|
|
||||||
|
return json.Marshal(imagenReq)
|
||||||
|
}
|
||||||
|
|
||||||
// GeminiVertexExecutor sends requests to Vertex AI Gemini endpoints using service account credentials.
|
// GeminiVertexExecutor sends requests to Vertex AI Gemini endpoints using service account credentials.
|
||||||
type GeminiVertexExecutor struct {
|
type GeminiVertexExecutor struct {
|
||||||
cfg *config.Config
|
cfg *config.Config
|
||||||
@@ -160,26 +298,38 @@ func (e *GeminiVertexExecutor) executeWithServiceAccount(ctx context.Context, au
|
|||||||
reporter := newUsageReporter(ctx, e.Identifier(), baseModel, auth)
|
reporter := newUsageReporter(ctx, e.Identifier(), baseModel, auth)
|
||||||
defer reporter.trackFailure(ctx, &err)
|
defer reporter.trackFailure(ctx, &err)
|
||||||
|
|
||||||
from := opts.SourceFormat
|
var body []byte
|
||||||
to := sdktranslator.FromString("gemini")
|
|
||||||
|
|
||||||
originalPayload := bytes.Clone(req.Payload)
|
// Handle Imagen models with special request format
|
||||||
if len(opts.OriginalRequest) > 0 {
|
if isImagenModel(baseModel) {
|
||||||
originalPayload = bytes.Clone(opts.OriginalRequest)
|
imagenBody, errImagen := convertToImagenRequest(req.Payload)
|
||||||
}
|
if errImagen != nil {
|
||||||
originalTranslated := sdktranslator.TranslateRequest(from, to, baseModel, originalPayload, false)
|
return resp, errImagen
|
||||||
body := sdktranslator.TranslateRequest(from, to, baseModel, bytes.Clone(req.Payload), false)
|
}
|
||||||
|
body = imagenBody
|
||||||
|
} else {
|
||||||
|
// Standard Gemini translation flow
|
||||||
|
from := opts.SourceFormat
|
||||||
|
to := sdktranslator.FromString("gemini")
|
||||||
|
|
||||||
body, err = thinking.ApplyThinking(body, req.Model, from.String(), to.String())
|
originalPayload := bytes.Clone(req.Payload)
|
||||||
if err != nil {
|
if len(opts.OriginalRequest) > 0 {
|
||||||
return resp, err
|
originalPayload = bytes.Clone(opts.OriginalRequest)
|
||||||
|
}
|
||||||
|
originalTranslated := sdktranslator.TranslateRequest(from, to, baseModel, originalPayload, false)
|
||||||
|
body = sdktranslator.TranslateRequest(from, to, baseModel, bytes.Clone(req.Payload), false)
|
||||||
|
|
||||||
|
body, err = thinking.ApplyThinking(body, req.Model, from.String(), to.String())
|
||||||
|
if err != nil {
|
||||||
|
return resp, err
|
||||||
|
}
|
||||||
|
|
||||||
|
body = fixGeminiImageAspectRatio(baseModel, body)
|
||||||
|
body = applyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", body, originalTranslated)
|
||||||
|
body, _ = sjson.SetBytes(body, "model", baseModel)
|
||||||
}
|
}
|
||||||
|
|
||||||
body = fixGeminiImageAspectRatio(baseModel, body)
|
action := getVertexAction(baseModel, false)
|
||||||
body = applyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", body, originalTranslated)
|
|
||||||
body, _ = sjson.SetBytes(body, "model", baseModel)
|
|
||||||
|
|
||||||
action := "generateContent"
|
|
||||||
if req.Metadata != nil {
|
if req.Metadata != nil {
|
||||||
if a, _ := req.Metadata["action"].(string); a == "countTokens" {
|
if a, _ := req.Metadata["action"].(string); a == "countTokens" {
|
||||||
action = "countTokens"
|
action = "countTokens"
|
||||||
@@ -249,6 +399,16 @@ func (e *GeminiVertexExecutor) executeWithServiceAccount(ctx context.Context, au
|
|||||||
}
|
}
|
||||||
appendAPIResponseChunk(ctx, e.cfg, data)
|
appendAPIResponseChunk(ctx, e.cfg, data)
|
||||||
reporter.publish(ctx, parseGeminiUsage(data))
|
reporter.publish(ctx, parseGeminiUsage(data))
|
||||||
|
|
||||||
|
// For Imagen models, convert response to Gemini format before translation
|
||||||
|
// This ensures Imagen responses use the same format as gemini-3-pro-image-preview
|
||||||
|
if isImagenModel(baseModel) {
|
||||||
|
data = convertImagenToGeminiResponse(data, baseModel)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Standard Gemini translation (works for both Gemini and converted Imagen responses)
|
||||||
|
from := opts.SourceFormat
|
||||||
|
to := sdktranslator.FromString("gemini")
|
||||||
var param any
|
var param any
|
||||||
out := sdktranslator.TranslateNonStream(ctx, to, from, req.Model, bytes.Clone(opts.OriginalRequest), body, data, ¶m)
|
out := sdktranslator.TranslateNonStream(ctx, to, from, req.Model, bytes.Clone(opts.OriginalRequest), body, data, ¶m)
|
||||||
resp = cliproxyexecutor.Response{Payload: []byte(out)}
|
resp = cliproxyexecutor.Response{Payload: []byte(out)}
|
||||||
@@ -281,7 +441,7 @@ func (e *GeminiVertexExecutor) executeWithAPIKey(ctx context.Context, auth *clip
|
|||||||
body = applyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", body, originalTranslated)
|
body = applyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", body, originalTranslated)
|
||||||
body, _ = sjson.SetBytes(body, "model", baseModel)
|
body, _ = sjson.SetBytes(body, "model", baseModel)
|
||||||
|
|
||||||
action := "generateContent"
|
action := getVertexAction(baseModel, false)
|
||||||
if req.Metadata != nil {
|
if req.Metadata != nil {
|
||||||
if a, _ := req.Metadata["action"].(string); a == "countTokens" {
|
if a, _ := req.Metadata["action"].(string); a == "countTokens" {
|
||||||
action = "countTokens"
|
action = "countTokens"
|
||||||
@@ -384,12 +544,16 @@ func (e *GeminiVertexExecutor) executeStreamWithServiceAccount(ctx context.Conte
|
|||||||
body = applyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", body, originalTranslated)
|
body = applyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", body, originalTranslated)
|
||||||
body, _ = sjson.SetBytes(body, "model", baseModel)
|
body, _ = sjson.SetBytes(body, "model", baseModel)
|
||||||
|
|
||||||
|
action := getVertexAction(baseModel, true)
|
||||||
baseURL := vertexBaseURL(location)
|
baseURL := vertexBaseURL(location)
|
||||||
url := fmt.Sprintf("%s/%s/projects/%s/locations/%s/publishers/google/models/%s:%s", baseURL, vertexAPIVersion, projectID, location, baseModel, "streamGenerateContent")
|
url := fmt.Sprintf("%s/%s/projects/%s/locations/%s/publishers/google/models/%s:%s", baseURL, vertexAPIVersion, projectID, location, baseModel, action)
|
||||||
if opts.Alt == "" {
|
// Imagen models don't support streaming, skip SSE params
|
||||||
url = url + "?alt=sse"
|
if !isImagenModel(baseModel) {
|
||||||
} else {
|
if opts.Alt == "" {
|
||||||
url = url + fmt.Sprintf("?$alt=%s", opts.Alt)
|
url = url + "?alt=sse"
|
||||||
|
} else {
|
||||||
|
url = url + fmt.Sprintf("?$alt=%s", opts.Alt)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
body, _ = sjson.DeleteBytes(body, "session_id")
|
body, _ = sjson.DeleteBytes(body, "session_id")
|
||||||
|
|
||||||
@@ -503,15 +667,19 @@ func (e *GeminiVertexExecutor) executeStreamWithAPIKey(ctx context.Context, auth
|
|||||||
body = applyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", body, originalTranslated)
|
body = applyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", body, originalTranslated)
|
||||||
body, _ = sjson.SetBytes(body, "model", baseModel)
|
body, _ = sjson.SetBytes(body, "model", baseModel)
|
||||||
|
|
||||||
|
action := getVertexAction(baseModel, true)
|
||||||
// For API key auth, use simpler URL format without project/location
|
// For API key auth, use simpler URL format without project/location
|
||||||
if baseURL == "" {
|
if baseURL == "" {
|
||||||
baseURL = "https://generativelanguage.googleapis.com"
|
baseURL = "https://generativelanguage.googleapis.com"
|
||||||
}
|
}
|
||||||
url := fmt.Sprintf("%s/%s/publishers/google/models/%s:%s", baseURL, vertexAPIVersion, baseModel, "streamGenerateContent")
|
url := fmt.Sprintf("%s/%s/publishers/google/models/%s:%s", baseURL, vertexAPIVersion, baseModel, action)
|
||||||
if opts.Alt == "" {
|
// Imagen models don't support streaming, skip SSE params
|
||||||
url = url + "?alt=sse"
|
if !isImagenModel(baseModel) {
|
||||||
} else {
|
if opts.Alt == "" {
|
||||||
url = url + fmt.Sprintf("?$alt=%s", opts.Alt)
|
url = url + "?alt=sse"
|
||||||
|
} else {
|
||||||
|
url = url + fmt.Sprintf("?$alt=%s", opts.Alt)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
body, _ = sjson.DeleteBytes(body, "session_id")
|
body, _ = sjson.DeleteBytes(body, "session_id")
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user