From 8734d4cb9080706ffc4e980452a5cc44ce5097af Mon Sep 17 00:00:00 2001 From: dinhkarate Date: Tue, 20 Jan 2026 01:26:37 +0700 Subject: [PATCH] feat(vertex): add Imagen image generation model support Add support for Imagen 3.0 and 4.0 image generation models in Vertex AI: - Add 5 Imagen model definitions (4.0, 4.0-ultra, 4.0-fast, 3.0, 3.0-fast) - Implement :predict action routing for Imagen models - Convert Imagen request/response format to match Gemini structure like gemini-3-pro-image - Transform prompts to Imagen's instances/parameters format - Convert base64 image responses to Gemini-compatible inline data --- internal/registry/model_definitions.go | 61 +++++ .../executor/gemini_vertex_executor.go | 222 +++++++++++++++--- 2 files changed, 256 insertions(+), 27 deletions(-) diff --git a/internal/registry/model_definitions.go b/internal/registry/model_definitions.go index 080c2726..1d29bda2 100644 --- a/internal/registry/model_definitions.go +++ b/internal/registry/model_definitions.go @@ -287,6 +287,67 @@ func GetGeminiVertexModels() []*ModelInfo { SupportedGenerationMethods: []string{"generateContent", "countTokens", "createCachedContent", "batchGenerateContent"}, Thinking: &ThinkingSupport{Min: 128, Max: 32768, ZeroAllowed: false, DynamicAllowed: true, Levels: []string{"low", "high"}}, }, + // Imagen image generation models - use :predict action + { + ID: "imagen-4.0-generate-001", + Object: "model", + Created: 1750000000, + OwnedBy: "google", + Type: "gemini", + Name: "models/imagen-4.0-generate-001", + Version: "4.0", + DisplayName: "Imagen 4.0 Generate", + Description: "Imagen 4.0 image generation model", + SupportedGenerationMethods: []string{"predict"}, + }, + { + ID: "imagen-4.0-ultra-generate-001", + Object: "model", + Created: 1750000000, + OwnedBy: "google", + Type: "gemini", + Name: "models/imagen-4.0-ultra-generate-001", + Version: "4.0", + DisplayName: "Imagen 4.0 Ultra Generate", + Description: "Imagen 4.0 Ultra high-quality image generation model", + SupportedGenerationMethods: []string{"predict"}, + }, + { + ID: "imagen-3.0-generate-002", + Object: "model", + Created: 1740000000, + OwnedBy: "google", + Type: "gemini", + Name: "models/imagen-3.0-generate-002", + Version: "3.0", + DisplayName: "Imagen 3.0 Generate", + Description: "Imagen 3.0 image generation model", + SupportedGenerationMethods: []string{"predict"}, + }, + { + ID: "imagen-3.0-fast-generate-001", + Object: "model", + Created: 1740000000, + OwnedBy: "google", + Type: "gemini", + Name: "models/imagen-3.0-fast-generate-001", + Version: "3.0", + DisplayName: "Imagen 3.0 Fast Generate", + Description: "Imagen 3.0 fast image generation model", + SupportedGenerationMethods: []string{"predict"}, + }, + { + ID: "imagen-4.0-fast-generate-001", + Object: "model", + Created: 1750000000, + OwnedBy: "google", + Type: "gemini", + Name: "models/imagen-4.0-fast-generate-001", + Version: "4.0", + DisplayName: "Imagen 4.0 Fast Generate", + Description: "Imagen 4.0 fast image generation model", + SupportedGenerationMethods: []string{"predict"}, + }, } } diff --git a/internal/runtime/executor/gemini_vertex_executor.go b/internal/runtime/executor/gemini_vertex_executor.go index 20e59b3f..1184c07e 100644 --- a/internal/runtime/executor/gemini_vertex_executor.go +++ b/internal/runtime/executor/gemini_vertex_executor.go @@ -12,6 +12,7 @@ import ( "io" "net/http" "strings" + "time" vertexauth "github.com/router-for-me/CLIProxyAPI/v6/internal/auth/vertex" "github.com/router-for-me/CLIProxyAPI/v6/internal/config" @@ -31,6 +32,143 @@ const ( vertexAPIVersion = "v1" ) +// isImagenModel checks if the model name is an Imagen image generation model. +// Imagen models use the :predict action instead of :generateContent. +func isImagenModel(model string) bool { + lowerModel := strings.ToLower(model) + return strings.Contains(lowerModel, "imagen") +} + +// getVertexAction returns the appropriate action for the given model. +// Imagen models use "predict", while Gemini models use "generateContent". +func getVertexAction(model string, isStream bool) string { + if isImagenModel(model) { + return "predict" + } + if isStream { + return "streamGenerateContent" + } + return "generateContent" +} + +// convertImagenToGeminiResponse converts Imagen API response to Gemini format +// so it can be processed by the standard translation pipeline. +// This ensures Imagen models return responses in the same format as gemini-3-pro-image-preview. +func convertImagenToGeminiResponse(data []byte, model string) []byte { + predictions := gjson.GetBytes(data, "predictions") + if !predictions.Exists() || !predictions.IsArray() { + return data + } + + // Build Gemini-compatible response with inlineData + parts := make([]map[string]any, 0) + for _, pred := range predictions.Array() { + imageData := pred.Get("bytesBase64Encoded").String() + mimeType := pred.Get("mimeType").String() + if mimeType == "" { + mimeType = "image/png" + } + if imageData != "" { + parts = append(parts, map[string]any{ + "inlineData": map[string]any{ + "mimeType": mimeType, + "data": imageData, + }, + }) + } + } + + // Generate unique response ID using timestamp + responseId := fmt.Sprintf("imagen-%d", time.Now().UnixNano()) + + response := map[string]any{ + "candidates": []map[string]any{{ + "content": map[string]any{ + "parts": parts, + "role": "model", + }, + "finishReason": "STOP", + }}, + "responseId": responseId, + "modelVersion": model, + // Imagen API doesn't return token counts, set to 0 for tracking purposes + "usageMetadata": map[string]any{ + "promptTokenCount": 0, + "candidatesTokenCount": 0, + "totalTokenCount": 0, + }, + } + + result, err := json.Marshal(response) + if err != nil { + return data + } + return result +} + +// convertToImagenRequest converts a Gemini-style request to Imagen API format. +// Imagen API uses a different structure: instances[].prompt instead of contents[]. +func convertToImagenRequest(payload []byte) ([]byte, error) { + // Extract prompt from Gemini-style contents + prompt := "" + + // Try to get prompt from contents[0].parts[0].text + contentsText := gjson.GetBytes(payload, "contents.0.parts.0.text") + if contentsText.Exists() { + prompt = contentsText.String() + } + + // If no contents, try messages format (OpenAI-compatible) + if prompt == "" { + messagesText := gjson.GetBytes(payload, "messages.#.content") + if messagesText.Exists() && messagesText.IsArray() { + for _, msg := range messagesText.Array() { + if msg.String() != "" { + prompt = msg.String() + break + } + } + } + } + + // If still no prompt, try direct prompt field + if prompt == "" { + directPrompt := gjson.GetBytes(payload, "prompt") + if directPrompt.Exists() { + prompt = directPrompt.String() + } + } + + if prompt == "" { + return nil, fmt.Errorf("imagen: no prompt found in request") + } + + // Build Imagen API request + imagenReq := map[string]any{ + "instances": []map[string]any{ + { + "prompt": prompt, + }, + }, + "parameters": map[string]any{ + "sampleCount": 1, + }, + } + + // Extract optional parameters + if aspectRatio := gjson.GetBytes(payload, "aspectRatio"); aspectRatio.Exists() { + imagenReq["parameters"].(map[string]any)["aspectRatio"] = aspectRatio.String() + } + if sampleCount := gjson.GetBytes(payload, "sampleCount"); sampleCount.Exists() { + imagenReq["parameters"].(map[string]any)["sampleCount"] = int(sampleCount.Int()) + } + if negativePrompt := gjson.GetBytes(payload, "negativePrompt"); negativePrompt.Exists() { + imagenReq["instances"].([]map[string]any)[0]["negativePrompt"] = negativePrompt.String() + } + + return json.Marshal(imagenReq) +} + // GeminiVertexExecutor sends requests to Vertex AI Gemini endpoints using service account credentials. type GeminiVertexExecutor struct { cfg *config.Config @@ -160,26 +298,38 @@ func (e *GeminiVertexExecutor) executeWithServiceAccount(ctx context.Context, au reporter := newUsageReporter(ctx, e.Identifier(), baseModel, auth) defer reporter.trackFailure(ctx, &err) - from := opts.SourceFormat - to := sdktranslator.FromString("gemini") + var body []byte - originalPayload := bytes.Clone(req.Payload) - if len(opts.OriginalRequest) > 0 { - originalPayload = bytes.Clone(opts.OriginalRequest) - } - originalTranslated := sdktranslator.TranslateRequest(from, to, baseModel, originalPayload, false) - body := sdktranslator.TranslateRequest(from, to, baseModel, bytes.Clone(req.Payload), false) + // Handle Imagen models with special request format + if isImagenModel(baseModel) { + imagenBody, errImagen := convertToImagenRequest(req.Payload) + if errImagen != nil { + return resp, errImagen + } + body = imagenBody + } else { + // Standard Gemini translation flow + from := opts.SourceFormat + to := sdktranslator.FromString("gemini") - body, err = thinking.ApplyThinking(body, req.Model, from.String(), to.String()) - if err != nil { - return resp, err + originalPayload := bytes.Clone(req.Payload) + if len(opts.OriginalRequest) > 0 { + originalPayload = bytes.Clone(opts.OriginalRequest) + } + originalTranslated := sdktranslator.TranslateRequest(from, to, baseModel, originalPayload, false) + body = sdktranslator.TranslateRequest(from, to, baseModel, bytes.Clone(req.Payload), false) + + body, err = thinking.ApplyThinking(body, req.Model, from.String(), to.String()) + if err != nil { + return resp, err + } + + body = fixGeminiImageAspectRatio(baseModel, body) + body = applyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", body, originalTranslated) + body, _ = sjson.SetBytes(body, "model", baseModel) } - body = fixGeminiImageAspectRatio(baseModel, body) - body = applyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", body, originalTranslated) - body, _ = sjson.SetBytes(body, "model", baseModel) - - action := "generateContent" + action := getVertexAction(baseModel, false) if req.Metadata != nil { if a, _ := req.Metadata["action"].(string); a == "countTokens" { action = "countTokens" @@ -249,6 +399,16 @@ func (e *GeminiVertexExecutor) executeWithServiceAccount(ctx context.Context, au } appendAPIResponseChunk(ctx, e.cfg, data) reporter.publish(ctx, parseGeminiUsage(data)) + + // For Imagen models, convert response to Gemini format before translation + // This ensures Imagen responses use the same format as gemini-3-pro-image-preview + if isImagenModel(baseModel) { + data = convertImagenToGeminiResponse(data, baseModel) + } + + // Standard Gemini translation (works for both Gemini and converted Imagen responses) + from := opts.SourceFormat + to := sdktranslator.FromString("gemini") var param any out := sdktranslator.TranslateNonStream(ctx, to, from, req.Model, bytes.Clone(opts.OriginalRequest), body, data, ¶m) resp = cliproxyexecutor.Response{Payload: []byte(out)} @@ -281,7 +441,7 @@ func (e *GeminiVertexExecutor) executeWithAPIKey(ctx context.Context, auth *clip body = applyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", body, originalTranslated) body, _ = sjson.SetBytes(body, "model", baseModel) - action := "generateContent" + action := getVertexAction(baseModel, false) if req.Metadata != nil { if a, _ := req.Metadata["action"].(string); a == "countTokens" { action = "countTokens" @@ -384,12 +544,16 @@ func (e *GeminiVertexExecutor) executeStreamWithServiceAccount(ctx context.Conte body = applyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", body, originalTranslated) body, _ = sjson.SetBytes(body, "model", baseModel) + action := getVertexAction(baseModel, true) baseURL := vertexBaseURL(location) - url := fmt.Sprintf("%s/%s/projects/%s/locations/%s/publishers/google/models/%s:%s", baseURL, vertexAPIVersion, projectID, location, baseModel, "streamGenerateContent") - if opts.Alt == "" { - url = url + "?alt=sse" - } else { - url = url + fmt.Sprintf("?$alt=%s", opts.Alt) + url := fmt.Sprintf("%s/%s/projects/%s/locations/%s/publishers/google/models/%s:%s", baseURL, vertexAPIVersion, projectID, location, baseModel, action) + // Imagen models don't support streaming, skip SSE params + if !isImagenModel(baseModel) { + if opts.Alt == "" { + url = url + "?alt=sse" + } else { + url = url + fmt.Sprintf("?$alt=%s", opts.Alt) + } } body, _ = sjson.DeleteBytes(body, "session_id") @@ -503,15 +667,19 @@ func (e *GeminiVertexExecutor) executeStreamWithAPIKey(ctx context.Context, auth body = applyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", body, originalTranslated) body, _ = sjson.SetBytes(body, "model", baseModel) + action := getVertexAction(baseModel, true) // For API key auth, use simpler URL format without project/location if baseURL == "" { baseURL = "https://generativelanguage.googleapis.com" } - url := fmt.Sprintf("%s/%s/publishers/google/models/%s:%s", baseURL, vertexAPIVersion, baseModel, "streamGenerateContent") - if opts.Alt == "" { - url = url + "?alt=sse" - } else { - url = url + fmt.Sprintf("?$alt=%s", opts.Alt) + url := fmt.Sprintf("%s/%s/publishers/google/models/%s:%s", baseURL, vertexAPIVersion, baseModel, action) + // Imagen models don't support streaming, skip SSE params + if !isImagenModel(baseModel) { + if opts.Alt == "" { + url = url + "?alt=sse" + } else { + url = url + fmt.Sprintf("?$alt=%s", opts.Alt) + } } body, _ = sjson.DeleteBytes(body, "session_id")