添加chat gemini、chatfire端点、 图片生成 gemini、chatfire 更轻松的AI配置

This commit is contained in:
Connor
2026-01-14 02:25:41 +08:00
parent 4d38357ff6
commit 23b45efae9
22 changed files with 1512 additions and 405 deletions

View File

@@ -25,6 +25,7 @@ func NewAIService(db *gorm.DB, log *logger.Logger) *AIService {
type CreateAIConfigRequest struct {
ServiceType string `json:"service_type" binding:"required,oneof=text image video"`
Name string `json:"name" binding:"required,min=1,max=100"`
Provider string `json:"provider" binding:"required"`
BaseURL string `json:"base_url" binding:"required,url"`
APIKey string `json:"api_key" binding:"required"`
Model models.ModelField `json:"model" binding:"required"`
@@ -37,6 +38,7 @@ type CreateAIConfigRequest struct {
type UpdateAIConfigRequest struct {
Name string `json:"name" binding:"omitempty,min=1,max=100"`
Provider string `json:"provider"`
BaseURL string `json:"base_url" binding:"omitempty,url"`
APIKey string `json:"api_key"`
Model *models.ModelField `json:"model"`
@@ -52,18 +54,53 @@ type TestConnectionRequest struct {
BaseURL string `json:"base_url" binding:"required,url"`
APIKey string `json:"api_key" binding:"required"`
Model models.ModelField `json:"model" binding:"required"`
Provider string `json:"provider"`
Endpoint string `json:"endpoint"`
}
func (s *AIService) CreateConfig(req *CreateAIConfigRequest) (*models.AIServiceConfig, error) {
// 根据 provider 和 service_type 自动设置 endpoint
endpoint := req.Endpoint
queryEndpoint := req.QueryEndpoint
if endpoint == "" {
switch req.Provider {
case "gemini", "google":
if req.ServiceType == "text" {
endpoint = "/v1beta/models/{model}:generateContent"
} else if req.ServiceType == "image" {
endpoint = "/v1beta/models/{model}:generateContent"
}
case "openai", "chatfire":
if req.ServiceType == "text" {
endpoint = "/chat/completions"
} else if req.ServiceType == "image" {
endpoint = "/images/generations"
} else if req.ServiceType == "video" {
endpoint = "/video/generations"
if queryEndpoint == "" {
queryEndpoint = "/v1/video/task/{taskId}"
}
}
default:
// 默认使用 OpenAI 格式
if req.ServiceType == "text" {
endpoint = "/chat/completions"
} else if req.ServiceType == "image" {
endpoint = "/images/generations"
}
}
}
config := &models.AIServiceConfig{
ServiceType: req.ServiceType,
Name: req.Name,
Provider: req.Provider,
BaseURL: req.BaseURL,
APIKey: req.APIKey,
Model: req.Model,
Endpoint: req.Endpoint,
QueryEndpoint: req.QueryEndpoint,
Endpoint: endpoint,
QueryEndpoint: queryEndpoint,
Priority: req.Priority,
IsDefault: req.IsDefault,
IsActive: true,
@@ -75,7 +112,7 @@ func (s *AIService) CreateConfig(req *CreateAIConfigRequest) (*models.AIServiceC
return nil, err
}
s.log.Infow("AI config created", "config_id", config.ID)
s.log.Infow("AI config created", "config_id", config.ID, "provider", req.Provider, "endpoint", endpoint)
return config, nil
}
@@ -125,6 +162,9 @@ func (s *AIService) UpdateConfig(configID uint, req *UpdateAIConfigRequest) (*mo
if req.Name != "" {
updates["name"] = req.Name
}
if req.Provider != "" {
updates["provider"] = req.Provider
}
if req.BaseURL != "" {
updates["base_url"] = req.BaseURL
}
@@ -137,9 +177,30 @@ func (s *AIService) UpdateConfig(configID uint, req *UpdateAIConfigRequest) (*mo
if req.Priority != nil {
updates["priority"] = *req.Priority
}
if req.Endpoint != "" {
// 如果提供了 provider根据 provider 和 service_type 自动设置 endpoint
if req.Provider != "" && req.Endpoint == "" {
provider := req.Provider
serviceType := config.ServiceType
switch provider {
case "gemini", "google":
if serviceType == "text" || serviceType == "image" {
updates["endpoint"] = "/v1beta/models/{model}:generateContent"
}
case "openai", "chatfire":
if serviceType == "text" {
updates["endpoint"] = "/chat/completions"
} else if serviceType == "image" {
updates["endpoint"] = "/images/generations"
} else if serviceType == "video" {
updates["endpoint"] = "/video/generations"
}
}
} else if req.Endpoint != "" {
updates["endpoint"] = req.Endpoint
}
// 允许清空query_endpoint所以不检查是否为空
updates["query_endpoint"] = req.QueryEndpoint
if req.Settings != "" {
@@ -179,13 +240,51 @@ func (s *AIService) DeleteConfig(configID uint) error {
}
func (s *AIService) TestConnection(req *TestConnectionRequest) error {
s.log.Infow("TestConnection called", "baseURL", req.BaseURL, "provider", req.Provider, "endpoint", req.Endpoint, "modelCount", len(req.Model))
// 使用第一个模型进行测试
model := ""
if len(req.Model) > 0 {
model = req.Model[0]
}
client := ai.NewOpenAIClient(req.BaseURL, req.APIKey, model, req.Endpoint)
return client.TestConnection()
s.log.Infow("Using model for test", "model", model, "provider", req.Provider)
// 根据 provider 参数选择客户端
var client ai.AIClient
var endpoint string
switch req.Provider {
case "gemini", "google":
// Gemini
s.log.Infow("Using Gemini client", "baseURL", req.BaseURL)
endpoint = "/v1beta/models/{model}:generateContent"
client = ai.NewGeminiClient(req.BaseURL, req.APIKey, model, endpoint)
case "openai", "chatfire":
// OpenAI 格式(包括 chatfire 等)
s.log.Infow("Using OpenAI-compatible client", "baseURL", req.BaseURL, "provider", req.Provider)
endpoint = req.Endpoint
if endpoint == "" {
endpoint = "/chat/completions"
}
client = ai.NewOpenAIClient(req.BaseURL, req.APIKey, model, endpoint)
default:
// 默认使用 OpenAI 格式
s.log.Infow("Using default OpenAI-compatible client", "baseURL", req.BaseURL)
endpoint = req.Endpoint
if endpoint == "" {
endpoint = "/chat/completions"
}
client = ai.NewOpenAIClient(req.BaseURL, req.APIKey, model, endpoint)
}
s.log.Infow("Calling TestConnection on client", "endpoint", endpoint)
err := client.TestConnection()
if err != nil {
s.log.Errorw("TestConnection failed", "error", err)
} else {
s.log.Infow("TestConnection succeeded")
}
return err
}
func (s *AIService) GetDefaultConfig(serviceType string) (*models.AIServiceConfig, error) {
@@ -228,7 +327,7 @@ func (s *AIService) GetConfigForModel(serviceType string, modelName string) (*mo
return nil, errors.New("no config found for model: " + modelName)
}
func (s *AIService) GetAIClient(serviceType string) (*ai.OpenAIClient, error) {
func (s *AIService) GetAIClient(serviceType string) (ai.AIClient, error) {
config, err := s.GetDefaultConfig(serviceType)
if err != nil {
return nil, err
@@ -240,7 +339,25 @@ func (s *AIService) GetAIClient(serviceType string) (*ai.OpenAIClient, error) {
model = config.Model[0]
}
return ai.NewOpenAIClient(config.BaseURL, config.APIKey, model, config.Endpoint), nil
// 使用数据库配置中的 endpoint如果为空则根据 provider 设置默认值
endpoint := config.Endpoint
if endpoint == "" {
switch config.Provider {
case "gemini", "google":
endpoint = "/v1beta/models/{model}:generateContent"
default:
endpoint = "/chat/completions"
}
}
// 根据 provider 创建对应的客户端
switch config.Provider {
case "gemini", "google":
return ai.NewGeminiClient(config.BaseURL, config.APIKey, model, endpoint), nil
default:
// openai, chatfire 等其他厂商都使用 OpenAI 格式
return ai.NewOpenAIClient(config.BaseURL, config.APIKey, model, endpoint), nil
}
}
func (s *AIService) GenerateText(prompt string, systemPrompt string, options ...func(*ai.ChatCompletionRequest)) (string, error) {

View File

@@ -4,6 +4,7 @@ import (
"encoding/json"
"fmt"
"strconv"
"strings"
"time"
models "github.com/drama-generator/backend/domain/models"
@@ -23,6 +24,24 @@ type ImageGenerationService struct {
log *logger.Logger
}
// truncateImageURL 截断图片 URL避免 base64 格式的 URL 占满日志
func truncateImageURL(url string) string {
if url == "" {
return ""
}
// 如果是 data URI 格式base64只显示前缀
if strings.HasPrefix(url, "data:") {
if len(url) > 50 {
return url[:50] + "...[base64 data]"
}
}
// 普通 URL 如果过长也截断
if len(url) > 100 {
return url[:100] + "..."
}
return url
}
func NewImageGenerationService(db *gorm.DB, transferService *ResourceTransferService, localStorage *storage.LocalStorage, log *logger.Logger) *ImageGenerationService {
return &ImageGenerationService{
db: db,
@@ -246,17 +265,23 @@ func (s *ImageGenerationService) completeImageGeneration(imageGenID uint, result
now := time.Now()
// 下载图片到本地存储(仅用于缓存,不更新数据库)
if s.localStorage != nil && result.ImageURL != "" {
// 仅下载 HTTP/HTTPS URL跳过 data URI
if s.localStorage != nil && result.ImageURL != "" &&
(strings.HasPrefix(result.ImageURL, "http://") || strings.HasPrefix(result.ImageURL, "https://")) {
_, err := s.localStorage.DownloadFromURL(result.ImageURL, "images")
if err != nil {
errStr := err.Error()
if len(errStr) > 200 {
errStr = errStr[:200] + "..."
}
s.log.Warnw("Failed to download image to local storage",
"error", err,
"error", errStr,
"id", imageGenID,
"original_url", result.ImageURL)
"original_url", truncateImageURL(result.ImageURL))
} else {
s.log.Infow("Image downloaded to local storage for caching",
"id", imageGenID,
"original_url", result.ImageURL)
"original_url", truncateImageURL(result.ImageURL))
}
}
@@ -291,7 +316,7 @@ func (s *ImageGenerationService) completeImageGeneration(imageGenID uint, result
} else {
s.log.Infow("Storyboard updated with composed image",
"storyboard_id", *imageGen.StoryboardID,
"composed_image", result.ImageURL)
"composed_image", truncateImageURL(result.ImageURL))
}
}
@@ -306,7 +331,7 @@ func (s *ImageGenerationService) completeImageGeneration(imageGenID uint, result
} else {
s.log.Infow("Scene updated with generated image",
"scene_id", *imageGen.SceneID,
"image_url", result.ImageURL)
"image_url", truncateImageURL(result.ImageURL))
}
}
@@ -317,7 +342,7 @@ func (s *ImageGenerationService) completeImageGeneration(imageGenID uint, result
} else {
s.log.Infow("Character updated with generated image",
"character_id", *imageGen.CharacterID,
"image_url", result.ImageURL)
"image_url", truncateImageURL(result.ImageURL))
}
}
}
@@ -356,13 +381,33 @@ func (s *ImageGenerationService) getImageClient(provider string) (image.ImageCli
model = config.Model[0]
}
switch provider {
// 使用配置中的 provider如果没有则使用传入的 provider
actualProvider := config.Provider
if actualProvider == "" {
actualProvider = provider
}
// 根据 provider 自动设置默认端点
var endpoint string
var queryEndpoint string
switch actualProvider {
case "openai", "dalle":
return image.NewOpenAIImageClient(config.BaseURL, config.APIKey, model), nil
case "stable_diffusion", "sd":
return image.NewStableDiffusionClient(config.BaseURL, config.APIKey, model), nil
endpoint = "/images/generations"
return image.NewOpenAIImageClient(config.BaseURL, config.APIKey, model, endpoint), nil
case "chatfire":
endpoint = "/images/generations"
return image.NewOpenAIImageClient(config.BaseURL, config.APIKey, model, endpoint), nil
case "volcengine", "volces", "doubao":
endpoint = "/images/generations"
queryEndpoint = ""
return image.NewVolcEngineImageClient(config.BaseURL, config.APIKey, model, endpoint, queryEndpoint), nil
case "gemini", "google":
endpoint = "/v1beta/models/{model}:generateContent"
return image.NewGeminiImageClient(config.BaseURL, config.APIKey, model, endpoint), nil
default:
return image.NewOpenAIImageClient(config.BaseURL, config.APIKey, model), nil
endpoint = "/images/generations"
return image.NewOpenAIImageClient(config.BaseURL, config.APIKey, model, endpoint), nil
}
}
@@ -394,13 +439,33 @@ func (s *ImageGenerationService) getImageClientWithModel(provider string, modelN
model = config.Model[0]
}
switch provider {
// 使用配置中的 provider如果没有则使用传入的 provider
actualProvider := config.Provider
if actualProvider == "" {
actualProvider = provider
}
// 根据 provider 自动设置默认端点
var endpoint string
var queryEndpoint string
switch actualProvider {
case "openai", "dalle":
return image.NewOpenAIImageClient(config.BaseURL, config.APIKey, model), nil
case "stable_diffusion", "sd":
return image.NewStableDiffusionClient(config.BaseURL, config.APIKey, model), nil
endpoint = "/images/generations"
return image.NewOpenAIImageClient(config.BaseURL, config.APIKey, model, endpoint), nil
case "chatfire":
endpoint = "/images/generations"
return image.NewOpenAIImageClient(config.BaseURL, config.APIKey, model, endpoint), nil
case "volcengine", "volces", "doubao":
endpoint = "/images/generations"
queryEndpoint = ""
return image.NewVolcEngineImageClient(config.BaseURL, config.APIKey, model, endpoint, queryEndpoint), nil
case "gemini", "google":
endpoint = "/v1beta/models/{model}:generateContent"
return image.NewGeminiImageClient(config.BaseURL, config.APIKey, model, endpoint), nil
default:
return image.NewOpenAIImageClient(config.BaseURL, config.APIKey, model), nil
endpoint = "/images/generations"
return image.NewOpenAIImageClient(config.BaseURL, config.APIKey, model, endpoint), nil
}
}
@@ -716,21 +781,11 @@ func (s *ImageGenerationService) extractBackgroundsFromScript(scriptContent stri
请严格按照JSON格式输出确保所有字段都使用中文。`, scriptContent)
messages := []ai.ChatMessage{
{Role: "user", Content: prompt},
}
resp, err := client.ChatCompletion(messages, ai.WithTemperature(0.7), ai.WithMaxTokens(8000))
response, err := client.GenerateText(prompt, "", ai.WithTemperature(0.7), ai.WithMaxTokens(8000))
if err != nil {
s.log.Errorw("Failed to extract backgrounds with AI", "error", err)
return nil, fmt.Errorf("AI提取场景失败: %w", err)
}
if len(resp.Choices) == 0 {
return nil, fmt.Errorf("AI未返回有效响应")
}
response := resp.Choices[0].Message.Content
s.log.Infow("AI backgrounds extraction response", "length", len(response))
// 解析JSON响应

View File

@@ -185,17 +185,17 @@ func (s *ScriptGenerationService) GenerateCharacters(req *GenerateCharactersRequ
count = 5
}
systemPrompt := `你是一个专业的角色设计师,擅长创作立体丰富的剧中角色
systemPrompt := `你是一个专业的角色分析师,擅长从剧本中提取和分析角色信息
你的任务是根据提供的剧本大纲,创作符合故事需求的角色设定。
你的任务是根据提供的剧本内容,提取并整理剧中出现的所有角色的详细设定。
要求:
1. 角色必须服务于大纲中的故事情节和冲突
2. 角色性格鲜明,有辨识度,符合故事类型
3. 每个角色都有清晰的动机目标,与大纲中的矛盾冲突相关
4. 角色之间有合理的关系和联系
5. 外貌描述必须极其详细便于AI绘画生成角色形象
6. 根据大纲的关键场景合理设置角色数量通常3-6个主要角色
1. 仔细阅读剧本,识别所有出现的角色
2. 根据剧本中的对话、行为和描述,总结角色性格特点
3. 提取角色在剧本中的关键信息:背景、动机目标、关系等
4. 角色之间的关系必须基于剧本中的实际描述
5. 外貌描述必须极其详细,如果剧本中有描述则使用,如果没有则根据角色设定合理推断,便于AI绘画生成角色形象
6. 优先提取主要角色和重要配角,次要角色可以简略
请严格按照以下 JSON 格式输出,不要添加任何其他文字:
@@ -213,21 +213,21 @@ func (s *ScriptGenerationService) GenerateCharacters(req *GenerateCharactersRequ
}
注意:
- 角色数量根据故事复杂度确定,不要过多
- 每个角色都要与大纲中的故事线有明确关联
- 必须基于剧本内容提取角色,不要凭空创作
- 优先提取主要角色和重要配角,数量根据剧本实际情况确定
- description、personality、appearance、voice_style都必须详细描述字数要充足
- appearance外貌描述是重中之重必须极其详细具体要能让AI准确生成角色形象
- 避免模糊描述,多用具体的视觉特征和细节`
- 如果剧本中角色信息不完整,可以根据角色设定合理补充,但要符合剧本整体风格`
outlineText := req.Outline
if outlineText == "" {
outlineText = fmt.Sprintf("剧名:%s\n简介%s\n类型%s", drama.Title, drama.Description, drama.Genre)
}
userPrompt := fmt.Sprintf(`剧本大纲
userPrompt := fmt.Sprintf(`剧本内容
%s
创作 %d 个角色的详细设定。`, outlineText, count)
从剧本中提取并整理最多 %d 个主要角色的详细设定。`, outlineText, count)
temperature := req.Temperature
if temperature == 0 {

View File

@@ -414,24 +414,33 @@ func (s *VideoGenerationService) getVideoClient(provider string, modelName strin
// 使用配置中的信息创建客户端
baseURL := config.BaseURL
apiKey := config.APIKey
endpoint := config.Endpoint
queryEndpoint := config.QueryEndpoint
model := modelName
if model == "" && len(config.Model) > 0 {
model = config.Model[0]
}
// 根据 provider 自动设置默认端点
var endpoint string
var queryEndpoint string
switch provider {
case "doubao":
case "chatfire":
endpoint = "/video/generations"
queryEndpoint = "/v1/video/task/{taskId}"
return video.NewChatfireClient(baseURL, apiKey, model, endpoint, queryEndpoint), nil
case "doubao", "volcengine", "volces":
endpoint = "/contents/generations/tasks"
queryEndpoint = "/generations/tasks/{taskId}"
return video.NewVolcesArkClient(baseURL, apiKey, model, endpoint, queryEndpoint), nil
case "openai":
// OpenAI Sora 使用 /v1/videos 端点
return video.NewOpenAISoraClient(baseURL, apiKey, model), nil
case "runway":
return video.NewRunwayClient(baseURL, apiKey, model), nil
case "pika":
return video.NewPikaClient(baseURL, apiKey, model), nil
case "minimax":
return video.NewMinimaxClient(baseURL, apiKey, model), nil
case "openai":
return video.NewOpenAISoraClient(baseURL, apiKey, model), nil
default:
return nil, fmt.Errorf("unsupported video provider: %s", provider)
}

View File

@@ -297,6 +297,10 @@ func (s *VideoMergeService) getVideoClient(provider string) (video.VideoClient,
model = config.Model[0]
}
// 根据 provider 自动设置默认端点
var endpoint string
var queryEndpoint string
switch provider {
case "runway":
return video.NewRunwayClient(config.BaseURL, config.APIKey, model), nil
@@ -306,10 +310,18 @@ func (s *VideoMergeService) getVideoClient(provider string) (video.VideoClient,
return video.NewOpenAISoraClient(config.BaseURL, config.APIKey, model), nil
case "minimax":
return video.NewMinimaxClient(config.BaseURL, config.APIKey, model), nil
case "chatfire":
endpoint = "/video/generations"
queryEndpoint = "/v1/video/task/{taskId}"
return video.NewChatfireClient(config.BaseURL, config.APIKey, model, endpoint, queryEndpoint), nil
case "doubao", "volces", "ark":
return video.NewVolcesArkClient(config.BaseURL, config.APIKey, model, config.Endpoint, config.QueryEndpoint), nil
endpoint = "/contents/generations/tasks"
queryEndpoint = "/generations/tasks/{taskId}"
return video.NewVolcesArkClient(config.BaseURL, config.APIKey, model, endpoint, queryEndpoint), nil
default:
return video.NewVolcesArkClient(config.BaseURL, config.APIKey, model, config.Endpoint, config.QueryEndpoint), nil
endpoint = "/contents/generations/tasks"
queryEndpoint = "/generations/tasks/{taskId}"
return video.NewVolcesArkClient(config.BaseURL, config.APIKey, model, endpoint, queryEndpoint), nil
}
}