添加chat gemini、chatfire端点、 图片生成 gemini、chatfire 更轻松的AI配置

This commit is contained in:
Connor
2026-01-14 02:25:41 +08:00
parent 4d38357ff6
commit 23b45efae9
22 changed files with 1512 additions and 405 deletions

2
.gitignore vendored
View File

@@ -63,3 +63,5 @@ configs/config.yaml
# Docker publish documentation (optional) # Docker publish documentation (optional)
DOCKER_PUBLISH.md DOCKER_PUBLISH.md
build.sh build.sh
/data/storage/
/web/package-lock.json

View File

@@ -3,6 +3,9 @@
# ==================== 阶段1: 构建前端 ==================== # ==================== 阶段1: 构建前端 ====================
FROM node:20-alpine AS frontend-builder FROM node:20-alpine AS frontend-builder
# 配置 npm 镜像源(国内加速)
RUN npm config set registry https://registry.npmmirror.com
WORKDIR /app/web WORKDIR /app/web
# 复制前端依赖文件 # 复制前端依赖文件
@@ -59,15 +62,12 @@ RUN apk add --no-cache \
tzdata \ tzdata \
ffmpeg \ ffmpeg \
sqlite-libs \ sqlite-libs \
wget \
&& rm -rf /var/cache/apk/* && rm -rf /var/cache/apk/*
# 设置时区 # 设置时区
ENV TZ=Asia/Shanghai ENV TZ=Asia/Shanghai
# 创建非 root 用户
RUN addgroup -g 1000 app && \
adduser -D -u 1000 -G app app
WORKDIR /app WORKDIR /app
# 从构建阶段复制可执行文件 # 从构建阶段复制可执行文件
@@ -83,10 +83,7 @@ RUN cp ./configs/config.example.yaml ./configs/config.yaml
# 复制数据库迁移文件 # 复制数据库迁移文件
COPY migrations ./migrations/ COPY migrations ./migrations/
# 切换到非 root 用户 # 创建数据目录root 用户运行,无需权限设置)
USER app
# 创建数据目录(在 app 用户下创建,确保权限正确)
RUN mkdir -p /app/data/storage RUN mkdir -p /app/data/storage
# 暴露端口 # 暴露端口

View File

@@ -25,6 +25,7 @@ func NewAIService(db *gorm.DB, log *logger.Logger) *AIService {
type CreateAIConfigRequest struct { type CreateAIConfigRequest struct {
ServiceType string `json:"service_type" binding:"required,oneof=text image video"` ServiceType string `json:"service_type" binding:"required,oneof=text image video"`
Name string `json:"name" binding:"required,min=1,max=100"` Name string `json:"name" binding:"required,min=1,max=100"`
Provider string `json:"provider" binding:"required"`
BaseURL string `json:"base_url" binding:"required,url"` BaseURL string `json:"base_url" binding:"required,url"`
APIKey string `json:"api_key" binding:"required"` APIKey string `json:"api_key" binding:"required"`
Model models.ModelField `json:"model" binding:"required"` Model models.ModelField `json:"model" binding:"required"`
@@ -37,6 +38,7 @@ type CreateAIConfigRequest struct {
type UpdateAIConfigRequest struct { type UpdateAIConfigRequest struct {
Name string `json:"name" binding:"omitempty,min=1,max=100"` Name string `json:"name" binding:"omitempty,min=1,max=100"`
Provider string `json:"provider"`
BaseURL string `json:"base_url" binding:"omitempty,url"` BaseURL string `json:"base_url" binding:"omitempty,url"`
APIKey string `json:"api_key"` APIKey string `json:"api_key"`
Model *models.ModelField `json:"model"` Model *models.ModelField `json:"model"`
@@ -52,18 +54,53 @@ type TestConnectionRequest struct {
BaseURL string `json:"base_url" binding:"required,url"` BaseURL string `json:"base_url" binding:"required,url"`
APIKey string `json:"api_key" binding:"required"` APIKey string `json:"api_key" binding:"required"`
Model models.ModelField `json:"model" binding:"required"` Model models.ModelField `json:"model" binding:"required"`
Provider string `json:"provider"`
Endpoint string `json:"endpoint"` Endpoint string `json:"endpoint"`
} }
func (s *AIService) CreateConfig(req *CreateAIConfigRequest) (*models.AIServiceConfig, error) { func (s *AIService) CreateConfig(req *CreateAIConfigRequest) (*models.AIServiceConfig, error) {
// 根据 provider 和 service_type 自动设置 endpoint
endpoint := req.Endpoint
queryEndpoint := req.QueryEndpoint
if endpoint == "" {
switch req.Provider {
case "gemini", "google":
if req.ServiceType == "text" {
endpoint = "/v1beta/models/{model}:generateContent"
} else if req.ServiceType == "image" {
endpoint = "/v1beta/models/{model}:generateContent"
}
case "openai", "chatfire":
if req.ServiceType == "text" {
endpoint = "/chat/completions"
} else if req.ServiceType == "image" {
endpoint = "/images/generations"
} else if req.ServiceType == "video" {
endpoint = "/video/generations"
if queryEndpoint == "" {
queryEndpoint = "/v1/video/task/{taskId}"
}
}
default:
// 默认使用 OpenAI 格式
if req.ServiceType == "text" {
endpoint = "/chat/completions"
} else if req.ServiceType == "image" {
endpoint = "/images/generations"
}
}
}
config := &models.AIServiceConfig{ config := &models.AIServiceConfig{
ServiceType: req.ServiceType, ServiceType: req.ServiceType,
Name: req.Name, Name: req.Name,
Provider: req.Provider,
BaseURL: req.BaseURL, BaseURL: req.BaseURL,
APIKey: req.APIKey, APIKey: req.APIKey,
Model: req.Model, Model: req.Model,
Endpoint: req.Endpoint, Endpoint: endpoint,
QueryEndpoint: req.QueryEndpoint, QueryEndpoint: queryEndpoint,
Priority: req.Priority, Priority: req.Priority,
IsDefault: req.IsDefault, IsDefault: req.IsDefault,
IsActive: true, IsActive: true,
@@ -75,7 +112,7 @@ func (s *AIService) CreateConfig(req *CreateAIConfigRequest) (*models.AIServiceC
return nil, err return nil, err
} }
s.log.Infow("AI config created", "config_id", config.ID) s.log.Infow("AI config created", "config_id", config.ID, "provider", req.Provider, "endpoint", endpoint)
return config, nil return config, nil
} }
@@ -125,6 +162,9 @@ func (s *AIService) UpdateConfig(configID uint, req *UpdateAIConfigRequest) (*mo
if req.Name != "" { if req.Name != "" {
updates["name"] = req.Name updates["name"] = req.Name
} }
if req.Provider != "" {
updates["provider"] = req.Provider
}
if req.BaseURL != "" { if req.BaseURL != "" {
updates["base_url"] = req.BaseURL updates["base_url"] = req.BaseURL
} }
@@ -137,9 +177,30 @@ func (s *AIService) UpdateConfig(configID uint, req *UpdateAIConfigRequest) (*mo
if req.Priority != nil { if req.Priority != nil {
updates["priority"] = *req.Priority updates["priority"] = *req.Priority
} }
if req.Endpoint != "" {
// 如果提供了 provider根据 provider 和 service_type 自动设置 endpoint
if req.Provider != "" && req.Endpoint == "" {
provider := req.Provider
serviceType := config.ServiceType
switch provider {
case "gemini", "google":
if serviceType == "text" || serviceType == "image" {
updates["endpoint"] = "/v1beta/models/{model}:generateContent"
}
case "openai", "chatfire":
if serviceType == "text" {
updates["endpoint"] = "/chat/completions"
} else if serviceType == "image" {
updates["endpoint"] = "/images/generations"
} else if serviceType == "video" {
updates["endpoint"] = "/video/generations"
}
}
} else if req.Endpoint != "" {
updates["endpoint"] = req.Endpoint updates["endpoint"] = req.Endpoint
} }
// 允许清空query_endpoint所以不检查是否为空 // 允许清空query_endpoint所以不检查是否为空
updates["query_endpoint"] = req.QueryEndpoint updates["query_endpoint"] = req.QueryEndpoint
if req.Settings != "" { if req.Settings != "" {
@@ -179,13 +240,51 @@ func (s *AIService) DeleteConfig(configID uint) error {
} }
func (s *AIService) TestConnection(req *TestConnectionRequest) error { func (s *AIService) TestConnection(req *TestConnectionRequest) error {
s.log.Infow("TestConnection called", "baseURL", req.BaseURL, "provider", req.Provider, "endpoint", req.Endpoint, "modelCount", len(req.Model))
// 使用第一个模型进行测试 // 使用第一个模型进行测试
model := "" model := ""
if len(req.Model) > 0 { if len(req.Model) > 0 {
model = req.Model[0] model = req.Model[0]
} }
client := ai.NewOpenAIClient(req.BaseURL, req.APIKey, model, req.Endpoint) s.log.Infow("Using model for test", "model", model, "provider", req.Provider)
return client.TestConnection()
// 根据 provider 参数选择客户端
var client ai.AIClient
var endpoint string
switch req.Provider {
case "gemini", "google":
// Gemini
s.log.Infow("Using Gemini client", "baseURL", req.BaseURL)
endpoint = "/v1beta/models/{model}:generateContent"
client = ai.NewGeminiClient(req.BaseURL, req.APIKey, model, endpoint)
case "openai", "chatfire":
// OpenAI 格式(包括 chatfire 等)
s.log.Infow("Using OpenAI-compatible client", "baseURL", req.BaseURL, "provider", req.Provider)
endpoint = req.Endpoint
if endpoint == "" {
endpoint = "/chat/completions"
}
client = ai.NewOpenAIClient(req.BaseURL, req.APIKey, model, endpoint)
default:
// 默认使用 OpenAI 格式
s.log.Infow("Using default OpenAI-compatible client", "baseURL", req.BaseURL)
endpoint = req.Endpoint
if endpoint == "" {
endpoint = "/chat/completions"
}
client = ai.NewOpenAIClient(req.BaseURL, req.APIKey, model, endpoint)
}
s.log.Infow("Calling TestConnection on client", "endpoint", endpoint)
err := client.TestConnection()
if err != nil {
s.log.Errorw("TestConnection failed", "error", err)
} else {
s.log.Infow("TestConnection succeeded")
}
return err
} }
func (s *AIService) GetDefaultConfig(serviceType string) (*models.AIServiceConfig, error) { func (s *AIService) GetDefaultConfig(serviceType string) (*models.AIServiceConfig, error) {
@@ -228,7 +327,7 @@ func (s *AIService) GetConfigForModel(serviceType string, modelName string) (*mo
return nil, errors.New("no config found for model: " + modelName) return nil, errors.New("no config found for model: " + modelName)
} }
func (s *AIService) GetAIClient(serviceType string) (*ai.OpenAIClient, error) { func (s *AIService) GetAIClient(serviceType string) (ai.AIClient, error) {
config, err := s.GetDefaultConfig(serviceType) config, err := s.GetDefaultConfig(serviceType)
if err != nil { if err != nil {
return nil, err return nil, err
@@ -240,7 +339,25 @@ func (s *AIService) GetAIClient(serviceType string) (*ai.OpenAIClient, error) {
model = config.Model[0] model = config.Model[0]
} }
return ai.NewOpenAIClient(config.BaseURL, config.APIKey, model, config.Endpoint), nil // 使用数据库配置中的 endpoint如果为空则根据 provider 设置默认值
endpoint := config.Endpoint
if endpoint == "" {
switch config.Provider {
case "gemini", "google":
endpoint = "/v1beta/models/{model}:generateContent"
default:
endpoint = "/chat/completions"
}
}
// 根据 provider 创建对应的客户端
switch config.Provider {
case "gemini", "google":
return ai.NewGeminiClient(config.BaseURL, config.APIKey, model, endpoint), nil
default:
// openai, chatfire 等其他厂商都使用 OpenAI 格式
return ai.NewOpenAIClient(config.BaseURL, config.APIKey, model, endpoint), nil
}
} }
func (s *AIService) GenerateText(prompt string, systemPrompt string, options ...func(*ai.ChatCompletionRequest)) (string, error) { func (s *AIService) GenerateText(prompt string, systemPrompt string, options ...func(*ai.ChatCompletionRequest)) (string, error) {

View File

@@ -4,6 +4,7 @@ import (
"encoding/json" "encoding/json"
"fmt" "fmt"
"strconv" "strconv"
"strings"
"time" "time"
models "github.com/drama-generator/backend/domain/models" models "github.com/drama-generator/backend/domain/models"
@@ -23,6 +24,24 @@ type ImageGenerationService struct {
log *logger.Logger log *logger.Logger
} }
// truncateImageURL 截断图片 URL避免 base64 格式的 URL 占满日志
func truncateImageURL(url string) string {
if url == "" {
return ""
}
// 如果是 data URI 格式base64只显示前缀
if strings.HasPrefix(url, "data:") {
if len(url) > 50 {
return url[:50] + "...[base64 data]"
}
}
// 普通 URL 如果过长也截断
if len(url) > 100 {
return url[:100] + "..."
}
return url
}
func NewImageGenerationService(db *gorm.DB, transferService *ResourceTransferService, localStorage *storage.LocalStorage, log *logger.Logger) *ImageGenerationService { func NewImageGenerationService(db *gorm.DB, transferService *ResourceTransferService, localStorage *storage.LocalStorage, log *logger.Logger) *ImageGenerationService {
return &ImageGenerationService{ return &ImageGenerationService{
db: db, db: db,
@@ -246,17 +265,23 @@ func (s *ImageGenerationService) completeImageGeneration(imageGenID uint, result
now := time.Now() now := time.Now()
// 下载图片到本地存储(仅用于缓存,不更新数据库) // 下载图片到本地存储(仅用于缓存,不更新数据库)
if s.localStorage != nil && result.ImageURL != "" { // 仅下载 HTTP/HTTPS URL跳过 data URI
if s.localStorage != nil && result.ImageURL != "" &&
(strings.HasPrefix(result.ImageURL, "http://") || strings.HasPrefix(result.ImageURL, "https://")) {
_, err := s.localStorage.DownloadFromURL(result.ImageURL, "images") _, err := s.localStorage.DownloadFromURL(result.ImageURL, "images")
if err != nil { if err != nil {
errStr := err.Error()
if len(errStr) > 200 {
errStr = errStr[:200] + "..."
}
s.log.Warnw("Failed to download image to local storage", s.log.Warnw("Failed to download image to local storage",
"error", err, "error", errStr,
"id", imageGenID, "id", imageGenID,
"original_url", result.ImageURL) "original_url", truncateImageURL(result.ImageURL))
} else { } else {
s.log.Infow("Image downloaded to local storage for caching", s.log.Infow("Image downloaded to local storage for caching",
"id", imageGenID, "id", imageGenID,
"original_url", result.ImageURL) "original_url", truncateImageURL(result.ImageURL))
} }
} }
@@ -291,7 +316,7 @@ func (s *ImageGenerationService) completeImageGeneration(imageGenID uint, result
} else { } else {
s.log.Infow("Storyboard updated with composed image", s.log.Infow("Storyboard updated with composed image",
"storyboard_id", *imageGen.StoryboardID, "storyboard_id", *imageGen.StoryboardID,
"composed_image", result.ImageURL) "composed_image", truncateImageURL(result.ImageURL))
} }
} }
@@ -306,7 +331,7 @@ func (s *ImageGenerationService) completeImageGeneration(imageGenID uint, result
} else { } else {
s.log.Infow("Scene updated with generated image", s.log.Infow("Scene updated with generated image",
"scene_id", *imageGen.SceneID, "scene_id", *imageGen.SceneID,
"image_url", result.ImageURL) "image_url", truncateImageURL(result.ImageURL))
} }
} }
@@ -317,7 +342,7 @@ func (s *ImageGenerationService) completeImageGeneration(imageGenID uint, result
} else { } else {
s.log.Infow("Character updated with generated image", s.log.Infow("Character updated with generated image",
"character_id", *imageGen.CharacterID, "character_id", *imageGen.CharacterID,
"image_url", result.ImageURL) "image_url", truncateImageURL(result.ImageURL))
} }
} }
} }
@@ -356,13 +381,33 @@ func (s *ImageGenerationService) getImageClient(provider string) (image.ImageCli
model = config.Model[0] model = config.Model[0]
} }
switch provider { // 使用配置中的 provider如果没有则使用传入的 provider
actualProvider := config.Provider
if actualProvider == "" {
actualProvider = provider
}
// 根据 provider 自动设置默认端点
var endpoint string
var queryEndpoint string
switch actualProvider {
case "openai", "dalle": case "openai", "dalle":
return image.NewOpenAIImageClient(config.BaseURL, config.APIKey, model), nil endpoint = "/images/generations"
case "stable_diffusion", "sd": return image.NewOpenAIImageClient(config.BaseURL, config.APIKey, model, endpoint), nil
return image.NewStableDiffusionClient(config.BaseURL, config.APIKey, model), nil case "chatfire":
endpoint = "/images/generations"
return image.NewOpenAIImageClient(config.BaseURL, config.APIKey, model, endpoint), nil
case "volcengine", "volces", "doubao":
endpoint = "/images/generations"
queryEndpoint = ""
return image.NewVolcEngineImageClient(config.BaseURL, config.APIKey, model, endpoint, queryEndpoint), nil
case "gemini", "google":
endpoint = "/v1beta/models/{model}:generateContent"
return image.NewGeminiImageClient(config.BaseURL, config.APIKey, model, endpoint), nil
default: default:
return image.NewOpenAIImageClient(config.BaseURL, config.APIKey, model), nil endpoint = "/images/generations"
return image.NewOpenAIImageClient(config.BaseURL, config.APIKey, model, endpoint), nil
} }
} }
@@ -394,13 +439,33 @@ func (s *ImageGenerationService) getImageClientWithModel(provider string, modelN
model = config.Model[0] model = config.Model[0]
} }
switch provider { // 使用配置中的 provider如果没有则使用传入的 provider
actualProvider := config.Provider
if actualProvider == "" {
actualProvider = provider
}
// 根据 provider 自动设置默认端点
var endpoint string
var queryEndpoint string
switch actualProvider {
case "openai", "dalle": case "openai", "dalle":
return image.NewOpenAIImageClient(config.BaseURL, config.APIKey, model), nil endpoint = "/images/generations"
case "stable_diffusion", "sd": return image.NewOpenAIImageClient(config.BaseURL, config.APIKey, model, endpoint), nil
return image.NewStableDiffusionClient(config.BaseURL, config.APIKey, model), nil case "chatfire":
endpoint = "/images/generations"
return image.NewOpenAIImageClient(config.BaseURL, config.APIKey, model, endpoint), nil
case "volcengine", "volces", "doubao":
endpoint = "/images/generations"
queryEndpoint = ""
return image.NewVolcEngineImageClient(config.BaseURL, config.APIKey, model, endpoint, queryEndpoint), nil
case "gemini", "google":
endpoint = "/v1beta/models/{model}:generateContent"
return image.NewGeminiImageClient(config.BaseURL, config.APIKey, model, endpoint), nil
default: default:
return image.NewOpenAIImageClient(config.BaseURL, config.APIKey, model), nil endpoint = "/images/generations"
return image.NewOpenAIImageClient(config.BaseURL, config.APIKey, model, endpoint), nil
} }
} }
@@ -716,21 +781,11 @@ func (s *ImageGenerationService) extractBackgroundsFromScript(scriptContent stri
请严格按照JSON格式输出确保所有字段都使用中文。`, scriptContent) 请严格按照JSON格式输出确保所有字段都使用中文。`, scriptContent)
messages := []ai.ChatMessage{ response, err := client.GenerateText(prompt, "", ai.WithTemperature(0.7), ai.WithMaxTokens(8000))
{Role: "user", Content: prompt},
}
resp, err := client.ChatCompletion(messages, ai.WithTemperature(0.7), ai.WithMaxTokens(8000))
if err != nil { if err != nil {
s.log.Errorw("Failed to extract backgrounds with AI", "error", err) s.log.Errorw("Failed to extract backgrounds with AI", "error", err)
return nil, fmt.Errorf("AI提取场景失败: %w", err) return nil, fmt.Errorf("AI提取场景失败: %w", err)
} }
if len(resp.Choices) == 0 {
return nil, fmt.Errorf("AI未返回有效响应")
}
response := resp.Choices[0].Message.Content
s.log.Infow("AI backgrounds extraction response", "length", len(response)) s.log.Infow("AI backgrounds extraction response", "length", len(response))
// 解析JSON响应 // 解析JSON响应

View File

@@ -185,17 +185,17 @@ func (s *ScriptGenerationService) GenerateCharacters(req *GenerateCharactersRequ
count = 5 count = 5
} }
systemPrompt := `你是一个专业的角色设计师,擅长创作立体丰富的剧中角色 systemPrompt := `你是一个专业的角色分析师,擅长从剧本中提取和分析角色信息
你的任务是根据提供的剧本大纲,创作符合故事需求的角色设定。 你的任务是根据提供的剧本内容,提取并整理剧中出现的所有角色的详细设定。
要求: 要求:
1. 角色必须服务于大纲中的故事情节和冲突 1. 仔细阅读剧本,识别所有出现的角色
2. 角色性格鲜明,有辨识度,符合故事类型 2. 根据剧本中的对话、行为和描述,总结角色性格特点
3. 每个角色都有清晰的动机目标,与大纲中的矛盾冲突相关 3. 提取角色在剧本中的关键信息:背景、动机目标、关系等
4. 角色之间有合理的关系和联系 4. 角色之间的关系必须基于剧本中的实际描述
5. 外貌描述必须极其详细便于AI绘画生成角色形象 5. 外貌描述必须极其详细,如果剧本中有描述则使用,如果没有则根据角色设定合理推断,便于AI绘画生成角色形象
6. 根据大纲的关键场景合理设置角色数量通常3-6个主要角色 6. 优先提取主要角色和重要配角,次要角色可以简略
请严格按照以下 JSON 格式输出,不要添加任何其他文字: 请严格按照以下 JSON 格式输出,不要添加任何其他文字:
@@ -213,21 +213,21 @@ func (s *ScriptGenerationService) GenerateCharacters(req *GenerateCharactersRequ
} }
注意: 注意:
- 角色数量根据故事复杂度确定,不要过多 - 必须基于剧本内容提取角色,不要凭空创作
- 每个角色都要与大纲中的故事线有明确关联 - 优先提取主要角色和重要配角,数量根据剧本实际情况确定
- description、personality、appearance、voice_style都必须详细描述字数要充足 - description、personality、appearance、voice_style都必须详细描述字数要充足
- appearance外貌描述是重中之重必须极其详细具体要能让AI准确生成角色形象 - appearance外貌描述是重中之重必须极其详细具体要能让AI准确生成角色形象
- 避免模糊描述,多用具体的视觉特征和细节` - 如果剧本中角色信息不完整,可以根据角色设定合理补充,但要符合剧本整体风格`
outlineText := req.Outline outlineText := req.Outline
if outlineText == "" { if outlineText == "" {
outlineText = fmt.Sprintf("剧名:%s\n简介%s\n类型%s", drama.Title, drama.Description, drama.Genre) outlineText = fmt.Sprintf("剧名:%s\n简介%s\n类型%s", drama.Title, drama.Description, drama.Genre)
} }
userPrompt := fmt.Sprintf(`剧本大纲 userPrompt := fmt.Sprintf(`剧本内容
%s %s
创作 %d 个角色的详细设定。`, outlineText, count) 从剧本中提取并整理最多 %d 个主要角色的详细设定。`, outlineText, count)
temperature := req.Temperature temperature := req.Temperature
if temperature == 0 { if temperature == 0 {

View File

@@ -414,24 +414,33 @@ func (s *VideoGenerationService) getVideoClient(provider string, modelName strin
// 使用配置中的信息创建客户端 // 使用配置中的信息创建客户端
baseURL := config.BaseURL baseURL := config.BaseURL
apiKey := config.APIKey apiKey := config.APIKey
endpoint := config.Endpoint
queryEndpoint := config.QueryEndpoint
model := modelName model := modelName
if model == "" && len(config.Model) > 0 { if model == "" && len(config.Model) > 0 {
model = config.Model[0] model = config.Model[0]
} }
// 根据 provider 自动设置默认端点
var endpoint string
var queryEndpoint string
switch provider { switch provider {
case "doubao": case "chatfire":
endpoint = "/video/generations"
queryEndpoint = "/v1/video/task/{taskId}"
return video.NewChatfireClient(baseURL, apiKey, model, endpoint, queryEndpoint), nil
case "doubao", "volcengine", "volces":
endpoint = "/contents/generations/tasks"
queryEndpoint = "/generations/tasks/{taskId}"
return video.NewVolcesArkClient(baseURL, apiKey, model, endpoint, queryEndpoint), nil return video.NewVolcesArkClient(baseURL, apiKey, model, endpoint, queryEndpoint), nil
case "openai":
// OpenAI Sora 使用 /v1/videos 端点
return video.NewOpenAISoraClient(baseURL, apiKey, model), nil
case "runway": case "runway":
return video.NewRunwayClient(baseURL, apiKey, model), nil return video.NewRunwayClient(baseURL, apiKey, model), nil
case "pika": case "pika":
return video.NewPikaClient(baseURL, apiKey, model), nil return video.NewPikaClient(baseURL, apiKey, model), nil
case "minimax": case "minimax":
return video.NewMinimaxClient(baseURL, apiKey, model), nil return video.NewMinimaxClient(baseURL, apiKey, model), nil
case "openai":
return video.NewOpenAISoraClient(baseURL, apiKey, model), nil
default: default:
return nil, fmt.Errorf("unsupported video provider: %s", provider) return nil, fmt.Errorf("unsupported video provider: %s", provider)
} }

View File

@@ -297,6 +297,10 @@ func (s *VideoMergeService) getVideoClient(provider string) (video.VideoClient,
model = config.Model[0] model = config.Model[0]
} }
// 根据 provider 自动设置默认端点
var endpoint string
var queryEndpoint string
switch provider { switch provider {
case "runway": case "runway":
return video.NewRunwayClient(config.BaseURL, config.APIKey, model), nil return video.NewRunwayClient(config.BaseURL, config.APIKey, model), nil
@@ -306,10 +310,18 @@ func (s *VideoMergeService) getVideoClient(provider string) (video.VideoClient,
return video.NewOpenAISoraClient(config.BaseURL, config.APIKey, model), nil return video.NewOpenAISoraClient(config.BaseURL, config.APIKey, model), nil
case "minimax": case "minimax":
return video.NewMinimaxClient(config.BaseURL, config.APIKey, model), nil return video.NewMinimaxClient(config.BaseURL, config.APIKey, model), nil
case "chatfire":
endpoint = "/video/generations"
queryEndpoint = "/v1/video/task/{taskId}"
return video.NewChatfireClient(config.BaseURL, config.APIKey, model, endpoint, queryEndpoint), nil
case "doubao", "volces", "ark": case "doubao", "volces", "ark":
return video.NewVolcesArkClient(config.BaseURL, config.APIKey, model, config.Endpoint, config.QueryEndpoint), nil endpoint = "/contents/generations/tasks"
queryEndpoint = "/generations/tasks/{taskId}"
return video.NewVolcesArkClient(config.BaseURL, config.APIKey, model, endpoint, queryEndpoint), nil
default: default:
return video.NewVolcesArkClient(config.BaseURL, config.APIKey, model, config.Endpoint, config.QueryEndpoint), nil endpoint = "/contents/generations/tasks"
queryEndpoint = "/generations/tasks/{taskId}"
return video.NewVolcesArkClient(config.BaseURL, config.APIKey, model, endpoint, queryEndpoint), nil
} }
} }

View File

@@ -8,10 +8,12 @@ services:
ports: ports:
- "5678:5678" - "5678:5678"
volumes: volumes:
# 持久化数据目录(使用命名卷) # 持久化数据目录(使用命名卷,容器内以 root 运行
- huobao-data:/app/data - huobao-data:/app/data
# 挂载配置文件(可选,如需自定义配置请取消注释) # 挂载配置文件(可选,如需自定义配置请取消注释)
# - ./configs/config.yaml:/app/configs/config.yaml:ro # - ./configs/config.yaml:/app/configs/config.yaml:ro
# 注意:如果使用本地目录挂载,需要确保目录权限正确
# 例如:- ./data:/app/data (需要 chmod 777 ./data
environment: environment:
- TZ=Asia/Shanghai - TZ=Asia/Shanghai
restart: unless-stopped restart: unless-stopped

View File

@@ -10,6 +10,7 @@ import (
type AIServiceConfig struct { type AIServiceConfig struct {
ID uint `gorm:"primaryKey;autoIncrement" json:"id"` ID uint `gorm:"primaryKey;autoIncrement" json:"id"`
ServiceType string `gorm:"type:varchar(50);not null" json:"service_type"` // text, image, video ServiceType string `gorm:"type:varchar(50);not null" json:"service_type"` // text, image, video
Provider string `gorm:"type:varchar(50)" json:"provider"` // openai, gemini, volcengine, etc.
Name string `gorm:"type:varchar(100);not null" json:"name"` Name string `gorm:"type:varchar(100);not null" json:"name"`
BaseURL string `gorm:"type:varchar(255);not null" json:"base_url"` BaseURL string `gorm:"type:varchar(255);not null" json:"base_url"`
APIKey string `gorm:"type:varchar(255);not null" json:"api_key"` APIKey string `gorm:"type:varchar(255);not null" json:"api_key"`

View File

@@ -0,0 +1,103 @@
package database
import (
"context"
"strings"
"time"
"gorm.io/gorm/logger"
)
// CustomLogger 自定义 GORM logger截断过长的 SQL 参数(如 base64 数据)
type CustomLogger struct {
logger.Interface
}
// NewCustomLogger 创建自定义 logger
func NewCustomLogger() logger.Interface {
return &CustomLogger{
Interface: logger.Default.LogMode(logger.Silent),
}
}
// Trace 重写 Trace 方法,禁用 SQL 日志输出
func (l *CustomLogger) Trace(ctx context.Context, begin time.Time, fc func() (sql string, rowsAffected int64), err error) {
// 不输出任何 SQL 日志
// 如果需要调试,可以临时取消注释下面的代码
/*
sql, rows := fc()
sql = truncateLongValues(sql)
elapsed := time.Since(begin)
if err != nil {
l.Interface.Error(ctx, "SQL error: %v [%v] %s", err, elapsed, sql)
} else {
l.Interface.Info(ctx, "[%.3fms] [rows:%d] %s", float64(elapsed.Nanoseconds())/1e6, rows, sql)
}
*/
}
// truncateLongValues 截断 SQL 中的长字符串值
func truncateLongValues(sql string) string {
// 查找 base64 格式的数据 (data:image/...;base64,...)
if strings.Contains(sql, "data:image/") && strings.Contains(sql, ";base64,") {
parts := strings.Split(sql, "\"")
for i, part := range parts {
if strings.HasPrefix(part, "data:image/") && strings.Contains(part, ";base64,") {
if len(part) > 100 {
// 保留前50字符添加截断标记
parts[i] = part[:50] + "...[base64 data truncated]"
}
}
}
sql = strings.Join(parts, "\"")
}
// 截断其他过长的值
if len(sql) > 5000 {
// 查找 VALUES 或 SET 后的内容
if idx := strings.Index(sql, " VALUES "); idx > 0 && len(sql) > idx+5000 {
sql = sql[:idx+5000] + "...[truncated]"
} else if idx := strings.Index(sql, " SET "); idx > 0 && len(sql) > idx+3000 {
sql = sql[:idx+3000] + "...[truncated]"
} else if len(sql) > 5000 {
sql = sql[:5000] + "...[truncated]"
}
}
return sql
}
// Info 实现 Info 方法
func (l *CustomLogger) Info(ctx context.Context, msg string, data ...interface{}) {
l.Interface.Info(ctx, msg, data...)
}
// Warn 实现 Warn 方法
func (l *CustomLogger) Warn(ctx context.Context, msg string, data ...interface{}) {
l.Interface.Warn(ctx, msg, data...)
}
// Error 实现 Error 方法
func (l *CustomLogger) Error(ctx context.Context, msg string, data ...interface{}) {
// 检查并截断 data 中的长字符串
truncatedData := make([]interface{}, len(data))
for i, d := range data {
if str, ok := d.(string); ok && len(str) > 200 {
if strings.HasPrefix(str, "data:image/") {
truncatedData[i] = str[:50] + "...[base64 data]"
} else {
truncatedData[i] = str[:200] + "..."
}
} else {
truncatedData[i] = d
}
}
l.Interface.Error(ctx, msg, truncatedData...)
}
// LogMode 实现 LogMode 方法
func (l *CustomLogger) LogMode(level logger.LogLevel) logger.Interface {
newLogger := *l
newLogger.Interface = l.Interface.LogMode(level)
return &newLogger
}

View File

@@ -11,7 +11,6 @@ import (
"gorm.io/driver/mysql" "gorm.io/driver/mysql"
"gorm.io/driver/sqlite" "gorm.io/driver/sqlite"
"gorm.io/gorm" "gorm.io/gorm"
"gorm.io/gorm/logger"
) )
func NewDatabase(cfg config.DatabaseConfig) (*gorm.DB, error) { func NewDatabase(cfg config.DatabaseConfig) (*gorm.DB, error) {
@@ -25,7 +24,7 @@ func NewDatabase(cfg config.DatabaseConfig) (*gorm.DB, error) {
} }
gormConfig := &gorm.Config{ gormConfig := &gorm.Config{
Logger: logger.Default.LogMode(logger.Info), Logger: NewCustomLogger(),
} }
var db *gorm.DB var db *gorm.DB

View File

@@ -445,12 +445,14 @@ CREATE INDEX IF NOT EXISTS idx_asset_collection_relations_collection_id ON asset
CREATE TABLE IF NOT EXISTS ai_service_configs ( CREATE TABLE IF NOT EXISTS ai_service_configs (
id INTEGER PRIMARY KEY AUTOINCREMENT, id INTEGER PRIMARY KEY AUTOINCREMENT,
service_type TEXT NOT NULL, -- text, image, video service_type TEXT NOT NULL, -- text, image, video
provider TEXT, -- openai, gemini, volcengine, etc.
name TEXT NOT NULL, name TEXT NOT NULL,
base_url TEXT NOT NULL, base_url TEXT NOT NULL,
api_key TEXT NOT NULL, api_key TEXT NOT NULL,
model TEXT, model TEXT,
endpoint TEXT, endpoint TEXT,
query_endpoint TEXT, query_endpoint TEXT,
priority INTEGER NOT NULL DEFAULT 0,
is_default INTEGER NOT NULL DEFAULT 0, is_default INTEGER NOT NULL DEFAULT 0,
is_active INTEGER NOT NULL DEFAULT 1, is_active INTEGER NOT NULL DEFAULT 1,
settings TEXT, -- JSON存储 settings TEXT, -- JSON存储
@@ -489,7 +491,8 @@ INSERT OR IGNORE INTO ai_service_providers (name, display_name, service_type, de
('openai-dalle', 'OpenAI DALL-E', 'image', 'https://api.openai.com/v1', 'OpenAI DALL-E图片生成'), ('openai-dalle', 'OpenAI DALL-E', 'image', 'https://api.openai.com/v1', 'OpenAI DALL-E图片生成'),
('openai-sora', 'OpenAI Sora', 'video', 'https://api.openai.com/v1', 'OpenAI Sora视频生成'), ('openai-sora', 'OpenAI Sora', 'video', 'https://api.openai.com/v1', 'OpenAI Sora视频生成'),
('midjourney', 'Midjourney', 'image', '', 'Midjourney图片生成'), ('midjourney', 'Midjourney', 'image', '', 'Midjourney图片生成'),
('stable-diffusion', 'Stable Diffusion', 'image', '', 'Stable Diffusion图片生成'), ('doubao-image', '豆包(火山引擎)', 'image', 'https://ark.cn-beijing.volces.com', '火山引擎豆包图片生成'),
('gemini-image', 'Google Gemini', 'image', 'https://generativelanguage.googleapis.com', 'Google Gemini原生图片生成(base64)'),
('runway', 'Runway', 'video', '', 'Runway视频生成'), ('runway', 'Runway', 'video', '', 'Runway视频生成'),
('pika', 'Pika Labs', 'video', '', 'Pika视频生成'), ('pika', 'Pika Labs', 'video', '', 'Pika视频生成'),
('doubao', '豆包(火山引擎)', 'video', 'https://ark.cn-beijing.volces.com', '火山引擎豆包视频生成'), ('doubao', '豆包(火山引擎)', 'video', 'https://ark.cn-beijing.volces.com', '火山引擎豆包视频生成'),

7
pkg/ai/client.go Normal file
View File

@@ -0,0 +1,7 @@
package ai
// AIClient 定义文本生成客户端接口
type AIClient interface {
GenerateText(prompt string, systemPrompt string, options ...func(*ChatCompletionRequest)) (string, error)
TestConnection() error
}

195
pkg/ai/gemini_client.go Normal file
View File

@@ -0,0 +1,195 @@
package ai
import (
"bytes"
"encoding/json"
"fmt"
"io"
"net/http"
"strings"
"time"
)
type GeminiClient struct {
BaseURL string
APIKey string
Model string
Endpoint string
HTTPClient *http.Client
}
type GeminiTextRequest struct {
Contents []GeminiContent `json:"contents"`
SystemInstruction *GeminiInstruction `json:"systemInstruction,omitempty"`
}
type GeminiContent struct {
Parts []GeminiPart `json:"parts"`
Role string `json:"role,omitempty"`
}
type GeminiPart struct {
Text string `json:"text"`
}
type GeminiInstruction struct {
Parts []GeminiPart `json:"parts"`
}
type GeminiTextResponse struct {
Candidates []struct {
Content struct {
Parts []struct {
Text string `json:"text"`
} `json:"parts"`
Role string `json:"role"`
} `json:"content"`
FinishReason string `json:"finishReason"`
Index int `json:"index"`
SafetyRatings []struct {
Category string `json:"category"`
Probability string `json:"probability"`
} `json:"safetyRatings"`
} `json:"candidates"`
UsageMetadata struct {
PromptTokenCount int `json:"promptTokenCount"`
CandidatesTokenCount int `json:"candidatesTokenCount"`
TotalTokenCount int `json:"totalTokenCount"`
} `json:"usageMetadata"`
}
func NewGeminiClient(baseURL, apiKey, model, endpoint string) *GeminiClient {
if baseURL == "" {
baseURL = "https://generativelanguage.googleapis.com"
}
if endpoint == "" {
endpoint = "/v1beta/models/{model}:generateContent"
}
if model == "" {
model = "gemini-3-pro"
}
return &GeminiClient{
BaseURL: baseURL,
APIKey: apiKey,
Model: model,
Endpoint: endpoint,
HTTPClient: &http.Client{
Timeout: 10 * time.Minute,
},
}
}
func (c *GeminiClient) GenerateText(prompt string, systemPrompt string, options ...func(*ChatCompletionRequest)) (string, error) {
model := c.Model
// 构建请求体
reqBody := GeminiTextRequest{
Contents: []GeminiContent{
{
Parts: []GeminiPart{{Text: prompt}},
Role: "user",
},
},
}
// 使用 systemInstruction 字段处理系统提示
if systemPrompt != "" {
reqBody.SystemInstruction = &GeminiInstruction{
Parts: []GeminiPart{{Text: systemPrompt}},
}
}
jsonData, err := json.Marshal(reqBody)
if err != nil {
fmt.Printf("Gemini: Failed to marshal request: %v\n", err)
return "", fmt.Errorf("marshal request: %w", err)
}
// 替换端点中的 {model} 占位符
endpoint := c.BaseURL + c.Endpoint
endpoint = strings.ReplaceAll(endpoint, "{model}", model)
url := fmt.Sprintf("%s?key=%s", endpoint, c.APIKey)
// 打印请求信息(隐藏 API Key
safeURL := strings.Replace(url, c.APIKey, "***", 1)
fmt.Printf("Gemini: Sending request to: %s\n", safeURL)
requestPreview := string(jsonData)
if len(jsonData) > 300 {
requestPreview = string(jsonData[:300]) + "..."
}
fmt.Printf("Gemini: Request body: %s\n", requestPreview)
req, err := http.NewRequest("POST", url, bytes.NewBuffer(jsonData))
if err != nil {
fmt.Printf("Gemini: Failed to create request: %v\n", err)
return "", fmt.Errorf("create request: %w", err)
}
req.Header.Set("Content-Type", "application/json")
fmt.Printf("Gemini: Executing HTTP request...\n")
resp, err := c.HTTPClient.Do(req)
if err != nil {
fmt.Printf("Gemini: HTTP request failed: %v\n", err)
return "", fmt.Errorf("send request: %w", err)
}
defer resp.Body.Close()
fmt.Printf("Gemini: Received response with status: %d\n", resp.StatusCode)
body, err := io.ReadAll(resp.Body)
if err != nil {
fmt.Printf("Gemini: Failed to read response body: %v\n", err)
return "", fmt.Errorf("read response: %w", err)
}
if resp.StatusCode != http.StatusOK {
fmt.Printf("Gemini: API error (status %d): %s\n", resp.StatusCode, string(body))
return "", fmt.Errorf("API error (status %d): %s", resp.StatusCode, string(body))
}
// 打印响应体用于调试
bodyPreview := string(body)
if len(body) > 500 {
bodyPreview = string(body[:500]) + "..."
}
fmt.Printf("Gemini: Response body: %s\n", bodyPreview)
var result GeminiTextResponse
if err := json.Unmarshal(body, &result); err != nil {
errorPreview := string(body)
if len(body) > 200 {
errorPreview = string(body[:200])
}
fmt.Printf("Gemini: Failed to parse response: %v\n", err)
return "", fmt.Errorf("parse response: %w, body preview: %s", err, errorPreview)
}
fmt.Printf("Gemini: Successfully parsed response, candidates count: %d\n", len(result.Candidates))
if len(result.Candidates) == 0 {
fmt.Printf("Gemini: No candidates in response\n")
return "", fmt.Errorf("no candidates in response")
}
if len(result.Candidates[0].Content.Parts) == 0 {
fmt.Printf("Gemini: No parts in first candidate\n")
return "", fmt.Errorf("no parts in response")
}
responseText := result.Candidates[0].Content.Parts[0].Text
fmt.Printf("Gemini: Generated text: %s\n", responseText)
return responseText, nil
}
func (c *GeminiClient) TestConnection() error {
fmt.Printf("Gemini: TestConnection called with BaseURL=%s, Model=%s, Endpoint=%s\n", c.BaseURL, c.Model, c.Endpoint)
_, err := c.GenerateText("Hello", "")
if err != nil {
fmt.Printf("Gemini: TestConnection failed: %v\n", err)
} else {
fmt.Printf("Gemini: TestConnection succeeded\n")
}
return err
}

View File

@@ -91,30 +91,48 @@ func (c *OpenAIClient) ChatCompletion(messages []ChatMessage, options ...func(*C
func (c *OpenAIClient) sendChatRequest(req *ChatCompletionRequest) (*ChatCompletionResponse, error) { func (c *OpenAIClient) sendChatRequest(req *ChatCompletionRequest) (*ChatCompletionResponse, error) {
jsonData, err := json.Marshal(req) jsonData, err := json.Marshal(req)
if err != nil { if err != nil {
fmt.Printf("OpenAI: Failed to marshal request: %v\n", err)
return nil, fmt.Errorf("failed to marshal request: %w", err) return nil, fmt.Errorf("failed to marshal request: %w", err)
} }
url := c.BaseURL + c.Endpoint url := c.BaseURL + c.Endpoint
// 打印请求信息
fmt.Printf("OpenAI: Sending request to: %s\n", url)
fmt.Printf("OpenAI: BaseURL=%s, Endpoint=%s, Model=%s\n", c.BaseURL, c.Endpoint, c.Model)
requestPreview := string(jsonData)
if len(jsonData) > 300 {
requestPreview = string(jsonData[:300]) + "..."
}
fmt.Printf("OpenAI: Request body: %s\n", requestPreview)
httpReq, err := http.NewRequest("POST", url, bytes.NewBuffer(jsonData)) httpReq, err := http.NewRequest("POST", url, bytes.NewBuffer(jsonData))
if err != nil { if err != nil {
fmt.Printf("OpenAI: Failed to create request: %v\n", err)
return nil, fmt.Errorf("failed to create request: %w", err) return nil, fmt.Errorf("failed to create request: %w", err)
} }
httpReq.Header.Set("Content-Type", "application/json") httpReq.Header.Set("Content-Type", "application/json")
httpReq.Header.Set("Authorization", "Bearer "+c.APIKey) httpReq.Header.Set("Authorization", "Bearer "+c.APIKey)
fmt.Printf("OpenAI: Executing HTTP request...\n")
resp, err := c.HTTPClient.Do(httpReq) resp, err := c.HTTPClient.Do(httpReq)
if err != nil { if err != nil {
fmt.Printf("OpenAI: HTTP request failed: %v\n", err)
return nil, fmt.Errorf("failed to send request: %w", err) return nil, fmt.Errorf("failed to send request: %w", err)
} }
defer resp.Body.Close() defer resp.Body.Close()
fmt.Printf("OpenAI: Received response with status: %d\n", resp.StatusCode)
body, err := io.ReadAll(resp.Body) body, err := io.ReadAll(resp.Body)
if err != nil { if err != nil {
fmt.Printf("OpenAI: Failed to read response body: %v\n", err)
return nil, fmt.Errorf("failed to read response: %w", err) return nil, fmt.Errorf("failed to read response: %w", err)
} }
if resp.StatusCode != http.StatusOK { if resp.StatusCode != http.StatusOK {
fmt.Printf("OpenAI: API error (status %d): %s\n", resp.StatusCode, string(body))
var errResp ErrorResponse var errResp ErrorResponse
if err := json.Unmarshal(body, &errResp); err != nil { if err := json.Unmarshal(body, &errResp); err != nil {
return nil, fmt.Errorf("API error (status %d): %s", resp.StatusCode, string(body)) return nil, fmt.Errorf("API error (status %d): %s", resp.StatusCode, string(body))
@@ -122,11 +140,25 @@ func (c *OpenAIClient) sendChatRequest(req *ChatCompletionRequest) (*ChatComplet
return nil, fmt.Errorf("API error: %s", errResp.Error.Message) return nil, fmt.Errorf("API error: %s", errResp.Error.Message)
} }
// 打印响应体用于调试
bodyPreview := string(body)
if len(body) > 500 {
bodyPreview = string(body[:500]) + "..."
}
fmt.Printf("OpenAI: Response body: %s\n", bodyPreview)
var chatResp ChatCompletionResponse var chatResp ChatCompletionResponse
if err := json.Unmarshal(body, &chatResp); err != nil { if err := json.Unmarshal(body, &chatResp); err != nil {
return nil, fmt.Errorf("failed to unmarshal response: %w", err) errorPreview := string(body)
if len(body) > 200 {
errorPreview = string(body[:200])
}
fmt.Printf("OpenAI: Failed to parse response: %v\n", err)
return nil, fmt.Errorf("failed to unmarshal response: %w, body preview: %s", err, errorPreview)
} }
fmt.Printf("OpenAI: Successfully parsed response, choices count: %d\n", len(chatResp.Choices))
return &chatResp, nil return &chatResp, nil
} }
@@ -176,6 +208,8 @@ func (c *OpenAIClient) GenerateText(prompt string, systemPrompt string, options
} }
func (c *OpenAIClient) TestConnection() error { func (c *OpenAIClient) TestConnection() error {
fmt.Printf("OpenAI: TestConnection called with BaseURL=%s, Endpoint=%s, Model=%s\n", c.BaseURL, c.Endpoint, c.Model)
messages := []ChatMessage{ messages := []ChatMessage{
{ {
Role: "user", Role: "user",
@@ -184,5 +218,10 @@ func (c *OpenAIClient) TestConnection() error {
} }
_, err := c.ChatCompletion(messages, WithMaxTokens(10)) _, err := c.ChatCompletion(messages, WithMaxTokens(10))
if err != nil {
fmt.Printf("OpenAI: TestConnection failed: %v\n", err)
} else {
fmt.Printf("OpenAI: TestConnection succeeded\n")
}
return err return err
} }

View File

@@ -0,0 +1,277 @@
package image
import (
"bytes"
"encoding/base64"
"encoding/json"
"fmt"
"io"
"net/http"
"strings"
"time"
)
type GeminiImageClient struct {
BaseURL string
APIKey string
Model string
Endpoint string
HTTPClient *http.Client
}
type GeminiImageRequest struct {
Contents []struct {
Parts []GeminiPart `json:"parts"`
} `json:"contents"`
GenerationConfig struct {
ResponseModalities []string `json:"responseModalities"`
} `json:"generationConfig"`
}
type GeminiPart struct {
Text string `json:"text,omitempty"`
InlineData *GeminiInlineData `json:"inlineData,omitempty"`
}
type GeminiInlineData struct {
MimeType string `json:"mimeType"`
Data string `json:"data"` // base64 编码的图片数据
}
type GeminiImageResponse struct {
Candidates []struct {
Content struct {
Parts []struct {
InlineData struct {
MimeType string `json:"mimeType"`
Data string `json:"data"`
} `json:"inlineData,omitempty"`
Text string `json:"text,omitempty"`
} `json:"parts"`
} `json:"content"`
} `json:"candidates"`
UsageMetadata struct {
PromptTokenCount int `json:"promptTokenCount"`
CandidatesTokenCount int `json:"candidatesTokenCount"`
TotalTokenCount int `json:"totalTokenCount"`
} `json:"usageMetadata"`
}
// downloadImageToBase64 下载图片 URL 并转换为 base64
func downloadImageToBase64(imageURL string) (string, string, error) {
resp, err := http.Get(imageURL)
if err != nil {
return "", "", fmt.Errorf("download image: %w", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
return "", "", fmt.Errorf("download image failed with status: %d", resp.StatusCode)
}
imageData, err := io.ReadAll(resp.Body)
if err != nil {
return "", "", fmt.Errorf("read image data: %w", err)
}
// 根据 Content-Type 确定 mimeType
mimeType := resp.Header.Get("Content-Type")
if mimeType == "" {
mimeType = "image/jpeg"
}
base64Data := base64.StdEncoding.EncodeToString(imageData)
return base64Data, mimeType, nil
}
func NewGeminiImageClient(baseURL, apiKey, model, endpoint string) *GeminiImageClient {
if baseURL == "" {
baseURL = "https://generativelanguage.googleapis.com"
}
if endpoint == "" {
endpoint = "/v1beta/models/{model}:generateContent"
}
if model == "" {
model = "gemini-3-pro-image-preview"
}
return &GeminiImageClient{
BaseURL: baseURL,
APIKey: apiKey,
Model: model,
Endpoint: endpoint,
HTTPClient: &http.Client{
Timeout: 10 * time.Minute,
},
}
}
func (c *GeminiImageClient) GenerateImage(prompt string, opts ...ImageOption) (*ImageResult, error) {
options := &ImageOptions{
Size: "1024x1024",
Quality: "standard",
}
for _, opt := range opts {
opt(options)
}
model := c.Model
if options.Model != "" {
model = options.Model
}
promptText := prompt
if options.NegativePrompt != "" {
promptText += fmt.Sprintf("\n\nNegative prompt: %s", options.NegativePrompt)
}
if options.Size != "" {
promptText += fmt.Sprintf("\n\nImage size: %s", options.Size)
}
// 构建请求的 parts支持参考图
parts := []GeminiPart{}
// 如果有参考图,先添加参考图
if len(options.ReferenceImages) > 0 {
for _, refImg := range options.ReferenceImages {
var base64Data string
var mimeType string
var err error
// 检查是否是 HTTP/HTTPS URL
if strings.HasPrefix(refImg, "http://") || strings.HasPrefix(refImg, "https://") {
// 下载图片并转换为 base64
base64Data, mimeType, err = downloadImageToBase64(refImg)
if err != nil {
continue
}
} else if strings.HasPrefix(refImg, "data:") {
// 如果是 data URI 格式,需要解析
// 格式: data:image/jpeg;base64,xxxxx
mimeType = "image/jpeg"
parts := []byte(refImg)
for i := 0; i < len(parts); i++ {
if parts[i] == ',' {
base64Data = refImg[i+1:]
// 提取 mime type
if i > 11 {
mimeTypeEnd := i
for j := 5; j < i; j++ {
if parts[j] == ';' {
mimeTypeEnd = j
break
}
}
mimeType = refImg[5:mimeTypeEnd]
}
break
}
}
} else {
// 假设已经是 base64 编码
base64Data = refImg
mimeType = "image/jpeg"
}
if base64Data != "" {
parts = append(parts, GeminiPart{
InlineData: &GeminiInlineData{
MimeType: mimeType,
Data: base64Data,
},
})
}
}
}
// 添加文本提示词
parts = append(parts, GeminiPart{
Text: promptText,
})
reqBody := GeminiImageRequest{
Contents: []struct {
Parts []GeminiPart `json:"parts"`
}{
{
Parts: parts,
},
},
GenerationConfig: struct {
ResponseModalities []string `json:"responseModalities"`
}{
ResponseModalities: []string{"IMAGE"},
},
}
jsonData, err := json.Marshal(reqBody)
if err != nil {
return nil, fmt.Errorf("marshal request: %w", err)
}
endpoint := c.BaseURL + c.Endpoint
endpoint = replaceModelPlaceholder(endpoint, model)
url := fmt.Sprintf("%s?key=%s", endpoint, c.APIKey)
req, err := http.NewRequest("POST", url, bytes.NewBuffer(jsonData))
if err != nil {
return nil, fmt.Errorf("create request: %w", err)
}
req.Header.Set("Content-Type", "application/json")
resp, err := c.HTTPClient.Do(req)
if err != nil {
return nil, fmt.Errorf("send request: %w", err)
}
defer resp.Body.Close()
body, err := io.ReadAll(resp.Body)
if err != nil {
return nil, fmt.Errorf("read response: %w", err)
}
if resp.StatusCode != http.StatusOK {
bodyStr := string(body)
if len(bodyStr) > 1000 {
bodyStr = fmt.Sprintf("%s ... %s", bodyStr[:500], bodyStr[len(bodyStr)-500:])
}
return nil, fmt.Errorf("API error (status %d): %s", resp.StatusCode, bodyStr)
}
var result GeminiImageResponse
if err := json.Unmarshal(body, &result); err != nil {
return nil, fmt.Errorf("parse response: %w", err)
}
if len(result.Candidates) == 0 || len(result.Candidates[0].Content.Parts) == 0 {
return nil, fmt.Errorf("no image generated in response")
}
base64Data := result.Candidates[0].Content.Parts[0].InlineData.Data
if base64Data == "" {
return nil, fmt.Errorf("no base64 image data in response")
}
dataURI := fmt.Sprintf("data:image/jpeg;base64,%s", base64Data)
return &ImageResult{
Status: "completed",
ImageURL: dataURI,
Completed: true,
Width: 1024,
Height: 1024,
}, nil
}
func (c *GeminiImageClient) GetTaskStatus(taskID string) (*ImageResult, error) {
return nil, fmt.Errorf("not supported for Gemini (synchronous generation)")
}
func replaceModelPlaceholder(endpoint, model string) string {
result := endpoint
if bytes.Contains([]byte(result), []byte("{model}")) {
result = string(bytes.ReplaceAll([]byte(result), []byte("{model}"), []byte(model)))
}
return result
}

View File

@@ -1,14 +1,5 @@
package image package image
import (
"bytes"
"encoding/json"
"fmt"
"io"
"net/http"
"time"
)
type ImageClient interface { type ImageClient interface {
GenerateImage(prompt string, opts ...ImageOption) (*ImageResult, error) GenerateImage(prompt string, opts ...ImageOption) (*ImageResult, error)
GetTaskStatus(taskID string) (*ImageResult, error) GetTaskStatus(taskID string) (*ImageResult, error)
@@ -100,285 +91,3 @@ func WithReferenceImages(images []string) ImageOption {
o.ReferenceImages = images o.ReferenceImages = images
} }
} }
type OpenAIImageClient struct {
BaseURL string
APIKey string
Model string
HTTPClient *http.Client
}
type DALLERequest struct {
Model string `json:"model"`
Prompt string `json:"prompt"`
Size string `json:"size,omitempty"`
Quality string `json:"quality,omitempty"`
N int `json:"n"`
Image []string `json:"image,omitempty"` // 参考图片URL列表
}
type DALLEResponse struct {
Created int64 `json:"created"`
Data []struct {
URL string `json:"url"`
RevisedPrompt string `json:"revised_prompt,omitempty"`
} `json:"data"`
}
func NewOpenAIImageClient(baseURL, apiKey, model string) *OpenAIImageClient {
return &OpenAIImageClient{
BaseURL: baseURL,
APIKey: apiKey,
Model: model,
HTTPClient: &http.Client{
Timeout: 10 * time.Minute,
},
}
}
func (c *OpenAIImageClient) GenerateImage(prompt string, opts ...ImageOption) (*ImageResult, error) {
options := &ImageOptions{
Size: "1920x1920",
Quality: "standard",
}
for _, opt := range opts {
opt(options)
}
model := c.Model
if options.Model != "" {
model = options.Model
}
reqBody := DALLERequest{
Model: model,
Prompt: prompt,
Size: options.Size,
Quality: options.Quality,
N: 1,
Image: options.ReferenceImages,
}
jsonData, err := json.Marshal(reqBody)
if err != nil {
return nil, fmt.Errorf("marshal request: %w", err)
}
endpoint := c.BaseURL + "/v1/images/generations"
req, err := http.NewRequest("POST", endpoint, bytes.NewBuffer(jsonData))
if err != nil {
return nil, fmt.Errorf("create request: %w", err)
}
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", "Bearer "+c.APIKey)
resp, err := c.HTTPClient.Do(req)
if err != nil {
return nil, fmt.Errorf("send request: %w", err)
}
defer resp.Body.Close()
body, err := io.ReadAll(resp.Body)
if err != nil {
return nil, fmt.Errorf("read response: %w", err)
}
if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("API error (status %d): %s", resp.StatusCode, string(body))
}
// 打印原始响应以便调试
fmt.Printf("OpenAI API Response: %s\n", string(body))
var result DALLEResponse
if err := json.Unmarshal(body, &result); err != nil {
return nil, fmt.Errorf("parse response: %w, body: %s", err, string(body))
}
if len(result.Data) == 0 {
return nil, fmt.Errorf("no image generated, response: %s", string(body))
}
return &ImageResult{
Status: "completed",
ImageURL: result.Data[0].URL,
Completed: true,
}, nil
}
func (c *OpenAIImageClient) GetTaskStatus(taskID string) (*ImageResult, error) {
return nil, fmt.Errorf("not supported for OpenAI/DALL-E")
}
type StableDiffusionClient struct {
BaseURL string
APIKey string
Model string
HTTPClient *http.Client
}
type SDRequest struct {
Prompt string `json:"prompt"`
NegativePrompt string `json:"negative_prompt,omitempty"`
Model string `json:"model,omitempty"`
Width int `json:"width,omitempty"`
Height int `json:"height,omitempty"`
Steps int `json:"steps,omitempty"`
CfgScale float64 `json:"cfg_scale,omitempty"`
Seed int64 `json:"seed,omitempty"`
Samples int `json:"samples"`
Image []string `json:"image,omitempty"` // 参考图片URL列表
}
type SDResponse struct {
Status string `json:"status"`
TaskID string `json:"task_id,omitempty"`
Output []struct {
URL string `json:"url"`
} `json:"output,omitempty"`
Error string `json:"error,omitempty"`
}
func NewStableDiffusionClient(baseURL, apiKey, model string) *StableDiffusionClient {
return &StableDiffusionClient{
BaseURL: baseURL,
APIKey: apiKey,
Model: model,
HTTPClient: &http.Client{
Timeout: 10 * time.Minute,
},
}
}
func (c *StableDiffusionClient) GenerateImage(prompt string, opts ...ImageOption) (*ImageResult, error) {
options := &ImageOptions{
Width: 1024,
Height: 1024,
Steps: 30,
CfgScale: 7.5,
}
for _, opt := range opts {
opt(options)
}
model := c.Model
if options.Model != "" {
model = options.Model
}
reqBody := SDRequest{
Prompt: prompt,
NegativePrompt: options.NegativePrompt,
Model: model,
Width: options.Width,
Height: options.Height,
Steps: options.Steps,
CfgScale: options.CfgScale,
Seed: options.Seed,
Samples: 1,
Image: options.ReferenceImages,
}
jsonData, err := json.Marshal(reqBody)
if err != nil {
return nil, fmt.Errorf("marshal request: %w", err)
}
endpoint := c.BaseURL + "/v1/images/generations"
req, err := http.NewRequest("POST", endpoint, bytes.NewBuffer(jsonData))
if err != nil {
return nil, fmt.Errorf("create request: %w", err)
}
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", "Bearer "+c.APIKey)
resp, err := c.HTTPClient.Do(req)
if err != nil {
return nil, fmt.Errorf("send request: %w", err)
}
defer resp.Body.Close()
body, err := io.ReadAll(resp.Body)
if err != nil {
return nil, fmt.Errorf("read response: %w", err)
}
if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("API error (status %d): %s", resp.StatusCode, string(body))
}
var result SDResponse
if err := json.Unmarshal(body, &result); err != nil {
return nil, fmt.Errorf("parse response: %w", err)
}
if result.Error != "" {
return nil, fmt.Errorf("SD error: %s", result.Error)
}
if result.Status == "processing" {
return &ImageResult{
TaskID: result.TaskID,
Status: "processing",
Completed: false,
}, nil
}
if len(result.Output) == 0 {
return nil, fmt.Errorf("no image generated")
}
return &ImageResult{
Status: "completed",
ImageURL: result.Output[0].URL,
Width: options.Width,
Height: options.Height,
Completed: true,
}, nil
}
func (c *StableDiffusionClient) GetTaskStatus(taskID string) (*ImageResult, error) {
endpoint := c.BaseURL + "/v1/images/status/" + taskID
req, err := http.NewRequest("GET", endpoint, nil)
if err != nil {
return nil, fmt.Errorf("create request: %w", err)
}
req.Header.Set("Authorization", "Bearer "+c.APIKey)
resp, err := c.HTTPClient.Do(req)
if err != nil {
return nil, fmt.Errorf("send request: %w", err)
}
defer resp.Body.Close()
body, err := io.ReadAll(resp.Body)
if err != nil {
return nil, fmt.Errorf("read response: %w", err)
}
var result SDResponse
if err := json.Unmarshal(body, &result); err != nil {
return nil, fmt.Errorf("parse response: %w", err)
}
imageResult := &ImageResult{
TaskID: taskID,
Status: result.Status,
Completed: result.Status == "completed",
}
if result.Error != "" {
imageResult.Error = result.Error
}
if len(result.Output) > 0 {
imageResult.ImageURL = result.Output[0].URL
}
return imageResult, nil
}

View File

@@ -0,0 +1,128 @@
package image
import (
"bytes"
"encoding/json"
"fmt"
"io"
"net/http"
"time"
)
type OpenAIImageClient struct {
BaseURL string
APIKey string
Model string
Endpoint string
HTTPClient *http.Client
}
type DALLERequest struct {
Model string `json:"model"`
Prompt string `json:"prompt"`
Size string `json:"size,omitempty"`
Quality string `json:"quality,omitempty"`
N int `json:"n"`
Image []string `json:"image,omitempty"`
}
type DALLEResponse struct {
Created int64 `json:"created"`
Data []struct {
URL string `json:"url"`
RevisedPrompt string `json:"revised_prompt,omitempty"`
} `json:"data"`
}
func NewOpenAIImageClient(baseURL, apiKey, model, endpoint string) *OpenAIImageClient {
if endpoint == "" {
endpoint = "/v1/images/generations"
}
return &OpenAIImageClient{
BaseURL: baseURL,
APIKey: apiKey,
Model: model,
Endpoint: endpoint,
HTTPClient: &http.Client{
Timeout: 10 * time.Minute,
},
}
}
func (c *OpenAIImageClient) GenerateImage(prompt string, opts ...ImageOption) (*ImageResult, error) {
options := &ImageOptions{
Size: "1920x1920",
Quality: "standard",
}
for _, opt := range opts {
opt(options)
}
model := c.Model
if options.Model != "" {
model = options.Model
}
reqBody := DALLERequest{
Model: model,
Prompt: prompt,
Size: options.Size,
Quality: options.Quality,
N: 1,
Image: options.ReferenceImages,
}
jsonData, err := json.Marshal(reqBody)
if err != nil {
return nil, fmt.Errorf("marshal request: %w", err)
}
url := c.BaseURL + c.Endpoint
fmt.Printf("[OpenAI Image] Request URL: %s\n", url)
fmt.Printf("[OpenAI Image] Request Body: %s\n", string(jsonData))
req, err := http.NewRequest("POST", url, bytes.NewBuffer(jsonData))
if err != nil {
return nil, fmt.Errorf("create request: %w", err)
}
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", "Bearer "+c.APIKey)
resp, err := c.HTTPClient.Do(req)
if err != nil {
return nil, fmt.Errorf("send request: %w", err)
}
defer resp.Body.Close()
body, err := io.ReadAll(resp.Body)
if err != nil {
return nil, fmt.Errorf("read response: %w", err)
}
if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("API error (status %d): %s", resp.StatusCode, string(body))
}
fmt.Printf("OpenAI API Response: %s\n", string(body))
var result DALLEResponse
if err := json.Unmarshal(body, &result); err != nil {
return nil, fmt.Errorf("parse response: %w, body: %s", err, string(body))
}
if len(result.Data) == 0 {
return nil, fmt.Errorf("no image generated, response: %s", string(body))
}
return &ImageResult{
Status: "completed",
ImageURL: result.Data[0].URL,
Completed: true,
}, nil
}
func (c *OpenAIImageClient) GetTaskStatus(taskID string) (*ImageResult, error) {
return nil, fmt.Errorf("not supported for OpenAI/DALL-E")
}

View File

@@ -0,0 +1,158 @@
package image
import (
"bytes"
"encoding/json"
"fmt"
"io"
"net/http"
"time"
)
type VolcEngineImageClient struct {
BaseURL string
APIKey string
Model string
Endpoint string
QueryEndpoint string
HTTPClient *http.Client
}
type VolcEngineImageRequest struct {
Model string `json:"model"`
Prompt string `json:"prompt"`
Image []string `json:"image,omitempty"`
SequentialImageGeneration string `json:"sequential_image_generation,omitempty"`
Size string `json:"size,omitempty"`
Watermark bool `json:"watermark,omitempty"`
}
type VolcEngineImageResponse struct {
Model string `json:"model"`
Created int64 `json:"created"`
Data []struct {
URL string `json:"url"`
Size string `json:"size"`
} `json:"data"`
Usage struct {
GeneratedImages int `json:"generated_images"`
OutputTokens int `json:"output_tokens"`
TotalTokens int `json:"total_tokens"`
} `json:"usage"`
Error interface{} `json:"error,omitempty"`
}
func NewVolcEngineImageClient(baseURL, apiKey, model, endpoint, queryEndpoint string) *VolcEngineImageClient {
if endpoint == "" {
endpoint = "/api/v3/images/generations"
}
if queryEndpoint == "" {
queryEndpoint = endpoint
}
return &VolcEngineImageClient{
BaseURL: baseURL,
APIKey: apiKey,
Model: model,
Endpoint: endpoint,
QueryEndpoint: queryEndpoint,
HTTPClient: &http.Client{
Timeout: 10 * time.Minute,
},
}
}
func (c *VolcEngineImageClient) GenerateImage(prompt string, opts ...ImageOption) (*ImageResult, error) {
options := &ImageOptions{
Size: "1024x1024",
Quality: "standard",
}
for _, opt := range opts {
opt(options)
}
model := c.Model
if options.Model != "" {
model = options.Model
}
promptText := prompt
if options.NegativePrompt != "" {
promptText += fmt.Sprintf(". Negative: %s", options.NegativePrompt)
}
size := options.Size
if size == "" {
if model == "doubao-seedream-4-5-251128" {
size = "2K"
} else {
size = "1K"
}
}
reqBody := VolcEngineImageRequest{
Model: model,
Prompt: promptText,
Image: options.ReferenceImages,
SequentialImageGeneration: "disabled",
Size: size,
Watermark: false,
}
jsonData, err := json.Marshal(reqBody)
if err != nil {
return nil, fmt.Errorf("marshal request: %w", err)
}
url := c.BaseURL + c.Endpoint
fmt.Printf("[VolcEngine Image] Request URL: %s\n", url)
fmt.Printf("[VolcEngine Image] Request Body: %s\n", string(jsonData))
req, err := http.NewRequest("POST", url, bytes.NewBuffer(jsonData))
if err != nil {
return nil, fmt.Errorf("create request: %w", err)
}
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", "Bearer "+c.APIKey)
resp, err := c.HTTPClient.Do(req)
if err != nil {
return nil, fmt.Errorf("send request: %w", err)
}
defer resp.Body.Close()
body, err := io.ReadAll(resp.Body)
if err != nil {
return nil, fmt.Errorf("read response: %w", err)
}
fmt.Printf("VolcEngine Image API Response: %s\n", string(body))
if resp.StatusCode != http.StatusOK && resp.StatusCode != http.StatusCreated {
return nil, fmt.Errorf("API error (status %d): %s", resp.StatusCode, string(body))
}
var result VolcEngineImageResponse
if err := json.Unmarshal(body, &result); err != nil {
return nil, fmt.Errorf("parse response: %w", err)
}
if result.Error != nil {
return nil, fmt.Errorf("volcengine error: %v", result.Error)
}
if len(result.Data) == 0 {
return nil, fmt.Errorf("no image generated")
}
return &ImageResult{
Status: "completed",
ImageURL: result.Data[0].URL,
Completed: true,
}, nil
}
func (c *VolcEngineImageClient) GetTaskStatus(taskID string) (*ImageResult, error) {
return nil, fmt.Errorf("not supported for VolcEngine Seedream (synchronous generation)")
}

View File

@@ -0,0 +1,184 @@
package video
import (
"bytes"
"encoding/json"
"fmt"
"io"
"net/http"
"strings"
"time"
)
// ChatfireClient Chatfire 视频生成客户端
type ChatfireClient struct {
BaseURL string
APIKey string
Model string
Endpoint string
QueryEndpoint string
HTTPClient *http.Client
}
type ChatfireRequest struct {
Model string `json:"model"`
Prompt string `json:"prompt"`
ImageURL string `json:"image_url,omitempty"`
Duration int `json:"duration,omitempty"`
Size string `json:"size,omitempty"`
}
type ChatfireResponse struct {
TaskID string `json:"task_id"`
Status string `json:"status"`
Error string `json:"error,omitempty"`
}
type ChatfireTaskResponse struct {
TaskID string `json:"task_id"`
Status string `json:"status"`
VideoURL string `json:"video_url,omitempty"`
Error string `json:"error,omitempty"`
}
func NewChatfireClient(baseURL, apiKey, model, endpoint, queryEndpoint string) *ChatfireClient {
if endpoint == "" {
endpoint = "/video/generations"
}
if queryEndpoint == "" {
queryEndpoint = "/v1/video/task/{taskId}"
}
return &ChatfireClient{
BaseURL: baseURL,
APIKey: apiKey,
Model: model,
Endpoint: endpoint,
QueryEndpoint: queryEndpoint,
HTTPClient: &http.Client{
Timeout: 300 * time.Second,
},
}
}
func (c *ChatfireClient) GenerateVideo(imageURL, prompt string, opts ...VideoOption) (*VideoResult, error) {
options := &VideoOptions{
Duration: 5,
AspectRatio: "16:9",
}
for _, opt := range opts {
opt(options)
}
model := c.Model
if options.Model != "" {
model = options.Model
}
reqBody := ChatfireRequest{
Model: model,
Prompt: prompt,
ImageURL: imageURL,
Duration: options.Duration,
Size: options.AspectRatio,
}
jsonData, err := json.Marshal(reqBody)
if err != nil {
return nil, fmt.Errorf("marshal request: %w", err)
}
endpoint := c.BaseURL + c.Endpoint
req, err := http.NewRequest("POST", endpoint, bytes.NewBuffer(jsonData))
if err != nil {
return nil, fmt.Errorf("create request: %w", err)
}
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", "Bearer "+c.APIKey)
resp, err := c.HTTPClient.Do(req)
if err != nil {
return nil, fmt.Errorf("send request: %w", err)
}
defer resp.Body.Close()
body, err := io.ReadAll(resp.Body)
if err != nil {
return nil, fmt.Errorf("read response: %w", err)
}
if resp.StatusCode != http.StatusOK && resp.StatusCode != http.StatusCreated {
return nil, fmt.Errorf("API error (status %d): %s", resp.StatusCode, string(body))
}
var result ChatfireResponse
if err := json.Unmarshal(body, &result); err != nil {
return nil, fmt.Errorf("parse response: %w", err)
}
if result.Error != "" {
return nil, fmt.Errorf("chatfire error: %s", result.Error)
}
videoResult := &VideoResult{
TaskID: result.TaskID,
Status: result.Status,
Completed: result.Status == "completed" || result.Status == "succeeded",
Duration: options.Duration,
}
return videoResult, nil
}
func (c *ChatfireClient) GetTaskStatus(taskID string) (*VideoResult, error) {
queryPath := c.QueryEndpoint
if strings.Contains(queryPath, "{taskId}") {
queryPath = strings.ReplaceAll(queryPath, "{taskId}", taskID)
} else if strings.Contains(queryPath, "{task_id}") {
queryPath = strings.ReplaceAll(queryPath, "{task_id}", taskID)
} else {
queryPath = queryPath + "/" + taskID
}
endpoint := c.BaseURL + queryPath
req, err := http.NewRequest("GET", endpoint, nil)
if err != nil {
return nil, fmt.Errorf("create request: %w", err)
}
req.Header.Set("Authorization", "Bearer "+c.APIKey)
resp, err := c.HTTPClient.Do(req)
if err != nil {
return nil, fmt.Errorf("send request: %w", err)
}
defer resp.Body.Close()
body, err := io.ReadAll(resp.Body)
if err != nil {
return nil, fmt.Errorf("read response: %w", err)
}
var result ChatfireTaskResponse
if err := json.Unmarshal(body, &result); err != nil {
return nil, fmt.Errorf("parse response: %w", err)
}
videoResult := &VideoResult{
TaskID: result.TaskID,
Status: result.Status,
Completed: result.Status == "completed" || result.Status == "succeeded",
}
if result.Error != "" {
videoResult.Error = result.Error
}
if result.VideoURL != "" {
videoResult.VideoURL = result.VideoURL
videoResult.Completed = true
}
return videoResult, nil
}

View File

@@ -83,7 +83,7 @@ func (c *OpenAISoraClient) GenerateVideo(imageURL, prompt string, opts ...VideoO
writer.Close() writer.Close()
endpoint := c.BaseURL + "/v1/videos" endpoint := c.BaseURL + "/videos"
req, err := http.NewRequest("POST", endpoint, body) req, err := http.NewRequest("POST", endpoint, body)
if err != nil { if err != nil {
return nil, fmt.Errorf("create request: %w", err) return nil, fmt.Errorf("create request: %w", err)

View File

@@ -9,7 +9,7 @@
</template> </template>
</el-page-header> </el-page-header>
<el-tabs v-model="activeTab" @tab-change="loadConfigs"> <el-tabs v-model="activeTab" @tab-change="handleTabChange">
<el-tab-pane label="文本生成" name="text"> <el-tab-pane label="文本生成" name="text">
<ConfigList <ConfigList
:configs="configs" :configs="configs"
@@ -114,7 +114,11 @@
<el-form-item label="Base URL" prop="base_url"> <el-form-item label="Base URL" prop="base_url">
<el-input v-model="form.base_url" placeholder="https://api.openai.com" /> <el-input v-model="form.base_url" placeholder="https://api.openai.com" />
<div class="form-tip">API 服务的基础地址</div> <div class="form-tip">
API 服务的基础地址 Chatfire: https://api.chatfire.site/v1Gemini: https://generativelanguage.googleapis.com无需 /v1
<br>
完整调用路径: {{ fullEndpointExample }}
</div>
</el-form-item> </el-form-item>
<el-form-item label="API Key" prop="api_key"> <el-form-item label="API Key" prop="api_key">
@@ -127,16 +131,6 @@
<div class="form-tip">您的 API 密钥</div> <div class="form-tip">您的 API 密钥</div>
</el-form-item> </el-form-item>
<el-form-item label="端点路径" prop="endpoint">
<el-input v-model="form.endpoint" placeholder="/v1/chat/completions" />
<div class="form-tip">API 端点路径默认为 /v1/chat/completions</div>
</el-form-item>
<el-form-item v-if="form.service_type === 'video'" label="查询端点" prop="query_endpoint">
<el-input v-model="form.query_endpoint" placeholder="/v1/video/task/{taskId}" />
<div class="form-tip">异步任务查询端点仅视频服务需要支持 {taskId} 占位符</div>
</el-form-item>
<el-form-item v-if="isEdit" label="启用状态"> <el-form-item v-if="isEdit" label="启用状态">
<el-switch v-model="form.is_active" /> <el-switch v-model="form.is_active" />
</el-form-item> </el-form-item>
@@ -181,8 +175,6 @@ const form = reactive<CreateAIConfigRequest & { is_active?: boolean, provider?:
base_url: '', base_url: '',
api_key: '', api_key: '',
model: [], // 改为数组支持多选 model: [], // 改为数组支持多选
endpoint: '/v1/chat/completions',
query_endpoint: '', // 异步查询端点
priority: 0, // 默认优先级为0 priority: 0, // 默认优先级为0
is_active: true is_active: true
}) })
@@ -197,10 +189,54 @@ interface ProviderConfig {
const providerConfigs: Record<AIServiceType, ProviderConfig[]> = { const providerConfigs: Record<AIServiceType, ProviderConfig[]> = {
text: [ text: [
{ id: 'openai', name: 'OpenAI', models: ['gpt-5.2', 'gemini-3-pro-preview'], disabled: true } { id: 'openai', name: 'OpenAI', models: ['gpt-5.2', 'gemini-3-pro-preview'] },
{
id: 'chatfire',
name: 'Chatfire',
models: [
'gpt-4o',
'claude-sonnet-4-5-20250929',
'doubao-seed-1-8-251228',
'kimi-k2-thinking',
'gemini-3-pro',
'gemini-2.5-pro',
'gemini-3-pro-preview'
]
},
{
id: 'gemini',
name: 'Google Gemini',
models: [
'gemini-2.5-pro',
'gemini-3-pro-preview'
]
}
], ],
image: [ image: [
{ id: 'openai', name: 'OpenAI', models: ['nano-banana-pro', 'doubao-seedream-4-5-251128'] } {
id: 'volcengine',
name: '火山引擎',
models: [
'doubao-seedream-4-5-251128',
'doubao-seedream-4-0-250828',
]
},
{
id: 'chatfire',
name: 'Chatfire',
models: [
'doubao-seedream-4-5-251128',
'nano-banana-pro',
]
},
{
id: 'gemini',
name: 'Google Gemini',
models: [
'gemini-3-pro-image-preview',
]
},
{ id: 'openai', name: 'OpenAI', models: ['dall-e-3', 'dall-e-2'] }
], ],
video: [ video: [
{ {
@@ -214,6 +250,19 @@ const providerConfigs: Record<AIServiceType, ProviderConfig[]> = {
'doubao-seedance-1-0-pro-fast-251015' 'doubao-seedance-1-0-pro-fast-251015'
] ]
}, },
{
id: 'chatfire',
name: 'Chatfire',
models: [
'doubao-seedance-1-5-pro-251215',
'doubao-seedance-1-0-lite-i2v-250428',
'doubao-seedance-1-0-lite-t2v-250428',
'doubao-seedance-1-0-pro-250528',
'doubao-seedance-1-0-pro-fast-251015',
'sora',
'sora-pro'
]
},
{ id: 'openai', name: 'OpenAI', models: ['sora-2', 'sora-2-pro'] }, { id: 'openai', name: 'OpenAI', models: ['sora-2', 'sora-2-pro'] },
// { id: 'minimax', name: 'MiniMax', models: ['MiniMax-Hailuo-2.3', 'MiniMax-Hailuo-2.3-Fast', 'MiniMax-Hailuo-02'] } // { id: 'minimax', name: 'MiniMax', models: ['MiniMax-Hailuo-2.3', 'MiniMax-Hailuo-2.3-Fast', 'MiniMax-Hailuo-02'] }
] ]
@@ -231,6 +280,41 @@ const availableModels = computed(() => {
return provider?.models || [] return provider?.models || []
}) })
// 完整端点示例
const fullEndpointExample = computed(() => {
const baseUrl = form.base_url || 'https://api.example.com'
const provider = form.provider
const serviceType = form.service_type
let endpoint = ''
if (serviceType === 'text') {
if (provider === 'gemini' || provider === 'google') {
endpoint = '/v1beta/models/{model}:generateContent'
} else {
endpoint = '/chat/completions'
}
} else if (serviceType === 'image') {
if (provider === 'gemini' || provider === 'google') {
endpoint = '/v1beta/models/{model}:generateContent'
} else {
endpoint = '/images/generations'
}
} else if (serviceType === 'video') {
if (provider === 'chatfire') {
endpoint = '/video/generations'
} else if (provider === 'doubao' || provider === 'volcengine' || provider === 'volces') {
endpoint = '/contents/generations/tasks'
} else if (provider === 'openai') {
endpoint = '/videos'
} else {
endpoint = '/video/generations'
}
}
return baseUrl + endpoint
})
const rules: FormRules = { const rules: FormRules = {
name: [ name: [
{ required: true, message: '请输入配置名称', trigger: 'blur' } { required: true, message: '请输入配置名称', trigger: 'blur' }
@@ -274,17 +358,39 @@ const loadConfigs = async () => {
} }
} }
// 生成随机配置名称
const generateConfigName = (provider: string, serviceType: AIServiceType): string => {
const providerNames: Record<string, string> = {
'chatfire': 'ChatFire',
'openai': 'OpenAI',
'gemini': 'Gemini',
'google': 'Google'
}
const serviceNames: Record<AIServiceType, string> = {
'text': '文本',
'image': '图片',
'video': '视频'
}
const randomNum = Math.floor(Math.random() * 10000).toString().padStart(4, '0')
const providerName = providerNames[provider] || provider
const serviceName = serviceNames[serviceType] || serviceType
return `${providerName}-${serviceName}-${randomNum}`
}
const showCreateDialog = () => { const showCreateDialog = () => {
isEdit.value = false isEdit.value = false
editingId.value = undefined editingId.value = undefined
resetForm() resetForm()
form.service_type = activeTab.value form.service_type = activeTab.value
// 根据服务类型设置默认端点路径 // 默认选择 chatfire
form.endpoint = getDefaultEndpoint(activeTab.value) form.provider = 'chatfire'
// 文本生成默认选择openai // 设置默认 base_url
if (activeTab.value === 'text') { form.base_url = 'https://api.chatfire.site/v1'
form.provider = 'openai' // 自动生成随机配置名称
} form.name = generateConfigName('chatfire', activeTab.value)
dialogVisible.value = true dialogVisible.value = true
} }
@@ -292,26 +398,13 @@ const handleEdit = (config: AIServiceConfig) => {
isEdit.value = true isEdit.value = true
editingId.value = config.id editingId.value = config.id
// 根据模型名称推断厂商
const inferProvider = (model: string, serviceType: AIServiceType): string => {
const providers = providerConfigs[serviceType]
for (const provider of providers) {
if (provider.models.includes(model)) {
return provider.id
}
}
return providers[0]?.id || ''
}
Object.assign(form, { Object.assign(form, {
service_type: config.service_type, service_type: config.service_type,
provider: inferProvider(Array.isArray(config.model) ? config.model[0] : config.model, config.service_type), provider: config.provider || 'chatfire', // 直接使用配置中的 provider默认为 chatfire
name: config.name, name: config.name,
base_url: config.base_url, base_url: config.base_url,
api_key: config.api_key, api_key: config.api_key,
model: Array.isArray(config.model) ? config.model : [config.model], // 统一转换为数组 model: Array.isArray(config.model) ? config.model : [config.model], // 统一转换为数组
endpoint: config.endpoint,
query_endpoint: config.query_endpoint || '',
priority: config.priority || 0, priority: config.priority || 0,
is_active: config.is_active is_active: config.is_active
}) })
@@ -359,7 +452,7 @@ const testConnection = async () => {
base_url: form.base_url, base_url: form.base_url,
api_key: form.api_key, api_key: form.api_key,
model: form.model, model: form.model,
endpoint: form.endpoint provider: form.provider
}) })
ElMessage.success('连接测试成功!') ElMessage.success('连接测试成功!')
} catch (error: any) { } catch (error: any) {
@@ -376,7 +469,7 @@ const handleTest = async (config: AIServiceConfig) => {
base_url: config.base_url, base_url: config.base_url,
api_key: config.api_key, api_key: config.api_key,
model: config.model, model: config.model,
endpoint: config.endpoint provider: config.provider
}) })
ElMessage.success('连接测试成功!') ElMessage.success('连接测试成功!')
} catch (error: any) { } catch (error: any) {
@@ -397,11 +490,10 @@ const handleSubmit = async () => {
if (isEdit.value && editingId.value) { if (isEdit.value && editingId.value) {
const updateData: UpdateAIConfigRequest = { const updateData: UpdateAIConfigRequest = {
name: form.name, name: form.name,
provider: form.provider,
base_url: form.base_url, base_url: form.base_url,
api_key: form.api_key, api_key: form.api_key,
model: form.model, model: form.model,
endpoint: form.endpoint,
query_endpoint: form.query_endpoint,
priority: form.priority, priority: form.priority,
is_active: form.is_active is_active: form.is_active
} }
@@ -422,16 +514,36 @@ const handleSubmit = async () => {
}) })
} }
const handleTabChange = (tabName: string | number) => {
// 标签页切换时重新加载对应服务类型的配置
activeTab.value = tabName as AIServiceType
loadConfigs()
}
const handleProviderChange = () => { const handleProviderChange = () => {
// 切换厂商时清空已选模型 // 切换厂商时清空已选模型
form.model = [] form.model = []
// 根据厂商自动设置默认 base_url
if (form.provider === 'gemini' || form.provider === 'google') {
form.base_url = 'https://api.chatfire.site'
} else {
// openai, chatfire 等其他厂商
form.base_url = 'https://api.chatfire.site/v1'
}
// 仅在新建配置时自动更新名称
if (!isEdit.value) {
form.name = generateConfigName(form.provider, form.service_type)
}
} }
// 根据服务类型获取默认端点路径 // getDefaultEndpoint 已移除,端点由后端根据 provider 自动设置
// 保留该函数定义以避免编译错误
const getDefaultEndpoint = (serviceType: AIServiceType): string => { const getDefaultEndpoint = (serviceType: AIServiceType): string => {
switch (serviceType) { switch (serviceType) {
case 'text': case 'text':
return '/v1/chat/completions' return ''
case 'image': case 'image':
return '/v1/images/generations' return '/v1/images/generations'
case 'video': case 'video':
@@ -450,8 +562,6 @@ const resetForm = () => {
base_url: '', base_url: '',
api_key: '', api_key: '',
model: [], // 改为空数组 model: [], // 改为空数组
endpoint: getDefaultEndpoint(serviceType),
query_endpoint: '',
priority: 0, priority: 0,
is_active: true is_active: true
}) })