From 23b45efae9d86c429ee8744dac8bc1bf939448f1 Mon Sep 17 00:00:00 2001 From: Connor <963408438@qq.com> Date: Wed, 14 Jan 2026 02:25:41 +0800 Subject: [PATCH] =?UTF-8?q?=E6=B7=BB=E5=8A=A0chat=20gemini=E3=80=81chatfir?= =?UTF-8?q?e=E7=AB=AF=E7=82=B9=E3=80=81=20=E5=9B=BE=E7=89=87=E7=94=9F?= =?UTF-8?q?=E6=88=90=20gemini=E3=80=81chatfire=20=E6=9B=B4=E8=BD=BB?= =?UTF-8?q?=E6=9D=BE=E7=9A=84AI=E9=85=8D=E7=BD=AE?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .gitignore | 2 + Dockerfile | 13 +- application/services/ai_service.go | 133 +++++++- .../services/image_generation_service.go | 111 +++++-- .../services/script_generation_service.go | 26 +- .../services/video_generation_service.go | 19 +- application/services/video_merge_service.go | 16 +- docker-compose.yml | 4 +- domain/models/ai_config.go | 1 + infrastructure/database/custom_logger.go | 103 +++++++ infrastructure/database/database.go | 3 +- migrations/init.sql | 5 +- pkg/ai/client.go | 7 + pkg/ai/gemini_client.go | 195 ++++++++++++ pkg/ai/openai_client.go | 41 ++- pkg/image/gemini_image_client.go | 277 +++++++++++++++++ pkg/image/image_client.go | 291 ------------------ pkg/image/openai_image_client.go | 128 ++++++++ pkg/image/volcengine_image_client.go | 158 ++++++++++ pkg/video/chatfire_client.go | 184 +++++++++++ pkg/video/openai_sora_client.go | 2 +- web/src/views/settings/AIConfig.vue | 198 +++++++++--- 22 files changed, 1512 insertions(+), 405 deletions(-) create mode 100644 infrastructure/database/custom_logger.go create mode 100644 pkg/ai/client.go create mode 100644 pkg/ai/gemini_client.go create mode 100644 pkg/image/gemini_image_client.go create mode 100644 pkg/image/openai_image_client.go create mode 100644 pkg/image/volcengine_image_client.go create mode 100644 pkg/video/chatfire_client.go diff --git a/.gitignore b/.gitignore index db47673..e7cdb4c 100644 --- a/.gitignore +++ b/.gitignore @@ -63,3 +63,5 @@ configs/config.yaml # Docker publish documentation (optional) DOCKER_PUBLISH.md build.sh +/data/storage/ +/web/package-lock.json diff --git a/Dockerfile b/Dockerfile index 12e72e4..8d03e94 100644 --- a/Dockerfile +++ b/Dockerfile @@ -3,6 +3,9 @@ # ==================== 阶段1: 构建前端 ==================== FROM node:20-alpine AS frontend-builder +# 配置 npm 镜像源(国内加速) +RUN npm config set registry https://registry.npmmirror.com + WORKDIR /app/web # 复制前端依赖文件 @@ -59,15 +62,12 @@ RUN apk add --no-cache \ tzdata \ ffmpeg \ sqlite-libs \ + wget \ && rm -rf /var/cache/apk/* # 设置时区 ENV TZ=Asia/Shanghai -# 创建非 root 用户 -RUN addgroup -g 1000 app && \ - adduser -D -u 1000 -G app app - WORKDIR /app # 从构建阶段复制可执行文件 @@ -83,10 +83,7 @@ RUN cp ./configs/config.example.yaml ./configs/config.yaml # 复制数据库迁移文件 COPY migrations ./migrations/ -# 切换到非 root 用户 -USER app - -# 创建数据目录(在 app 用户下创建,确保权限正确) +# 创建数据目录(root 用户运行,无需权限设置) RUN mkdir -p /app/data/storage # 暴露端口 diff --git a/application/services/ai_service.go b/application/services/ai_service.go index 31d4f6a..2140c2d 100644 --- a/application/services/ai_service.go +++ b/application/services/ai_service.go @@ -25,6 +25,7 @@ func NewAIService(db *gorm.DB, log *logger.Logger) *AIService { type CreateAIConfigRequest struct { ServiceType string `json:"service_type" binding:"required,oneof=text image video"` Name string `json:"name" binding:"required,min=1,max=100"` + Provider string `json:"provider" binding:"required"` BaseURL string `json:"base_url" binding:"required,url"` APIKey string `json:"api_key" binding:"required"` Model models.ModelField `json:"model" binding:"required"` @@ -37,6 +38,7 @@ type CreateAIConfigRequest struct { type UpdateAIConfigRequest struct { Name string `json:"name" binding:"omitempty,min=1,max=100"` + Provider string `json:"provider"` BaseURL string `json:"base_url" binding:"omitempty,url"` APIKey string `json:"api_key"` Model *models.ModelField `json:"model"` @@ -52,18 +54,53 @@ type TestConnectionRequest struct { BaseURL string `json:"base_url" binding:"required,url"` APIKey string `json:"api_key" binding:"required"` Model models.ModelField `json:"model" binding:"required"` + Provider string `json:"provider"` Endpoint string `json:"endpoint"` } func (s *AIService) CreateConfig(req *CreateAIConfigRequest) (*models.AIServiceConfig, error) { + // 根据 provider 和 service_type 自动设置 endpoint + endpoint := req.Endpoint + queryEndpoint := req.QueryEndpoint + + if endpoint == "" { + switch req.Provider { + case "gemini", "google": + if req.ServiceType == "text" { + endpoint = "/v1beta/models/{model}:generateContent" + } else if req.ServiceType == "image" { + endpoint = "/v1beta/models/{model}:generateContent" + } + case "openai", "chatfire": + if req.ServiceType == "text" { + endpoint = "/chat/completions" + } else if req.ServiceType == "image" { + endpoint = "/images/generations" + } else if req.ServiceType == "video" { + endpoint = "/video/generations" + if queryEndpoint == "" { + queryEndpoint = "/v1/video/task/{taskId}" + } + } + default: + // 默认使用 OpenAI 格式 + if req.ServiceType == "text" { + endpoint = "/chat/completions" + } else if req.ServiceType == "image" { + endpoint = "/images/generations" + } + } + } + config := &models.AIServiceConfig{ ServiceType: req.ServiceType, Name: req.Name, + Provider: req.Provider, BaseURL: req.BaseURL, APIKey: req.APIKey, Model: req.Model, - Endpoint: req.Endpoint, - QueryEndpoint: req.QueryEndpoint, + Endpoint: endpoint, + QueryEndpoint: queryEndpoint, Priority: req.Priority, IsDefault: req.IsDefault, IsActive: true, @@ -75,7 +112,7 @@ func (s *AIService) CreateConfig(req *CreateAIConfigRequest) (*models.AIServiceC return nil, err } - s.log.Infow("AI config created", "config_id", config.ID) + s.log.Infow("AI config created", "config_id", config.ID, "provider", req.Provider, "endpoint", endpoint) return config, nil } @@ -125,6 +162,9 @@ func (s *AIService) UpdateConfig(configID uint, req *UpdateAIConfigRequest) (*mo if req.Name != "" { updates["name"] = req.Name } + if req.Provider != "" { + updates["provider"] = req.Provider + } if req.BaseURL != "" { updates["base_url"] = req.BaseURL } @@ -137,9 +177,30 @@ func (s *AIService) UpdateConfig(configID uint, req *UpdateAIConfigRequest) (*mo if req.Priority != nil { updates["priority"] = *req.Priority } - if req.Endpoint != "" { + + // 如果提供了 provider,根据 provider 和 service_type 自动设置 endpoint + if req.Provider != "" && req.Endpoint == "" { + provider := req.Provider + serviceType := config.ServiceType + + switch provider { + case "gemini", "google": + if serviceType == "text" || serviceType == "image" { + updates["endpoint"] = "/v1beta/models/{model}:generateContent" + } + case "openai", "chatfire": + if serviceType == "text" { + updates["endpoint"] = "/chat/completions" + } else if serviceType == "image" { + updates["endpoint"] = "/images/generations" + } else if serviceType == "video" { + updates["endpoint"] = "/video/generations" + } + } + } else if req.Endpoint != "" { updates["endpoint"] = req.Endpoint } + // 允许清空query_endpoint,所以不检查是否为空 updates["query_endpoint"] = req.QueryEndpoint if req.Settings != "" { @@ -179,13 +240,51 @@ func (s *AIService) DeleteConfig(configID uint) error { } func (s *AIService) TestConnection(req *TestConnectionRequest) error { + s.log.Infow("TestConnection called", "baseURL", req.BaseURL, "provider", req.Provider, "endpoint", req.Endpoint, "modelCount", len(req.Model)) + // 使用第一个模型进行测试 model := "" if len(req.Model) > 0 { model = req.Model[0] } - client := ai.NewOpenAIClient(req.BaseURL, req.APIKey, model, req.Endpoint) - return client.TestConnection() + s.log.Infow("Using model for test", "model", model, "provider", req.Provider) + + // 根据 provider 参数选择客户端 + var client ai.AIClient + var endpoint string + + switch req.Provider { + case "gemini", "google": + // Gemini + s.log.Infow("Using Gemini client", "baseURL", req.BaseURL) + endpoint = "/v1beta/models/{model}:generateContent" + client = ai.NewGeminiClient(req.BaseURL, req.APIKey, model, endpoint) + case "openai", "chatfire": + // OpenAI 格式(包括 chatfire 等) + s.log.Infow("Using OpenAI-compatible client", "baseURL", req.BaseURL, "provider", req.Provider) + endpoint = req.Endpoint + if endpoint == "" { + endpoint = "/chat/completions" + } + client = ai.NewOpenAIClient(req.BaseURL, req.APIKey, model, endpoint) + default: + // 默认使用 OpenAI 格式 + s.log.Infow("Using default OpenAI-compatible client", "baseURL", req.BaseURL) + endpoint = req.Endpoint + if endpoint == "" { + endpoint = "/chat/completions" + } + client = ai.NewOpenAIClient(req.BaseURL, req.APIKey, model, endpoint) + } + + s.log.Infow("Calling TestConnection on client", "endpoint", endpoint) + err := client.TestConnection() + if err != nil { + s.log.Errorw("TestConnection failed", "error", err) + } else { + s.log.Infow("TestConnection succeeded") + } + return err } func (s *AIService) GetDefaultConfig(serviceType string) (*models.AIServiceConfig, error) { @@ -228,7 +327,7 @@ func (s *AIService) GetConfigForModel(serviceType string, modelName string) (*mo return nil, errors.New("no config found for model: " + modelName) } -func (s *AIService) GetAIClient(serviceType string) (*ai.OpenAIClient, error) { +func (s *AIService) GetAIClient(serviceType string) (ai.AIClient, error) { config, err := s.GetDefaultConfig(serviceType) if err != nil { return nil, err @@ -240,7 +339,25 @@ func (s *AIService) GetAIClient(serviceType string) (*ai.OpenAIClient, error) { model = config.Model[0] } - return ai.NewOpenAIClient(config.BaseURL, config.APIKey, model, config.Endpoint), nil + // 使用数据库配置中的 endpoint,如果为空则根据 provider 设置默认值 + endpoint := config.Endpoint + if endpoint == "" { + switch config.Provider { + case "gemini", "google": + endpoint = "/v1beta/models/{model}:generateContent" + default: + endpoint = "/chat/completions" + } + } + + // 根据 provider 创建对应的客户端 + switch config.Provider { + case "gemini", "google": + return ai.NewGeminiClient(config.BaseURL, config.APIKey, model, endpoint), nil + default: + // openai, chatfire 等其他厂商都使用 OpenAI 格式 + return ai.NewOpenAIClient(config.BaseURL, config.APIKey, model, endpoint), nil + } } func (s *AIService) GenerateText(prompt string, systemPrompt string, options ...func(*ai.ChatCompletionRequest)) (string, error) { diff --git a/application/services/image_generation_service.go b/application/services/image_generation_service.go index 0368fdf..a67ad6f 100644 --- a/application/services/image_generation_service.go +++ b/application/services/image_generation_service.go @@ -4,6 +4,7 @@ import ( "encoding/json" "fmt" "strconv" + "strings" "time" models "github.com/drama-generator/backend/domain/models" @@ -23,6 +24,24 @@ type ImageGenerationService struct { log *logger.Logger } +// truncateImageURL 截断图片 URL,避免 base64 格式的 URL 占满日志 +func truncateImageURL(url string) string { + if url == "" { + return "" + } + // 如果是 data URI 格式(base64),只显示前缀 + if strings.HasPrefix(url, "data:") { + if len(url) > 50 { + return url[:50] + "...[base64 data]" + } + } + // 普通 URL 如果过长也截断 + if len(url) > 100 { + return url[:100] + "..." + } + return url +} + func NewImageGenerationService(db *gorm.DB, transferService *ResourceTransferService, localStorage *storage.LocalStorage, log *logger.Logger) *ImageGenerationService { return &ImageGenerationService{ db: db, @@ -246,17 +265,23 @@ func (s *ImageGenerationService) completeImageGeneration(imageGenID uint, result now := time.Now() // 下载图片到本地存储(仅用于缓存,不更新数据库) - if s.localStorage != nil && result.ImageURL != "" { + // 仅下载 HTTP/HTTPS URL,跳过 data URI + if s.localStorage != nil && result.ImageURL != "" && + (strings.HasPrefix(result.ImageURL, "http://") || strings.HasPrefix(result.ImageURL, "https://")) { _, err := s.localStorage.DownloadFromURL(result.ImageURL, "images") if err != nil { + errStr := err.Error() + if len(errStr) > 200 { + errStr = errStr[:200] + "..." + } s.log.Warnw("Failed to download image to local storage", - "error", err, + "error", errStr, "id", imageGenID, - "original_url", result.ImageURL) + "original_url", truncateImageURL(result.ImageURL)) } else { s.log.Infow("Image downloaded to local storage for caching", "id", imageGenID, - "original_url", result.ImageURL) + "original_url", truncateImageURL(result.ImageURL)) } } @@ -291,7 +316,7 @@ func (s *ImageGenerationService) completeImageGeneration(imageGenID uint, result } else { s.log.Infow("Storyboard updated with composed image", "storyboard_id", *imageGen.StoryboardID, - "composed_image", result.ImageURL) + "composed_image", truncateImageURL(result.ImageURL)) } } @@ -306,7 +331,7 @@ func (s *ImageGenerationService) completeImageGeneration(imageGenID uint, result } else { s.log.Infow("Scene updated with generated image", "scene_id", *imageGen.SceneID, - "image_url", result.ImageURL) + "image_url", truncateImageURL(result.ImageURL)) } } @@ -317,7 +342,7 @@ func (s *ImageGenerationService) completeImageGeneration(imageGenID uint, result } else { s.log.Infow("Character updated with generated image", "character_id", *imageGen.CharacterID, - "image_url", result.ImageURL) + "image_url", truncateImageURL(result.ImageURL)) } } } @@ -356,13 +381,33 @@ func (s *ImageGenerationService) getImageClient(provider string) (image.ImageCli model = config.Model[0] } - switch provider { + // 使用配置中的 provider,如果没有则使用传入的 provider + actualProvider := config.Provider + if actualProvider == "" { + actualProvider = provider + } + + // 根据 provider 自动设置默认端点 + var endpoint string + var queryEndpoint string + + switch actualProvider { case "openai", "dalle": - return image.NewOpenAIImageClient(config.BaseURL, config.APIKey, model), nil - case "stable_diffusion", "sd": - return image.NewStableDiffusionClient(config.BaseURL, config.APIKey, model), nil + endpoint = "/images/generations" + return image.NewOpenAIImageClient(config.BaseURL, config.APIKey, model, endpoint), nil + case "chatfire": + endpoint = "/images/generations" + return image.NewOpenAIImageClient(config.BaseURL, config.APIKey, model, endpoint), nil + case "volcengine", "volces", "doubao": + endpoint = "/images/generations" + queryEndpoint = "" + return image.NewVolcEngineImageClient(config.BaseURL, config.APIKey, model, endpoint, queryEndpoint), nil + case "gemini", "google": + endpoint = "/v1beta/models/{model}:generateContent" + return image.NewGeminiImageClient(config.BaseURL, config.APIKey, model, endpoint), nil default: - return image.NewOpenAIImageClient(config.BaseURL, config.APIKey, model), nil + endpoint = "/images/generations" + return image.NewOpenAIImageClient(config.BaseURL, config.APIKey, model, endpoint), nil } } @@ -394,13 +439,33 @@ func (s *ImageGenerationService) getImageClientWithModel(provider string, modelN model = config.Model[0] } - switch provider { + // 使用配置中的 provider,如果没有则使用传入的 provider + actualProvider := config.Provider + if actualProvider == "" { + actualProvider = provider + } + + // 根据 provider 自动设置默认端点 + var endpoint string + var queryEndpoint string + + switch actualProvider { case "openai", "dalle": - return image.NewOpenAIImageClient(config.BaseURL, config.APIKey, model), nil - case "stable_diffusion", "sd": - return image.NewStableDiffusionClient(config.BaseURL, config.APIKey, model), nil + endpoint = "/images/generations" + return image.NewOpenAIImageClient(config.BaseURL, config.APIKey, model, endpoint), nil + case "chatfire": + endpoint = "/images/generations" + return image.NewOpenAIImageClient(config.BaseURL, config.APIKey, model, endpoint), nil + case "volcengine", "volces", "doubao": + endpoint = "/images/generations" + queryEndpoint = "" + return image.NewVolcEngineImageClient(config.BaseURL, config.APIKey, model, endpoint, queryEndpoint), nil + case "gemini", "google": + endpoint = "/v1beta/models/{model}:generateContent" + return image.NewGeminiImageClient(config.BaseURL, config.APIKey, model, endpoint), nil default: - return image.NewOpenAIImageClient(config.BaseURL, config.APIKey, model), nil + endpoint = "/images/generations" + return image.NewOpenAIImageClient(config.BaseURL, config.APIKey, model, endpoint), nil } } @@ -716,21 +781,11 @@ func (s *ImageGenerationService) extractBackgroundsFromScript(scriptContent stri 请严格按照JSON格式输出,确保所有字段都使用中文。`, scriptContent) - messages := []ai.ChatMessage{ - {Role: "user", Content: prompt}, - } - - resp, err := client.ChatCompletion(messages, ai.WithTemperature(0.7), ai.WithMaxTokens(8000)) + response, err := client.GenerateText(prompt, "", ai.WithTemperature(0.7), ai.WithMaxTokens(8000)) if err != nil { s.log.Errorw("Failed to extract backgrounds with AI", "error", err) return nil, fmt.Errorf("AI提取场景失败: %w", err) } - - if len(resp.Choices) == 0 { - return nil, fmt.Errorf("AI未返回有效响应") - } - - response := resp.Choices[0].Message.Content s.log.Infow("AI backgrounds extraction response", "length", len(response)) // 解析JSON响应 diff --git a/application/services/script_generation_service.go b/application/services/script_generation_service.go index 8b45b19..8a0479c 100644 --- a/application/services/script_generation_service.go +++ b/application/services/script_generation_service.go @@ -185,17 +185,17 @@ func (s *ScriptGenerationService) GenerateCharacters(req *GenerateCharactersRequ count = 5 } - systemPrompt := `你是一个专业的角色设计师,擅长创作立体丰富的剧中角色。 + systemPrompt := `你是一个专业的角色分析师,擅长从剧本中提取和分析角色信息。 -你的任务是根据提供的剧本大纲,创作符合故事需求的角色设定。 +你的任务是根据提供的剧本内容,提取并整理剧中出现的所有角色的详细设定。 要求: -1. 角色必须服务于大纲中的故事情节和冲突 -2. 角色性格鲜明,有辨识度,符合故事类型 -3. 每个角色都有清晰的动机和目标,与大纲中的矛盾冲突相关 -4. 角色之间有合理的关系和联系 -5. 外貌描述必须极其详细,便于AI绘画生成角色形象 -6. 根据大纲的关键场景,合理设置角色数量(通常3-6个主要角色) +1. 仔细阅读剧本,识别所有出现的角色 +2. 根据剧本中的对话、行为和描述,总结角色的性格特点 +3. 提取角色在剧本中的关键信息:背景、动机、目标、关系等 +4. 角色之间的关系必须基于剧本中的实际描述 +5. 外貌描述必须极其详细,如果剧本中有描述则使用,如果没有则根据角色设定合理推断,便于AI绘画生成角色形象 +6. 优先提取主要角色和重要配角,次要角色可以简略 请严格按照以下 JSON 格式输出,不要添加任何其他文字: @@ -213,21 +213,21 @@ func (s *ScriptGenerationService) GenerateCharacters(req *GenerateCharactersRequ } 注意: -- 角色数量根据故事复杂度确定,不要过多 -- 每个角色都要与大纲中的故事线有明确关联 +- 必须基于剧本内容提取角色,不要凭空创作 +- 优先提取主要角色和重要配角,数量根据剧本实际情况确定 - description、personality、appearance、voice_style都必须详细描述,字数要充足 - appearance外貌描述是重中之重,必须极其详细具体,要能让AI准确生成角色形象 -- 避免模糊描述,多用具体的视觉特征和细节` +- 如果剧本中角色信息不完整,可以根据角色设定合理补充,但要符合剧本整体风格` outlineText := req.Outline if outlineText == "" { outlineText = fmt.Sprintf("剧名:%s\n简介:%s\n类型:%s", drama.Title, drama.Description, drama.Genre) } - userPrompt := fmt.Sprintf(`剧本大纲: + userPrompt := fmt.Sprintf(`剧本内容: %s -请创作 %d 个角色的详细设定。`, outlineText, count) +请从剧本中提取并整理最多 %d 个主要角色的详细设定。`, outlineText, count) temperature := req.Temperature if temperature == 0 { diff --git a/application/services/video_generation_service.go b/application/services/video_generation_service.go index 92e8c7a..ac99469 100644 --- a/application/services/video_generation_service.go +++ b/application/services/video_generation_service.go @@ -414,24 +414,33 @@ func (s *VideoGenerationService) getVideoClient(provider string, modelName strin // 使用配置中的信息创建客户端 baseURL := config.BaseURL apiKey := config.APIKey - endpoint := config.Endpoint - queryEndpoint := config.QueryEndpoint model := modelName if model == "" && len(config.Model) > 0 { model = config.Model[0] } + // 根据 provider 自动设置默认端点 + var endpoint string + var queryEndpoint string + switch provider { - case "doubao": + case "chatfire": + endpoint = "/video/generations" + queryEndpoint = "/v1/video/task/{taskId}" + return video.NewChatfireClient(baseURL, apiKey, model, endpoint, queryEndpoint), nil + case "doubao", "volcengine", "volces": + endpoint = "/contents/generations/tasks" + queryEndpoint = "/generations/tasks/{taskId}" return video.NewVolcesArkClient(baseURL, apiKey, model, endpoint, queryEndpoint), nil + case "openai": + // OpenAI Sora 使用 /v1/videos 端点 + return video.NewOpenAISoraClient(baseURL, apiKey, model), nil case "runway": return video.NewRunwayClient(baseURL, apiKey, model), nil case "pika": return video.NewPikaClient(baseURL, apiKey, model), nil case "minimax": return video.NewMinimaxClient(baseURL, apiKey, model), nil - case "openai": - return video.NewOpenAISoraClient(baseURL, apiKey, model), nil default: return nil, fmt.Errorf("unsupported video provider: %s", provider) } diff --git a/application/services/video_merge_service.go b/application/services/video_merge_service.go index 0cabcf0..3b8b21e 100644 --- a/application/services/video_merge_service.go +++ b/application/services/video_merge_service.go @@ -297,6 +297,10 @@ func (s *VideoMergeService) getVideoClient(provider string) (video.VideoClient, model = config.Model[0] } + // 根据 provider 自动设置默认端点 + var endpoint string + var queryEndpoint string + switch provider { case "runway": return video.NewRunwayClient(config.BaseURL, config.APIKey, model), nil @@ -306,10 +310,18 @@ func (s *VideoMergeService) getVideoClient(provider string) (video.VideoClient, return video.NewOpenAISoraClient(config.BaseURL, config.APIKey, model), nil case "minimax": return video.NewMinimaxClient(config.BaseURL, config.APIKey, model), nil + case "chatfire": + endpoint = "/video/generations" + queryEndpoint = "/v1/video/task/{taskId}" + return video.NewChatfireClient(config.BaseURL, config.APIKey, model, endpoint, queryEndpoint), nil case "doubao", "volces", "ark": - return video.NewVolcesArkClient(config.BaseURL, config.APIKey, model, config.Endpoint, config.QueryEndpoint), nil + endpoint = "/contents/generations/tasks" + queryEndpoint = "/generations/tasks/{taskId}" + return video.NewVolcesArkClient(config.BaseURL, config.APIKey, model, endpoint, queryEndpoint), nil default: - return video.NewVolcesArkClient(config.BaseURL, config.APIKey, model, config.Endpoint, config.QueryEndpoint), nil + endpoint = "/contents/generations/tasks" + queryEndpoint = "/generations/tasks/{taskId}" + return video.NewVolcesArkClient(config.BaseURL, config.APIKey, model, endpoint, queryEndpoint), nil } } diff --git a/docker-compose.yml b/docker-compose.yml index 30c2c0b..b7ecdf2 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -8,10 +8,12 @@ services: ports: - "5678:5678" volumes: - # 持久化数据目录(使用命名卷) + # 持久化数据目录(使用命名卷,容器内以 root 运行) - huobao-data:/app/data # 挂载配置文件(可选,如需自定义配置请取消注释) # - ./configs/config.yaml:/app/configs/config.yaml:ro + # 注意:如果使用本地目录挂载,需要确保目录权限正确 + # 例如:- ./data:/app/data (需要 chmod 777 ./data) environment: - TZ=Asia/Shanghai restart: unless-stopped diff --git a/domain/models/ai_config.go b/domain/models/ai_config.go index 2aade18..3236b6a 100644 --- a/domain/models/ai_config.go +++ b/domain/models/ai_config.go @@ -10,6 +10,7 @@ import ( type AIServiceConfig struct { ID uint `gorm:"primaryKey;autoIncrement" json:"id"` ServiceType string `gorm:"type:varchar(50);not null" json:"service_type"` // text, image, video + Provider string `gorm:"type:varchar(50)" json:"provider"` // openai, gemini, volcengine, etc. Name string `gorm:"type:varchar(100);not null" json:"name"` BaseURL string `gorm:"type:varchar(255);not null" json:"base_url"` APIKey string `gorm:"type:varchar(255);not null" json:"api_key"` diff --git a/infrastructure/database/custom_logger.go b/infrastructure/database/custom_logger.go new file mode 100644 index 0000000..7496e1b --- /dev/null +++ b/infrastructure/database/custom_logger.go @@ -0,0 +1,103 @@ +package database + +import ( + "context" + "strings" + "time" + + "gorm.io/gorm/logger" +) + +// CustomLogger 自定义 GORM logger,截断过长的 SQL 参数(如 base64 数据) +type CustomLogger struct { + logger.Interface +} + +// NewCustomLogger 创建自定义 logger +func NewCustomLogger() logger.Interface { + return &CustomLogger{ + Interface: logger.Default.LogMode(logger.Silent), + } +} + +// Trace 重写 Trace 方法,禁用 SQL 日志输出 +func (l *CustomLogger) Trace(ctx context.Context, begin time.Time, fc func() (sql string, rowsAffected int64), err error) { + // 不输出任何 SQL 日志 + // 如果需要调试,可以临时取消注释下面的代码 + /* + sql, rows := fc() + sql = truncateLongValues(sql) + elapsed := time.Since(begin) + if err != nil { + l.Interface.Error(ctx, "SQL error: %v [%v] %s", err, elapsed, sql) + } else { + l.Interface.Info(ctx, "[%.3fms] [rows:%d] %s", float64(elapsed.Nanoseconds())/1e6, rows, sql) + } + */ +} + +// truncateLongValues 截断 SQL 中的长字符串值 +func truncateLongValues(sql string) string { + // 查找 base64 格式的数据 (data:image/...;base64,...) + if strings.Contains(sql, "data:image/") && strings.Contains(sql, ";base64,") { + parts := strings.Split(sql, "\"") + for i, part := range parts { + if strings.HasPrefix(part, "data:image/") && strings.Contains(part, ";base64,") { + if len(part) > 100 { + // 保留前50字符,添加截断标记 + parts[i] = part[:50] + "...[base64 data truncated]" + } + } + } + sql = strings.Join(parts, "\"") + } + + // 截断其他过长的值 + if len(sql) > 5000 { + // 查找 VALUES 或 SET 后的内容 + if idx := strings.Index(sql, " VALUES "); idx > 0 && len(sql) > idx+5000 { + sql = sql[:idx+5000] + "...[truncated]" + } else if idx := strings.Index(sql, " SET "); idx > 0 && len(sql) > idx+3000 { + sql = sql[:idx+3000] + "...[truncated]" + } else if len(sql) > 5000 { + sql = sql[:5000] + "...[truncated]" + } + } + + return sql +} + +// Info 实现 Info 方法 +func (l *CustomLogger) Info(ctx context.Context, msg string, data ...interface{}) { + l.Interface.Info(ctx, msg, data...) +} + +// Warn 实现 Warn 方法 +func (l *CustomLogger) Warn(ctx context.Context, msg string, data ...interface{}) { + l.Interface.Warn(ctx, msg, data...) +} + +// Error 实现 Error 方法 +func (l *CustomLogger) Error(ctx context.Context, msg string, data ...interface{}) { + // 检查并截断 data 中的长字符串 + truncatedData := make([]interface{}, len(data)) + for i, d := range data { + if str, ok := d.(string); ok && len(str) > 200 { + if strings.HasPrefix(str, "data:image/") { + truncatedData[i] = str[:50] + "...[base64 data]" + } else { + truncatedData[i] = str[:200] + "..." + } + } else { + truncatedData[i] = d + } + } + l.Interface.Error(ctx, msg, truncatedData...) +} + +// LogMode 实现 LogMode 方法 +func (l *CustomLogger) LogMode(level logger.LogLevel) logger.Interface { + newLogger := *l + newLogger.Interface = l.Interface.LogMode(level) + return &newLogger +} diff --git a/infrastructure/database/database.go b/infrastructure/database/database.go index 7655c02..bd29d06 100644 --- a/infrastructure/database/database.go +++ b/infrastructure/database/database.go @@ -11,7 +11,6 @@ import ( "gorm.io/driver/mysql" "gorm.io/driver/sqlite" "gorm.io/gorm" - "gorm.io/gorm/logger" ) func NewDatabase(cfg config.DatabaseConfig) (*gorm.DB, error) { @@ -25,7 +24,7 @@ func NewDatabase(cfg config.DatabaseConfig) (*gorm.DB, error) { } gormConfig := &gorm.Config{ - Logger: logger.Default.LogMode(logger.Info), + Logger: NewCustomLogger(), } var db *gorm.DB diff --git a/migrations/init.sql b/migrations/init.sql index e27b564..3662adb 100644 --- a/migrations/init.sql +++ b/migrations/init.sql @@ -445,12 +445,14 @@ CREATE INDEX IF NOT EXISTS idx_asset_collection_relations_collection_id ON asset CREATE TABLE IF NOT EXISTS ai_service_configs ( id INTEGER PRIMARY KEY AUTOINCREMENT, service_type TEXT NOT NULL, -- text, image, video + provider TEXT, -- openai, gemini, volcengine, etc. name TEXT NOT NULL, base_url TEXT NOT NULL, api_key TEXT NOT NULL, model TEXT, endpoint TEXT, query_endpoint TEXT, + priority INTEGER NOT NULL DEFAULT 0, is_default INTEGER NOT NULL DEFAULT 0, is_active INTEGER NOT NULL DEFAULT 1, settings TEXT, -- JSON存储 @@ -489,7 +491,8 @@ INSERT OR IGNORE INTO ai_service_providers (name, display_name, service_type, de ('openai-dalle', 'OpenAI DALL-E', 'image', 'https://api.openai.com/v1', 'OpenAI DALL-E图片生成'), ('openai-sora', 'OpenAI Sora', 'video', 'https://api.openai.com/v1', 'OpenAI Sora视频生成'), ('midjourney', 'Midjourney', 'image', '', 'Midjourney图片生成'), -('stable-diffusion', 'Stable Diffusion', 'image', '', 'Stable Diffusion图片生成'), +('doubao-image', '豆包(火山引擎)', 'image', 'https://ark.cn-beijing.volces.com', '火山引擎豆包图片生成'), +('gemini-image', 'Google Gemini', 'image', 'https://generativelanguage.googleapis.com', 'Google Gemini原生图片生成(base64)'), ('runway', 'Runway', 'video', '', 'Runway视频生成'), ('pika', 'Pika Labs', 'video', '', 'Pika视频生成'), ('doubao', '豆包(火山引擎)', 'video', 'https://ark.cn-beijing.volces.com', '火山引擎豆包视频生成'), diff --git a/pkg/ai/client.go b/pkg/ai/client.go new file mode 100644 index 0000000..81da6ed --- /dev/null +++ b/pkg/ai/client.go @@ -0,0 +1,7 @@ +package ai + +// AIClient 定义文本生成客户端接口 +type AIClient interface { + GenerateText(prompt string, systemPrompt string, options ...func(*ChatCompletionRequest)) (string, error) + TestConnection() error +} diff --git a/pkg/ai/gemini_client.go b/pkg/ai/gemini_client.go new file mode 100644 index 0000000..348387f --- /dev/null +++ b/pkg/ai/gemini_client.go @@ -0,0 +1,195 @@ +package ai + +import ( + "bytes" + "encoding/json" + "fmt" + "io" + "net/http" + "strings" + "time" +) + +type GeminiClient struct { + BaseURL string + APIKey string + Model string + Endpoint string + HTTPClient *http.Client +} + +type GeminiTextRequest struct { + Contents []GeminiContent `json:"contents"` + SystemInstruction *GeminiInstruction `json:"systemInstruction,omitempty"` +} + +type GeminiContent struct { + Parts []GeminiPart `json:"parts"` + Role string `json:"role,omitempty"` +} + +type GeminiPart struct { + Text string `json:"text"` +} + +type GeminiInstruction struct { + Parts []GeminiPart `json:"parts"` +} + +type GeminiTextResponse struct { + Candidates []struct { + Content struct { + Parts []struct { + Text string `json:"text"` + } `json:"parts"` + Role string `json:"role"` + } `json:"content"` + FinishReason string `json:"finishReason"` + Index int `json:"index"` + SafetyRatings []struct { + Category string `json:"category"` + Probability string `json:"probability"` + } `json:"safetyRatings"` + } `json:"candidates"` + UsageMetadata struct { + PromptTokenCount int `json:"promptTokenCount"` + CandidatesTokenCount int `json:"candidatesTokenCount"` + TotalTokenCount int `json:"totalTokenCount"` + } `json:"usageMetadata"` +} + +func NewGeminiClient(baseURL, apiKey, model, endpoint string) *GeminiClient { + if baseURL == "" { + baseURL = "https://generativelanguage.googleapis.com" + } + if endpoint == "" { + endpoint = "/v1beta/models/{model}:generateContent" + } + if model == "" { + model = "gemini-3-pro" + } + return &GeminiClient{ + BaseURL: baseURL, + APIKey: apiKey, + Model: model, + Endpoint: endpoint, + HTTPClient: &http.Client{ + Timeout: 10 * time.Minute, + }, + } +} + +func (c *GeminiClient) GenerateText(prompt string, systemPrompt string, options ...func(*ChatCompletionRequest)) (string, error) { + model := c.Model + + // 构建请求体 + reqBody := GeminiTextRequest{ + Contents: []GeminiContent{ + { + Parts: []GeminiPart{{Text: prompt}}, + Role: "user", + }, + }, + } + + // 使用 systemInstruction 字段处理系统提示 + if systemPrompt != "" { + reqBody.SystemInstruction = &GeminiInstruction{ + Parts: []GeminiPart{{Text: systemPrompt}}, + } + } + + jsonData, err := json.Marshal(reqBody) + if err != nil { + fmt.Printf("Gemini: Failed to marshal request: %v\n", err) + return "", fmt.Errorf("marshal request: %w", err) + } + + // 替换端点中的 {model} 占位符 + endpoint := c.BaseURL + c.Endpoint + endpoint = strings.ReplaceAll(endpoint, "{model}", model) + url := fmt.Sprintf("%s?key=%s", endpoint, c.APIKey) + + // 打印请求信息(隐藏 API Key) + safeURL := strings.Replace(url, c.APIKey, "***", 1) + fmt.Printf("Gemini: Sending request to: %s\n", safeURL) + requestPreview := string(jsonData) + if len(jsonData) > 300 { + requestPreview = string(jsonData[:300]) + "..." + } + fmt.Printf("Gemini: Request body: %s\n", requestPreview) + + req, err := http.NewRequest("POST", url, bytes.NewBuffer(jsonData)) + if err != nil { + fmt.Printf("Gemini: Failed to create request: %v\n", err) + return "", fmt.Errorf("create request: %w", err) + } + + req.Header.Set("Content-Type", "application/json") + + fmt.Printf("Gemini: Executing HTTP request...\n") + resp, err := c.HTTPClient.Do(req) + if err != nil { + fmt.Printf("Gemini: HTTP request failed: %v\n", err) + return "", fmt.Errorf("send request: %w", err) + } + defer resp.Body.Close() + + fmt.Printf("Gemini: Received response with status: %d\n", resp.StatusCode) + + body, err := io.ReadAll(resp.Body) + if err != nil { + fmt.Printf("Gemini: Failed to read response body: %v\n", err) + return "", fmt.Errorf("read response: %w", err) + } + + if resp.StatusCode != http.StatusOK { + fmt.Printf("Gemini: API error (status %d): %s\n", resp.StatusCode, string(body)) + return "", fmt.Errorf("API error (status %d): %s", resp.StatusCode, string(body)) + } + + // 打印响应体用于调试 + bodyPreview := string(body) + if len(body) > 500 { + bodyPreview = string(body[:500]) + "..." + } + fmt.Printf("Gemini: Response body: %s\n", bodyPreview) + + var result GeminiTextResponse + if err := json.Unmarshal(body, &result); err != nil { + errorPreview := string(body) + if len(body) > 200 { + errorPreview = string(body[:200]) + } + fmt.Printf("Gemini: Failed to parse response: %v\n", err) + return "", fmt.Errorf("parse response: %w, body preview: %s", err, errorPreview) + } + + fmt.Printf("Gemini: Successfully parsed response, candidates count: %d\n", len(result.Candidates)) + + if len(result.Candidates) == 0 { + fmt.Printf("Gemini: No candidates in response\n") + return "", fmt.Errorf("no candidates in response") + } + + if len(result.Candidates[0].Content.Parts) == 0 { + fmt.Printf("Gemini: No parts in first candidate\n") + return "", fmt.Errorf("no parts in response") + } + + responseText := result.Candidates[0].Content.Parts[0].Text + fmt.Printf("Gemini: Generated text: %s\n", responseText) + + return responseText, nil +} + +func (c *GeminiClient) TestConnection() error { + fmt.Printf("Gemini: TestConnection called with BaseURL=%s, Model=%s, Endpoint=%s\n", c.BaseURL, c.Model, c.Endpoint) + _, err := c.GenerateText("Hello", "") + if err != nil { + fmt.Printf("Gemini: TestConnection failed: %v\n", err) + } else { + fmt.Printf("Gemini: TestConnection succeeded\n") + } + return err +} diff --git a/pkg/ai/openai_client.go b/pkg/ai/openai_client.go index 987b0ef..1264e1d 100644 --- a/pkg/ai/openai_client.go +++ b/pkg/ai/openai_client.go @@ -91,30 +91,48 @@ func (c *OpenAIClient) ChatCompletion(messages []ChatMessage, options ...func(*C func (c *OpenAIClient) sendChatRequest(req *ChatCompletionRequest) (*ChatCompletionResponse, error) { jsonData, err := json.Marshal(req) if err != nil { + fmt.Printf("OpenAI: Failed to marshal request: %v\n", err) return nil, fmt.Errorf("failed to marshal request: %w", err) } url := c.BaseURL + c.Endpoint + + // 打印请求信息 + fmt.Printf("OpenAI: Sending request to: %s\n", url) + fmt.Printf("OpenAI: BaseURL=%s, Endpoint=%s, Model=%s\n", c.BaseURL, c.Endpoint, c.Model) + requestPreview := string(jsonData) + if len(jsonData) > 300 { + requestPreview = string(jsonData[:300]) + "..." + } + fmt.Printf("OpenAI: Request body: %s\n", requestPreview) + httpReq, err := http.NewRequest("POST", url, bytes.NewBuffer(jsonData)) if err != nil { + fmt.Printf("OpenAI: Failed to create request: %v\n", err) return nil, fmt.Errorf("failed to create request: %w", err) } httpReq.Header.Set("Content-Type", "application/json") httpReq.Header.Set("Authorization", "Bearer "+c.APIKey) + fmt.Printf("OpenAI: Executing HTTP request...\n") resp, err := c.HTTPClient.Do(httpReq) if err != nil { + fmt.Printf("OpenAI: HTTP request failed: %v\n", err) return nil, fmt.Errorf("failed to send request: %w", err) } defer resp.Body.Close() + fmt.Printf("OpenAI: Received response with status: %d\n", resp.StatusCode) + body, err := io.ReadAll(resp.Body) if err != nil { + fmt.Printf("OpenAI: Failed to read response body: %v\n", err) return nil, fmt.Errorf("failed to read response: %w", err) } if resp.StatusCode != http.StatusOK { + fmt.Printf("OpenAI: API error (status %d): %s\n", resp.StatusCode, string(body)) var errResp ErrorResponse if err := json.Unmarshal(body, &errResp); err != nil { return nil, fmt.Errorf("API error (status %d): %s", resp.StatusCode, string(body)) @@ -122,11 +140,25 @@ func (c *OpenAIClient) sendChatRequest(req *ChatCompletionRequest) (*ChatComplet return nil, fmt.Errorf("API error: %s", errResp.Error.Message) } + // 打印响应体用于调试 + bodyPreview := string(body) + if len(body) > 500 { + bodyPreview = string(body[:500]) + "..." + } + fmt.Printf("OpenAI: Response body: %s\n", bodyPreview) + var chatResp ChatCompletionResponse if err := json.Unmarshal(body, &chatResp); err != nil { - return nil, fmt.Errorf("failed to unmarshal response: %w", err) + errorPreview := string(body) + if len(body) > 200 { + errorPreview = string(body[:200]) + } + fmt.Printf("OpenAI: Failed to parse response: %v\n", err) + return nil, fmt.Errorf("failed to unmarshal response: %w, body preview: %s", err, errorPreview) } + fmt.Printf("OpenAI: Successfully parsed response, choices count: %d\n", len(chatResp.Choices)) + return &chatResp, nil } @@ -176,6 +208,8 @@ func (c *OpenAIClient) GenerateText(prompt string, systemPrompt string, options } func (c *OpenAIClient) TestConnection() error { + fmt.Printf("OpenAI: TestConnection called with BaseURL=%s, Endpoint=%s, Model=%s\n", c.BaseURL, c.Endpoint, c.Model) + messages := []ChatMessage{ { Role: "user", @@ -184,5 +218,10 @@ func (c *OpenAIClient) TestConnection() error { } _, err := c.ChatCompletion(messages, WithMaxTokens(10)) + if err != nil { + fmt.Printf("OpenAI: TestConnection failed: %v\n", err) + } else { + fmt.Printf("OpenAI: TestConnection succeeded\n") + } return err } diff --git a/pkg/image/gemini_image_client.go b/pkg/image/gemini_image_client.go new file mode 100644 index 0000000..3ffa427 --- /dev/null +++ b/pkg/image/gemini_image_client.go @@ -0,0 +1,277 @@ +package image + +import ( + "bytes" + "encoding/base64" + "encoding/json" + "fmt" + "io" + "net/http" + "strings" + "time" +) + +type GeminiImageClient struct { + BaseURL string + APIKey string + Model string + Endpoint string + HTTPClient *http.Client +} + +type GeminiImageRequest struct { + Contents []struct { + Parts []GeminiPart `json:"parts"` + } `json:"contents"` + GenerationConfig struct { + ResponseModalities []string `json:"responseModalities"` + } `json:"generationConfig"` +} + +type GeminiPart struct { + Text string `json:"text,omitempty"` + InlineData *GeminiInlineData `json:"inlineData,omitempty"` +} + +type GeminiInlineData struct { + MimeType string `json:"mimeType"` + Data string `json:"data"` // base64 编码的图片数据 +} + +type GeminiImageResponse struct { + Candidates []struct { + Content struct { + Parts []struct { + InlineData struct { + MimeType string `json:"mimeType"` + Data string `json:"data"` + } `json:"inlineData,omitempty"` + Text string `json:"text,omitempty"` + } `json:"parts"` + } `json:"content"` + } `json:"candidates"` + UsageMetadata struct { + PromptTokenCount int `json:"promptTokenCount"` + CandidatesTokenCount int `json:"candidatesTokenCount"` + TotalTokenCount int `json:"totalTokenCount"` + } `json:"usageMetadata"` +} + +// downloadImageToBase64 下载图片 URL 并转换为 base64 +func downloadImageToBase64(imageURL string) (string, string, error) { + resp, err := http.Get(imageURL) + if err != nil { + return "", "", fmt.Errorf("download image: %w", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + return "", "", fmt.Errorf("download image failed with status: %d", resp.StatusCode) + } + + imageData, err := io.ReadAll(resp.Body) + if err != nil { + return "", "", fmt.Errorf("read image data: %w", err) + } + + // 根据 Content-Type 确定 mimeType + mimeType := resp.Header.Get("Content-Type") + if mimeType == "" { + mimeType = "image/jpeg" + } + + base64Data := base64.StdEncoding.EncodeToString(imageData) + return base64Data, mimeType, nil +} + +func NewGeminiImageClient(baseURL, apiKey, model, endpoint string) *GeminiImageClient { + if baseURL == "" { + baseURL = "https://generativelanguage.googleapis.com" + } + if endpoint == "" { + endpoint = "/v1beta/models/{model}:generateContent" + } + if model == "" { + model = "gemini-3-pro-image-preview" + } + return &GeminiImageClient{ + BaseURL: baseURL, + APIKey: apiKey, + Model: model, + Endpoint: endpoint, + HTTPClient: &http.Client{ + Timeout: 10 * time.Minute, + }, + } +} + +func (c *GeminiImageClient) GenerateImage(prompt string, opts ...ImageOption) (*ImageResult, error) { + options := &ImageOptions{ + Size: "1024x1024", + Quality: "standard", + } + + for _, opt := range opts { + opt(options) + } + + model := c.Model + if options.Model != "" { + model = options.Model + } + + promptText := prompt + if options.NegativePrompt != "" { + promptText += fmt.Sprintf("\n\nNegative prompt: %s", options.NegativePrompt) + } + if options.Size != "" { + promptText += fmt.Sprintf("\n\nImage size: %s", options.Size) + } + + // 构建请求的 parts,支持参考图 + parts := []GeminiPart{} + + // 如果有参考图,先添加参考图 + if len(options.ReferenceImages) > 0 { + for _, refImg := range options.ReferenceImages { + var base64Data string + var mimeType string + var err error + + // 检查是否是 HTTP/HTTPS URL + if strings.HasPrefix(refImg, "http://") || strings.HasPrefix(refImg, "https://") { + // 下载图片并转换为 base64 + base64Data, mimeType, err = downloadImageToBase64(refImg) + if err != nil { + continue + } + } else if strings.HasPrefix(refImg, "data:") { + // 如果是 data URI 格式,需要解析 + // 格式: data:image/jpeg;base64,xxxxx + mimeType = "image/jpeg" + parts := []byte(refImg) + for i := 0; i < len(parts); i++ { + if parts[i] == ',' { + base64Data = refImg[i+1:] + // 提取 mime type + if i > 11 { + mimeTypeEnd := i + for j := 5; j < i; j++ { + if parts[j] == ';' { + mimeTypeEnd = j + break + } + } + mimeType = refImg[5:mimeTypeEnd] + } + break + } + } + } else { + // 假设已经是 base64 编码 + base64Data = refImg + mimeType = "image/jpeg" + } + + if base64Data != "" { + parts = append(parts, GeminiPart{ + InlineData: &GeminiInlineData{ + MimeType: mimeType, + Data: base64Data, + }, + }) + } + } + } + + // 添加文本提示词 + parts = append(parts, GeminiPart{ + Text: promptText, + }) + + reqBody := GeminiImageRequest{ + Contents: []struct { + Parts []GeminiPart `json:"parts"` + }{ + { + Parts: parts, + }, + }, + GenerationConfig: struct { + ResponseModalities []string `json:"responseModalities"` + }{ + ResponseModalities: []string{"IMAGE"}, + }, + } + + jsonData, err := json.Marshal(reqBody) + if err != nil { + return nil, fmt.Errorf("marshal request: %w", err) + } + + endpoint := c.BaseURL + c.Endpoint + endpoint = replaceModelPlaceholder(endpoint, model) + url := fmt.Sprintf("%s?key=%s", endpoint, c.APIKey) + + req, err := http.NewRequest("POST", url, bytes.NewBuffer(jsonData)) + if err != nil { + return nil, fmt.Errorf("create request: %w", err) + } + + req.Header.Set("Content-Type", "application/json") + + resp, err := c.HTTPClient.Do(req) + if err != nil { + return nil, fmt.Errorf("send request: %w", err) + } + defer resp.Body.Close() + + body, err := io.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("read response: %w", err) + } + + if resp.StatusCode != http.StatusOK { + bodyStr := string(body) + if len(bodyStr) > 1000 { + bodyStr = fmt.Sprintf("%s ... %s", bodyStr[:500], bodyStr[len(bodyStr)-500:]) + } + return nil, fmt.Errorf("API error (status %d): %s", resp.StatusCode, bodyStr) + } + + var result GeminiImageResponse + if err := json.Unmarshal(body, &result); err != nil { + return nil, fmt.Errorf("parse response: %w", err) + } + + if len(result.Candidates) == 0 || len(result.Candidates[0].Content.Parts) == 0 { + return nil, fmt.Errorf("no image generated in response") + } + + base64Data := result.Candidates[0].Content.Parts[0].InlineData.Data + if base64Data == "" { + return nil, fmt.Errorf("no base64 image data in response") + } + + dataURI := fmt.Sprintf("data:image/jpeg;base64,%s", base64Data) + + return &ImageResult{ + Status: "completed", + ImageURL: dataURI, + Completed: true, + Width: 1024, + Height: 1024, + }, nil +} + +func (c *GeminiImageClient) GetTaskStatus(taskID string) (*ImageResult, error) { + return nil, fmt.Errorf("not supported for Gemini (synchronous generation)") +} + +func replaceModelPlaceholder(endpoint, model string) string { + result := endpoint + if bytes.Contains([]byte(result), []byte("{model}")) { + result = string(bytes.ReplaceAll([]byte(result), []byte("{model}"), []byte(model))) + } + return result +} diff --git a/pkg/image/image_client.go b/pkg/image/image_client.go index 8d575e5..ae5f2f6 100644 --- a/pkg/image/image_client.go +++ b/pkg/image/image_client.go @@ -1,14 +1,5 @@ package image -import ( - "bytes" - "encoding/json" - "fmt" - "io" - "net/http" - "time" -) - type ImageClient interface { GenerateImage(prompt string, opts ...ImageOption) (*ImageResult, error) GetTaskStatus(taskID string) (*ImageResult, error) @@ -100,285 +91,3 @@ func WithReferenceImages(images []string) ImageOption { o.ReferenceImages = images } } - -type OpenAIImageClient struct { - BaseURL string - APIKey string - Model string - HTTPClient *http.Client -} - -type DALLERequest struct { - Model string `json:"model"` - Prompt string `json:"prompt"` - Size string `json:"size,omitempty"` - Quality string `json:"quality,omitempty"` - N int `json:"n"` - Image []string `json:"image,omitempty"` // 参考图片URL列表 -} - -type DALLEResponse struct { - Created int64 `json:"created"` - Data []struct { - URL string `json:"url"` - RevisedPrompt string `json:"revised_prompt,omitempty"` - } `json:"data"` -} - -func NewOpenAIImageClient(baseURL, apiKey, model string) *OpenAIImageClient { - return &OpenAIImageClient{ - BaseURL: baseURL, - APIKey: apiKey, - Model: model, - HTTPClient: &http.Client{ - Timeout: 10 * time.Minute, - }, - } -} - -func (c *OpenAIImageClient) GenerateImage(prompt string, opts ...ImageOption) (*ImageResult, error) { - options := &ImageOptions{ - Size: "1920x1920", - Quality: "standard", - } - - for _, opt := range opts { - opt(options) - } - - model := c.Model - if options.Model != "" { - model = options.Model - } - - reqBody := DALLERequest{ - Model: model, - Prompt: prompt, - Size: options.Size, - Quality: options.Quality, - N: 1, - Image: options.ReferenceImages, - } - - jsonData, err := json.Marshal(reqBody) - if err != nil { - return nil, fmt.Errorf("marshal request: %w", err) - } - - endpoint := c.BaseURL + "/v1/images/generations" - req, err := http.NewRequest("POST", endpoint, bytes.NewBuffer(jsonData)) - if err != nil { - return nil, fmt.Errorf("create request: %w", err) - } - - req.Header.Set("Content-Type", "application/json") - req.Header.Set("Authorization", "Bearer "+c.APIKey) - - resp, err := c.HTTPClient.Do(req) - if err != nil { - return nil, fmt.Errorf("send request: %w", err) - } - defer resp.Body.Close() - - body, err := io.ReadAll(resp.Body) - if err != nil { - return nil, fmt.Errorf("read response: %w", err) - } - - if resp.StatusCode != http.StatusOK { - return nil, fmt.Errorf("API error (status %d): %s", resp.StatusCode, string(body)) - } - - // 打印原始响应以便调试 - fmt.Printf("OpenAI API Response: %s\n", string(body)) - - var result DALLEResponse - if err := json.Unmarshal(body, &result); err != nil { - return nil, fmt.Errorf("parse response: %w, body: %s", err, string(body)) - } - - if len(result.Data) == 0 { - return nil, fmt.Errorf("no image generated, response: %s", string(body)) - } - - return &ImageResult{ - Status: "completed", - ImageURL: result.Data[0].URL, - Completed: true, - }, nil -} - -func (c *OpenAIImageClient) GetTaskStatus(taskID string) (*ImageResult, error) { - return nil, fmt.Errorf("not supported for OpenAI/DALL-E") -} - -type StableDiffusionClient struct { - BaseURL string - APIKey string - Model string - HTTPClient *http.Client -} - -type SDRequest struct { - Prompt string `json:"prompt"` - NegativePrompt string `json:"negative_prompt,omitempty"` - Model string `json:"model,omitempty"` - Width int `json:"width,omitempty"` - Height int `json:"height,omitempty"` - Steps int `json:"steps,omitempty"` - CfgScale float64 `json:"cfg_scale,omitempty"` - Seed int64 `json:"seed,omitempty"` - Samples int `json:"samples"` - Image []string `json:"image,omitempty"` // 参考图片URL列表 -} - -type SDResponse struct { - Status string `json:"status"` - TaskID string `json:"task_id,omitempty"` - Output []struct { - URL string `json:"url"` - } `json:"output,omitempty"` - Error string `json:"error,omitempty"` -} - -func NewStableDiffusionClient(baseURL, apiKey, model string) *StableDiffusionClient { - return &StableDiffusionClient{ - BaseURL: baseURL, - APIKey: apiKey, - Model: model, - HTTPClient: &http.Client{ - Timeout: 10 * time.Minute, - }, - } -} - -func (c *StableDiffusionClient) GenerateImage(prompt string, opts ...ImageOption) (*ImageResult, error) { - options := &ImageOptions{ - Width: 1024, - Height: 1024, - Steps: 30, - CfgScale: 7.5, - } - - for _, opt := range opts { - opt(options) - } - - model := c.Model - if options.Model != "" { - model = options.Model - } - - reqBody := SDRequest{ - Prompt: prompt, - NegativePrompt: options.NegativePrompt, - Model: model, - Width: options.Width, - Height: options.Height, - Steps: options.Steps, - CfgScale: options.CfgScale, - Seed: options.Seed, - Samples: 1, - Image: options.ReferenceImages, - } - - jsonData, err := json.Marshal(reqBody) - if err != nil { - return nil, fmt.Errorf("marshal request: %w", err) - } - - endpoint := c.BaseURL + "/v1/images/generations" - req, err := http.NewRequest("POST", endpoint, bytes.NewBuffer(jsonData)) - if err != nil { - return nil, fmt.Errorf("create request: %w", err) - } - - req.Header.Set("Content-Type", "application/json") - req.Header.Set("Authorization", "Bearer "+c.APIKey) - - resp, err := c.HTTPClient.Do(req) - if err != nil { - return nil, fmt.Errorf("send request: %w", err) - } - defer resp.Body.Close() - - body, err := io.ReadAll(resp.Body) - if err != nil { - return nil, fmt.Errorf("read response: %w", err) - } - - if resp.StatusCode != http.StatusOK { - return nil, fmt.Errorf("API error (status %d): %s", resp.StatusCode, string(body)) - } - - var result SDResponse - if err := json.Unmarshal(body, &result); err != nil { - return nil, fmt.Errorf("parse response: %w", err) - } - - if result.Error != "" { - return nil, fmt.Errorf("SD error: %s", result.Error) - } - - if result.Status == "processing" { - return &ImageResult{ - TaskID: result.TaskID, - Status: "processing", - Completed: false, - }, nil - } - - if len(result.Output) == 0 { - return nil, fmt.Errorf("no image generated") - } - - return &ImageResult{ - Status: "completed", - ImageURL: result.Output[0].URL, - Width: options.Width, - Height: options.Height, - Completed: true, - }, nil -} - -func (c *StableDiffusionClient) GetTaskStatus(taskID string) (*ImageResult, error) { - endpoint := c.BaseURL + "/v1/images/status/" + taskID - req, err := http.NewRequest("GET", endpoint, nil) - if err != nil { - return nil, fmt.Errorf("create request: %w", err) - } - - req.Header.Set("Authorization", "Bearer "+c.APIKey) - - resp, err := c.HTTPClient.Do(req) - if err != nil { - return nil, fmt.Errorf("send request: %w", err) - } - defer resp.Body.Close() - - body, err := io.ReadAll(resp.Body) - if err != nil { - return nil, fmt.Errorf("read response: %w", err) - } - - var result SDResponse - if err := json.Unmarshal(body, &result); err != nil { - return nil, fmt.Errorf("parse response: %w", err) - } - - imageResult := &ImageResult{ - TaskID: taskID, - Status: result.Status, - Completed: result.Status == "completed", - } - - if result.Error != "" { - imageResult.Error = result.Error - } - - if len(result.Output) > 0 { - imageResult.ImageURL = result.Output[0].URL - } - - return imageResult, nil -} diff --git a/pkg/image/openai_image_client.go b/pkg/image/openai_image_client.go new file mode 100644 index 0000000..c238113 --- /dev/null +++ b/pkg/image/openai_image_client.go @@ -0,0 +1,128 @@ +package image + +import ( + "bytes" + "encoding/json" + "fmt" + "io" + "net/http" + "time" +) + +type OpenAIImageClient struct { + BaseURL string + APIKey string + Model string + Endpoint string + HTTPClient *http.Client +} + +type DALLERequest struct { + Model string `json:"model"` + Prompt string `json:"prompt"` + Size string `json:"size,omitempty"` + Quality string `json:"quality,omitempty"` + N int `json:"n"` + Image []string `json:"image,omitempty"` +} + +type DALLEResponse struct { + Created int64 `json:"created"` + Data []struct { + URL string `json:"url"` + RevisedPrompt string `json:"revised_prompt,omitempty"` + } `json:"data"` +} + +func NewOpenAIImageClient(baseURL, apiKey, model, endpoint string) *OpenAIImageClient { + if endpoint == "" { + endpoint = "/v1/images/generations" + } + return &OpenAIImageClient{ + BaseURL: baseURL, + APIKey: apiKey, + Model: model, + Endpoint: endpoint, + HTTPClient: &http.Client{ + Timeout: 10 * time.Minute, + }, + } +} + +func (c *OpenAIImageClient) GenerateImage(prompt string, opts ...ImageOption) (*ImageResult, error) { + options := &ImageOptions{ + Size: "1920x1920", + Quality: "standard", + } + + for _, opt := range opts { + opt(options) + } + + model := c.Model + if options.Model != "" { + model = options.Model + } + + reqBody := DALLERequest{ + Model: model, + Prompt: prompt, + Size: options.Size, + Quality: options.Quality, + N: 1, + Image: options.ReferenceImages, + } + + jsonData, err := json.Marshal(reqBody) + if err != nil { + return nil, fmt.Errorf("marshal request: %w", err) + } + + url := c.BaseURL + c.Endpoint + fmt.Printf("[OpenAI Image] Request URL: %s\n", url) + fmt.Printf("[OpenAI Image] Request Body: %s\n", string(jsonData)) + + req, err := http.NewRequest("POST", url, bytes.NewBuffer(jsonData)) + if err != nil { + return nil, fmt.Errorf("create request: %w", err) + } + + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Authorization", "Bearer "+c.APIKey) + + resp, err := c.HTTPClient.Do(req) + if err != nil { + return nil, fmt.Errorf("send request: %w", err) + } + defer resp.Body.Close() + + body, err := io.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("read response: %w", err) + } + + if resp.StatusCode != http.StatusOK { + return nil, fmt.Errorf("API error (status %d): %s", resp.StatusCode, string(body)) + } + + fmt.Printf("OpenAI API Response: %s\n", string(body)) + + var result DALLEResponse + if err := json.Unmarshal(body, &result); err != nil { + return nil, fmt.Errorf("parse response: %w, body: %s", err, string(body)) + } + + if len(result.Data) == 0 { + return nil, fmt.Errorf("no image generated, response: %s", string(body)) + } + + return &ImageResult{ + Status: "completed", + ImageURL: result.Data[0].URL, + Completed: true, + }, nil +} + +func (c *OpenAIImageClient) GetTaskStatus(taskID string) (*ImageResult, error) { + return nil, fmt.Errorf("not supported for OpenAI/DALL-E") +} diff --git a/pkg/image/volcengine_image_client.go b/pkg/image/volcengine_image_client.go new file mode 100644 index 0000000..87b920d --- /dev/null +++ b/pkg/image/volcengine_image_client.go @@ -0,0 +1,158 @@ +package image + +import ( + "bytes" + "encoding/json" + "fmt" + "io" + "net/http" + "time" +) + +type VolcEngineImageClient struct { + BaseURL string + APIKey string + Model string + Endpoint string + QueryEndpoint string + HTTPClient *http.Client +} + +type VolcEngineImageRequest struct { + Model string `json:"model"` + Prompt string `json:"prompt"` + Image []string `json:"image,omitempty"` + SequentialImageGeneration string `json:"sequential_image_generation,omitempty"` + Size string `json:"size,omitempty"` + Watermark bool `json:"watermark,omitempty"` +} + +type VolcEngineImageResponse struct { + Model string `json:"model"` + Created int64 `json:"created"` + Data []struct { + URL string `json:"url"` + Size string `json:"size"` + } `json:"data"` + Usage struct { + GeneratedImages int `json:"generated_images"` + OutputTokens int `json:"output_tokens"` + TotalTokens int `json:"total_tokens"` + } `json:"usage"` + Error interface{} `json:"error,omitempty"` +} + +func NewVolcEngineImageClient(baseURL, apiKey, model, endpoint, queryEndpoint string) *VolcEngineImageClient { + if endpoint == "" { + endpoint = "/api/v3/images/generations" + } + if queryEndpoint == "" { + queryEndpoint = endpoint + } + return &VolcEngineImageClient{ + BaseURL: baseURL, + APIKey: apiKey, + Model: model, + Endpoint: endpoint, + QueryEndpoint: queryEndpoint, + HTTPClient: &http.Client{ + Timeout: 10 * time.Minute, + }, + } +} + +func (c *VolcEngineImageClient) GenerateImage(prompt string, opts ...ImageOption) (*ImageResult, error) { + options := &ImageOptions{ + Size: "1024x1024", + Quality: "standard", + } + + for _, opt := range opts { + opt(options) + } + + model := c.Model + if options.Model != "" { + model = options.Model + } + + promptText := prompt + if options.NegativePrompt != "" { + promptText += fmt.Sprintf(". Negative: %s", options.NegativePrompt) + } + + size := options.Size + if size == "" { + if model == "doubao-seedream-4-5-251128" { + size = "2K" + } else { + size = "1K" + } + } + + reqBody := VolcEngineImageRequest{ + Model: model, + Prompt: promptText, + Image: options.ReferenceImages, + SequentialImageGeneration: "disabled", + Size: size, + Watermark: false, + } + + jsonData, err := json.Marshal(reqBody) + if err != nil { + return nil, fmt.Errorf("marshal request: %w", err) + } + + url := c.BaseURL + c.Endpoint + fmt.Printf("[VolcEngine Image] Request URL: %s\n", url) + fmt.Printf("[VolcEngine Image] Request Body: %s\n", string(jsonData)) + + req, err := http.NewRequest("POST", url, bytes.NewBuffer(jsonData)) + if err != nil { + return nil, fmt.Errorf("create request: %w", err) + } + + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Authorization", "Bearer "+c.APIKey) + + resp, err := c.HTTPClient.Do(req) + if err != nil { + return nil, fmt.Errorf("send request: %w", err) + } + defer resp.Body.Close() + + body, err := io.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("read response: %w", err) + } + + fmt.Printf("VolcEngine Image API Response: %s\n", string(body)) + + if resp.StatusCode != http.StatusOK && resp.StatusCode != http.StatusCreated { + return nil, fmt.Errorf("API error (status %d): %s", resp.StatusCode, string(body)) + } + + var result VolcEngineImageResponse + if err := json.Unmarshal(body, &result); err != nil { + return nil, fmt.Errorf("parse response: %w", err) + } + + if result.Error != nil { + return nil, fmt.Errorf("volcengine error: %v", result.Error) + } + + if len(result.Data) == 0 { + return nil, fmt.Errorf("no image generated") + } + + return &ImageResult{ + Status: "completed", + ImageURL: result.Data[0].URL, + Completed: true, + }, nil +} + +func (c *VolcEngineImageClient) GetTaskStatus(taskID string) (*ImageResult, error) { + return nil, fmt.Errorf("not supported for VolcEngine Seedream (synchronous generation)") +} diff --git a/pkg/video/chatfire_client.go b/pkg/video/chatfire_client.go new file mode 100644 index 0000000..f9b6a91 --- /dev/null +++ b/pkg/video/chatfire_client.go @@ -0,0 +1,184 @@ +package video + +import ( + "bytes" + "encoding/json" + "fmt" + "io" + "net/http" + "strings" + "time" +) + +// ChatfireClient Chatfire 视频生成客户端 +type ChatfireClient struct { + BaseURL string + APIKey string + Model string + Endpoint string + QueryEndpoint string + HTTPClient *http.Client +} + +type ChatfireRequest struct { + Model string `json:"model"` + Prompt string `json:"prompt"` + ImageURL string `json:"image_url,omitempty"` + Duration int `json:"duration,omitempty"` + Size string `json:"size,omitempty"` +} + +type ChatfireResponse struct { + TaskID string `json:"task_id"` + Status string `json:"status"` + Error string `json:"error,omitempty"` +} + +type ChatfireTaskResponse struct { + TaskID string `json:"task_id"` + Status string `json:"status"` + VideoURL string `json:"video_url,omitempty"` + Error string `json:"error,omitempty"` +} + +func NewChatfireClient(baseURL, apiKey, model, endpoint, queryEndpoint string) *ChatfireClient { + if endpoint == "" { + endpoint = "/video/generations" + } + if queryEndpoint == "" { + queryEndpoint = "/v1/video/task/{taskId}" + } + return &ChatfireClient{ + BaseURL: baseURL, + APIKey: apiKey, + Model: model, + Endpoint: endpoint, + QueryEndpoint: queryEndpoint, + HTTPClient: &http.Client{ + Timeout: 300 * time.Second, + }, + } +} + +func (c *ChatfireClient) GenerateVideo(imageURL, prompt string, opts ...VideoOption) (*VideoResult, error) { + options := &VideoOptions{ + Duration: 5, + AspectRatio: "16:9", + } + + for _, opt := range opts { + opt(options) + } + + model := c.Model + if options.Model != "" { + model = options.Model + } + + reqBody := ChatfireRequest{ + Model: model, + Prompt: prompt, + ImageURL: imageURL, + Duration: options.Duration, + Size: options.AspectRatio, + } + + jsonData, err := json.Marshal(reqBody) + if err != nil { + return nil, fmt.Errorf("marshal request: %w", err) + } + + endpoint := c.BaseURL + c.Endpoint + req, err := http.NewRequest("POST", endpoint, bytes.NewBuffer(jsonData)) + if err != nil { + return nil, fmt.Errorf("create request: %w", err) + } + + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Authorization", "Bearer "+c.APIKey) + + resp, err := c.HTTPClient.Do(req) + if err != nil { + return nil, fmt.Errorf("send request: %w", err) + } + defer resp.Body.Close() + + body, err := io.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("read response: %w", err) + } + + if resp.StatusCode != http.StatusOK && resp.StatusCode != http.StatusCreated { + return nil, fmt.Errorf("API error (status %d): %s", resp.StatusCode, string(body)) + } + + var result ChatfireResponse + if err := json.Unmarshal(body, &result); err != nil { + return nil, fmt.Errorf("parse response: %w", err) + } + + if result.Error != "" { + return nil, fmt.Errorf("chatfire error: %s", result.Error) + } + + videoResult := &VideoResult{ + TaskID: result.TaskID, + Status: result.Status, + Completed: result.Status == "completed" || result.Status == "succeeded", + Duration: options.Duration, + } + + return videoResult, nil +} + +func (c *ChatfireClient) GetTaskStatus(taskID string) (*VideoResult, error) { + queryPath := c.QueryEndpoint + if strings.Contains(queryPath, "{taskId}") { + queryPath = strings.ReplaceAll(queryPath, "{taskId}", taskID) + } else if strings.Contains(queryPath, "{task_id}") { + queryPath = strings.ReplaceAll(queryPath, "{task_id}", taskID) + } else { + queryPath = queryPath + "/" + taskID + } + + endpoint := c.BaseURL + queryPath + req, err := http.NewRequest("GET", endpoint, nil) + if err != nil { + return nil, fmt.Errorf("create request: %w", err) + } + + req.Header.Set("Authorization", "Bearer "+c.APIKey) + + resp, err := c.HTTPClient.Do(req) + if err != nil { + return nil, fmt.Errorf("send request: %w", err) + } + defer resp.Body.Close() + + body, err := io.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("read response: %w", err) + } + + var result ChatfireTaskResponse + if err := json.Unmarshal(body, &result); err != nil { + return nil, fmt.Errorf("parse response: %w", err) + } + + videoResult := &VideoResult{ + TaskID: result.TaskID, + Status: result.Status, + Completed: result.Status == "completed" || result.Status == "succeeded", + } + + if result.Error != "" { + videoResult.Error = result.Error + } + + if result.VideoURL != "" { + videoResult.VideoURL = result.VideoURL + videoResult.Completed = true + } + + return videoResult, nil +} diff --git a/pkg/video/openai_sora_client.go b/pkg/video/openai_sora_client.go index 39aae3b..772bb84 100644 --- a/pkg/video/openai_sora_client.go +++ b/pkg/video/openai_sora_client.go @@ -83,7 +83,7 @@ func (c *OpenAISoraClient) GenerateVideo(imageURL, prompt string, opts ...VideoO writer.Close() - endpoint := c.BaseURL + "/v1/videos" + endpoint := c.BaseURL + "/videos" req, err := http.NewRequest("POST", endpoint, body) if err != nil { return nil, fmt.Errorf("create request: %w", err) diff --git a/web/src/views/settings/AIConfig.vue b/web/src/views/settings/AIConfig.vue index 3333536..1ba032c 100644 --- a/web/src/views/settings/AIConfig.vue +++ b/web/src/views/settings/AIConfig.vue @@ -9,7 +9,7 @@ - + -
API 服务的基础地址
+
+ API 服务的基础地址,如 Chatfire: https://api.chatfire.site/v1,Gemini: https://generativelanguage.googleapis.com(无需 /v1) +
+ 完整调用路径: {{ fullEndpointExample }} +
@@ -127,16 +131,6 @@
您的 API 密钥
- - -
API 端点路径,默认为 /v1/chat/completions
-
- - - -
异步任务查询端点(仅视频服务需要),支持 {taskId} 占位符
-
- @@ -181,8 +175,6 @@ const form = reactive = { text: [ - { id: 'openai', name: 'OpenAI', models: ['gpt-5.2', 'gemini-3-pro-preview'], disabled: true } + { id: 'openai', name: 'OpenAI', models: ['gpt-5.2', 'gemini-3-pro-preview'] }, + { + id: 'chatfire', + name: 'Chatfire', + models: [ + 'gpt-4o', + 'claude-sonnet-4-5-20250929', + 'doubao-seed-1-8-251228', + 'kimi-k2-thinking', + 'gemini-3-pro', + 'gemini-2.5-pro', + 'gemini-3-pro-preview' + ] + }, + { + id: 'gemini', + name: 'Google Gemini', + models: [ + 'gemini-2.5-pro', + 'gemini-3-pro-preview' + ] + } ], image: [ - { id: 'openai', name: 'OpenAI', models: ['nano-banana-pro', 'doubao-seedream-4-5-251128'] } + { + id: 'volcengine', + name: '火山引擎', + models: [ + 'doubao-seedream-4-5-251128', + 'doubao-seedream-4-0-250828', + ] + }, + { + id: 'chatfire', + name: 'Chatfire', + models: [ + 'doubao-seedream-4-5-251128', + 'nano-banana-pro', + ] + }, + { + id: 'gemini', + name: 'Google Gemini', + models: [ + 'gemini-3-pro-image-preview', + ] + }, + { id: 'openai', name: 'OpenAI', models: ['dall-e-3', 'dall-e-2'] } ], video: [ { @@ -214,6 +250,19 @@ const providerConfigs: Record = { 'doubao-seedance-1-0-pro-fast-251015' ] }, + { + id: 'chatfire', + name: 'Chatfire', + models: [ + 'doubao-seedance-1-5-pro-251215', + 'doubao-seedance-1-0-lite-i2v-250428', + 'doubao-seedance-1-0-lite-t2v-250428', + 'doubao-seedance-1-0-pro-250528', + 'doubao-seedance-1-0-pro-fast-251015', + 'sora', + 'sora-pro' + ] + }, { id: 'openai', name: 'OpenAI', models: ['sora-2', 'sora-2-pro'] }, // { id: 'minimax', name: 'MiniMax', models: ['MiniMax-Hailuo-2.3', 'MiniMax-Hailuo-2.3-Fast', 'MiniMax-Hailuo-02'] } ] @@ -231,6 +280,41 @@ const availableModels = computed(() => { return provider?.models || [] }) +// 完整端点示例 +const fullEndpointExample = computed(() => { + const baseUrl = form.base_url || 'https://api.example.com' + const provider = form.provider + const serviceType = form.service_type + + let endpoint = '' + + if (serviceType === 'text') { + if (provider === 'gemini' || provider === 'google') { + endpoint = '/v1beta/models/{model}:generateContent' + } else { + endpoint = '/chat/completions' + } + } else if (serviceType === 'image') { + if (provider === 'gemini' || provider === 'google') { + endpoint = '/v1beta/models/{model}:generateContent' + } else { + endpoint = '/images/generations' + } + } else if (serviceType === 'video') { + if (provider === 'chatfire') { + endpoint = '/video/generations' + } else if (provider === 'doubao' || provider === 'volcengine' || provider === 'volces') { + endpoint = '/contents/generations/tasks' + } else if (provider === 'openai') { + endpoint = '/videos' + } else { + endpoint = '/video/generations' + } + } + + return baseUrl + endpoint +}) + const rules: FormRules = { name: [ { required: true, message: '请输入配置名称', trigger: 'blur' } @@ -274,17 +358,39 @@ const loadConfigs = async () => { } } +// 生成随机配置名称 +const generateConfigName = (provider: string, serviceType: AIServiceType): string => { + const providerNames: Record = { + 'chatfire': 'ChatFire', + 'openai': 'OpenAI', + 'gemini': 'Gemini', + 'google': 'Google' + } + + const serviceNames: Record = { + 'text': '文本', + 'image': '图片', + 'video': '视频' + } + + const randomNum = Math.floor(Math.random() * 10000).toString().padStart(4, '0') + const providerName = providerNames[provider] || provider + const serviceName = serviceNames[serviceType] || serviceType + + return `${providerName}-${serviceName}-${randomNum}` +} + const showCreateDialog = () => { isEdit.value = false editingId.value = undefined resetForm() form.service_type = activeTab.value - // 根据服务类型设置默认端点路径 - form.endpoint = getDefaultEndpoint(activeTab.value) - // 文本生成默认选择openai - if (activeTab.value === 'text') { - form.provider = 'openai' - } + // 默认选择 chatfire + form.provider = 'chatfire' + // 设置默认 base_url + form.base_url = 'https://api.chatfire.site/v1' + // 自动生成随机配置名称 + form.name = generateConfigName('chatfire', activeTab.value) dialogVisible.value = true } @@ -292,26 +398,13 @@ const handleEdit = (config: AIServiceConfig) => { isEdit.value = true editingId.value = config.id - // 根据模型名称推断厂商 - const inferProvider = (model: string, serviceType: AIServiceType): string => { - const providers = providerConfigs[serviceType] - for (const provider of providers) { - if (provider.models.includes(model)) { - return provider.id - } - } - return providers[0]?.id || '' - } - Object.assign(form, { service_type: config.service_type, - provider: inferProvider(Array.isArray(config.model) ? config.model[0] : config.model, config.service_type), + provider: config.provider || 'chatfire', // 直接使用配置中的 provider,默认为 chatfire name: config.name, base_url: config.base_url, api_key: config.api_key, model: Array.isArray(config.model) ? config.model : [config.model], // 统一转换为数组 - endpoint: config.endpoint, - query_endpoint: config.query_endpoint || '', priority: config.priority || 0, is_active: config.is_active }) @@ -359,7 +452,7 @@ const testConnection = async () => { base_url: form.base_url, api_key: form.api_key, model: form.model, - endpoint: form.endpoint + provider: form.provider }) ElMessage.success('连接测试成功!') } catch (error: any) { @@ -376,7 +469,7 @@ const handleTest = async (config: AIServiceConfig) => { base_url: config.base_url, api_key: config.api_key, model: config.model, - endpoint: config.endpoint + provider: config.provider }) ElMessage.success('连接测试成功!') } catch (error: any) { @@ -397,11 +490,10 @@ const handleSubmit = async () => { if (isEdit.value && editingId.value) { const updateData: UpdateAIConfigRequest = { name: form.name, + provider: form.provider, base_url: form.base_url, api_key: form.api_key, model: form.model, - endpoint: form.endpoint, - query_endpoint: form.query_endpoint, priority: form.priority, is_active: form.is_active } @@ -422,16 +514,36 @@ const handleSubmit = async () => { }) } +const handleTabChange = (tabName: string | number) => { + // 标签页切换时重新加载对应服务类型的配置 + activeTab.value = tabName as AIServiceType + loadConfigs() +} + const handleProviderChange = () => { // 切换厂商时清空已选模型 form.model = [] + + // 根据厂商自动设置默认 base_url + if (form.provider === 'gemini' || form.provider === 'google') { + form.base_url = 'https://api.chatfire.site' + } else { + // openai, chatfire 等其他厂商 + form.base_url = 'https://api.chatfire.site/v1' + } + + // 仅在新建配置时自动更新名称 + if (!isEdit.value) { + form.name = generateConfigName(form.provider, form.service_type) + } } -// 根据服务类型获取默认端点路径 +// getDefaultEndpoint 已移除,端点由后端根据 provider 自动设置 +// 保留该函数定义以避免编译错误 const getDefaultEndpoint = (serviceType: AIServiceType): string => { switch (serviceType) { case 'text': - return '/v1/chat/completions' + return '' case 'image': return '/v1/images/generations' case 'video': @@ -450,8 +562,6 @@ const resetForm = () => { base_url: '', api_key: '', model: [], // 改为空数组 - endpoint: getDefaultEndpoint(serviceType), - query_endpoint: '', priority: 0, is_active: true })