添加chat gemini、chatfire端点、 图片生成 gemini、chatfire 更轻松的AI配置
This commit is contained in:
@@ -25,6 +25,7 @@ func NewAIService(db *gorm.DB, log *logger.Logger) *AIService {
|
||||
type CreateAIConfigRequest struct {
|
||||
ServiceType string `json:"service_type" binding:"required,oneof=text image video"`
|
||||
Name string `json:"name" binding:"required,min=1,max=100"`
|
||||
Provider string `json:"provider" binding:"required"`
|
||||
BaseURL string `json:"base_url" binding:"required,url"`
|
||||
APIKey string `json:"api_key" binding:"required"`
|
||||
Model models.ModelField `json:"model" binding:"required"`
|
||||
@@ -37,6 +38,7 @@ type CreateAIConfigRequest struct {
|
||||
|
||||
type UpdateAIConfigRequest struct {
|
||||
Name string `json:"name" binding:"omitempty,min=1,max=100"`
|
||||
Provider string `json:"provider"`
|
||||
BaseURL string `json:"base_url" binding:"omitempty,url"`
|
||||
APIKey string `json:"api_key"`
|
||||
Model *models.ModelField `json:"model"`
|
||||
@@ -52,18 +54,53 @@ type TestConnectionRequest struct {
|
||||
BaseURL string `json:"base_url" binding:"required,url"`
|
||||
APIKey string `json:"api_key" binding:"required"`
|
||||
Model models.ModelField `json:"model" binding:"required"`
|
||||
Provider string `json:"provider"`
|
||||
Endpoint string `json:"endpoint"`
|
||||
}
|
||||
|
||||
func (s *AIService) CreateConfig(req *CreateAIConfigRequest) (*models.AIServiceConfig, error) {
|
||||
// 根据 provider 和 service_type 自动设置 endpoint
|
||||
endpoint := req.Endpoint
|
||||
queryEndpoint := req.QueryEndpoint
|
||||
|
||||
if endpoint == "" {
|
||||
switch req.Provider {
|
||||
case "gemini", "google":
|
||||
if req.ServiceType == "text" {
|
||||
endpoint = "/v1beta/models/{model}:generateContent"
|
||||
} else if req.ServiceType == "image" {
|
||||
endpoint = "/v1beta/models/{model}:generateContent"
|
||||
}
|
||||
case "openai", "chatfire":
|
||||
if req.ServiceType == "text" {
|
||||
endpoint = "/chat/completions"
|
||||
} else if req.ServiceType == "image" {
|
||||
endpoint = "/images/generations"
|
||||
} else if req.ServiceType == "video" {
|
||||
endpoint = "/video/generations"
|
||||
if queryEndpoint == "" {
|
||||
queryEndpoint = "/v1/video/task/{taskId}"
|
||||
}
|
||||
}
|
||||
default:
|
||||
// 默认使用 OpenAI 格式
|
||||
if req.ServiceType == "text" {
|
||||
endpoint = "/chat/completions"
|
||||
} else if req.ServiceType == "image" {
|
||||
endpoint = "/images/generations"
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
config := &models.AIServiceConfig{
|
||||
ServiceType: req.ServiceType,
|
||||
Name: req.Name,
|
||||
Provider: req.Provider,
|
||||
BaseURL: req.BaseURL,
|
||||
APIKey: req.APIKey,
|
||||
Model: req.Model,
|
||||
Endpoint: req.Endpoint,
|
||||
QueryEndpoint: req.QueryEndpoint,
|
||||
Endpoint: endpoint,
|
||||
QueryEndpoint: queryEndpoint,
|
||||
Priority: req.Priority,
|
||||
IsDefault: req.IsDefault,
|
||||
IsActive: true,
|
||||
@@ -75,7 +112,7 @@ func (s *AIService) CreateConfig(req *CreateAIConfigRequest) (*models.AIServiceC
|
||||
return nil, err
|
||||
}
|
||||
|
||||
s.log.Infow("AI config created", "config_id", config.ID)
|
||||
s.log.Infow("AI config created", "config_id", config.ID, "provider", req.Provider, "endpoint", endpoint)
|
||||
return config, nil
|
||||
}
|
||||
|
||||
@@ -125,6 +162,9 @@ func (s *AIService) UpdateConfig(configID uint, req *UpdateAIConfigRequest) (*mo
|
||||
if req.Name != "" {
|
||||
updates["name"] = req.Name
|
||||
}
|
||||
if req.Provider != "" {
|
||||
updates["provider"] = req.Provider
|
||||
}
|
||||
if req.BaseURL != "" {
|
||||
updates["base_url"] = req.BaseURL
|
||||
}
|
||||
@@ -137,9 +177,30 @@ func (s *AIService) UpdateConfig(configID uint, req *UpdateAIConfigRequest) (*mo
|
||||
if req.Priority != nil {
|
||||
updates["priority"] = *req.Priority
|
||||
}
|
||||
if req.Endpoint != "" {
|
||||
|
||||
// 如果提供了 provider,根据 provider 和 service_type 自动设置 endpoint
|
||||
if req.Provider != "" && req.Endpoint == "" {
|
||||
provider := req.Provider
|
||||
serviceType := config.ServiceType
|
||||
|
||||
switch provider {
|
||||
case "gemini", "google":
|
||||
if serviceType == "text" || serviceType == "image" {
|
||||
updates["endpoint"] = "/v1beta/models/{model}:generateContent"
|
||||
}
|
||||
case "openai", "chatfire":
|
||||
if serviceType == "text" {
|
||||
updates["endpoint"] = "/chat/completions"
|
||||
} else if serviceType == "image" {
|
||||
updates["endpoint"] = "/images/generations"
|
||||
} else if serviceType == "video" {
|
||||
updates["endpoint"] = "/video/generations"
|
||||
}
|
||||
}
|
||||
} else if req.Endpoint != "" {
|
||||
updates["endpoint"] = req.Endpoint
|
||||
}
|
||||
|
||||
// 允许清空query_endpoint,所以不检查是否为空
|
||||
updates["query_endpoint"] = req.QueryEndpoint
|
||||
if req.Settings != "" {
|
||||
@@ -179,13 +240,51 @@ func (s *AIService) DeleteConfig(configID uint) error {
|
||||
}
|
||||
|
||||
func (s *AIService) TestConnection(req *TestConnectionRequest) error {
|
||||
s.log.Infow("TestConnection called", "baseURL", req.BaseURL, "provider", req.Provider, "endpoint", req.Endpoint, "modelCount", len(req.Model))
|
||||
|
||||
// 使用第一个模型进行测试
|
||||
model := ""
|
||||
if len(req.Model) > 0 {
|
||||
model = req.Model[0]
|
||||
}
|
||||
client := ai.NewOpenAIClient(req.BaseURL, req.APIKey, model, req.Endpoint)
|
||||
return client.TestConnection()
|
||||
s.log.Infow("Using model for test", "model", model, "provider", req.Provider)
|
||||
|
||||
// 根据 provider 参数选择客户端
|
||||
var client ai.AIClient
|
||||
var endpoint string
|
||||
|
||||
switch req.Provider {
|
||||
case "gemini", "google":
|
||||
// Gemini
|
||||
s.log.Infow("Using Gemini client", "baseURL", req.BaseURL)
|
||||
endpoint = "/v1beta/models/{model}:generateContent"
|
||||
client = ai.NewGeminiClient(req.BaseURL, req.APIKey, model, endpoint)
|
||||
case "openai", "chatfire":
|
||||
// OpenAI 格式(包括 chatfire 等)
|
||||
s.log.Infow("Using OpenAI-compatible client", "baseURL", req.BaseURL, "provider", req.Provider)
|
||||
endpoint = req.Endpoint
|
||||
if endpoint == "" {
|
||||
endpoint = "/chat/completions"
|
||||
}
|
||||
client = ai.NewOpenAIClient(req.BaseURL, req.APIKey, model, endpoint)
|
||||
default:
|
||||
// 默认使用 OpenAI 格式
|
||||
s.log.Infow("Using default OpenAI-compatible client", "baseURL", req.BaseURL)
|
||||
endpoint = req.Endpoint
|
||||
if endpoint == "" {
|
||||
endpoint = "/chat/completions"
|
||||
}
|
||||
client = ai.NewOpenAIClient(req.BaseURL, req.APIKey, model, endpoint)
|
||||
}
|
||||
|
||||
s.log.Infow("Calling TestConnection on client", "endpoint", endpoint)
|
||||
err := client.TestConnection()
|
||||
if err != nil {
|
||||
s.log.Errorw("TestConnection failed", "error", err)
|
||||
} else {
|
||||
s.log.Infow("TestConnection succeeded")
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
func (s *AIService) GetDefaultConfig(serviceType string) (*models.AIServiceConfig, error) {
|
||||
@@ -228,7 +327,7 @@ func (s *AIService) GetConfigForModel(serviceType string, modelName string) (*mo
|
||||
return nil, errors.New("no config found for model: " + modelName)
|
||||
}
|
||||
|
||||
func (s *AIService) GetAIClient(serviceType string) (*ai.OpenAIClient, error) {
|
||||
func (s *AIService) GetAIClient(serviceType string) (ai.AIClient, error) {
|
||||
config, err := s.GetDefaultConfig(serviceType)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -240,7 +339,25 @@ func (s *AIService) GetAIClient(serviceType string) (*ai.OpenAIClient, error) {
|
||||
model = config.Model[0]
|
||||
}
|
||||
|
||||
return ai.NewOpenAIClient(config.BaseURL, config.APIKey, model, config.Endpoint), nil
|
||||
// 使用数据库配置中的 endpoint,如果为空则根据 provider 设置默认值
|
||||
endpoint := config.Endpoint
|
||||
if endpoint == "" {
|
||||
switch config.Provider {
|
||||
case "gemini", "google":
|
||||
endpoint = "/v1beta/models/{model}:generateContent"
|
||||
default:
|
||||
endpoint = "/chat/completions"
|
||||
}
|
||||
}
|
||||
|
||||
// 根据 provider 创建对应的客户端
|
||||
switch config.Provider {
|
||||
case "gemini", "google":
|
||||
return ai.NewGeminiClient(config.BaseURL, config.APIKey, model, endpoint), nil
|
||||
default:
|
||||
// openai, chatfire 等其他厂商都使用 OpenAI 格式
|
||||
return ai.NewOpenAIClient(config.BaseURL, config.APIKey, model, endpoint), nil
|
||||
}
|
||||
}
|
||||
|
||||
func (s *AIService) GenerateText(prompt string, systemPrompt string, options ...func(*ai.ChatCompletionRequest)) (string, error) {
|
||||
|
||||
@@ -4,6 +4,7 @@ import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
models "github.com/drama-generator/backend/domain/models"
|
||||
@@ -23,6 +24,24 @@ type ImageGenerationService struct {
|
||||
log *logger.Logger
|
||||
}
|
||||
|
||||
// truncateImageURL 截断图片 URL,避免 base64 格式的 URL 占满日志
|
||||
func truncateImageURL(url string) string {
|
||||
if url == "" {
|
||||
return ""
|
||||
}
|
||||
// 如果是 data URI 格式(base64),只显示前缀
|
||||
if strings.HasPrefix(url, "data:") {
|
||||
if len(url) > 50 {
|
||||
return url[:50] + "...[base64 data]"
|
||||
}
|
||||
}
|
||||
// 普通 URL 如果过长也截断
|
||||
if len(url) > 100 {
|
||||
return url[:100] + "..."
|
||||
}
|
||||
return url
|
||||
}
|
||||
|
||||
func NewImageGenerationService(db *gorm.DB, transferService *ResourceTransferService, localStorage *storage.LocalStorage, log *logger.Logger) *ImageGenerationService {
|
||||
return &ImageGenerationService{
|
||||
db: db,
|
||||
@@ -246,17 +265,23 @@ func (s *ImageGenerationService) completeImageGeneration(imageGenID uint, result
|
||||
now := time.Now()
|
||||
|
||||
// 下载图片到本地存储(仅用于缓存,不更新数据库)
|
||||
if s.localStorage != nil && result.ImageURL != "" {
|
||||
// 仅下载 HTTP/HTTPS URL,跳过 data URI
|
||||
if s.localStorage != nil && result.ImageURL != "" &&
|
||||
(strings.HasPrefix(result.ImageURL, "http://") || strings.HasPrefix(result.ImageURL, "https://")) {
|
||||
_, err := s.localStorage.DownloadFromURL(result.ImageURL, "images")
|
||||
if err != nil {
|
||||
errStr := err.Error()
|
||||
if len(errStr) > 200 {
|
||||
errStr = errStr[:200] + "..."
|
||||
}
|
||||
s.log.Warnw("Failed to download image to local storage",
|
||||
"error", err,
|
||||
"error", errStr,
|
||||
"id", imageGenID,
|
||||
"original_url", result.ImageURL)
|
||||
"original_url", truncateImageURL(result.ImageURL))
|
||||
} else {
|
||||
s.log.Infow("Image downloaded to local storage for caching",
|
||||
"id", imageGenID,
|
||||
"original_url", result.ImageURL)
|
||||
"original_url", truncateImageURL(result.ImageURL))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -291,7 +316,7 @@ func (s *ImageGenerationService) completeImageGeneration(imageGenID uint, result
|
||||
} else {
|
||||
s.log.Infow("Storyboard updated with composed image",
|
||||
"storyboard_id", *imageGen.StoryboardID,
|
||||
"composed_image", result.ImageURL)
|
||||
"composed_image", truncateImageURL(result.ImageURL))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -306,7 +331,7 @@ func (s *ImageGenerationService) completeImageGeneration(imageGenID uint, result
|
||||
} else {
|
||||
s.log.Infow("Scene updated with generated image",
|
||||
"scene_id", *imageGen.SceneID,
|
||||
"image_url", result.ImageURL)
|
||||
"image_url", truncateImageURL(result.ImageURL))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -317,7 +342,7 @@ func (s *ImageGenerationService) completeImageGeneration(imageGenID uint, result
|
||||
} else {
|
||||
s.log.Infow("Character updated with generated image",
|
||||
"character_id", *imageGen.CharacterID,
|
||||
"image_url", result.ImageURL)
|
||||
"image_url", truncateImageURL(result.ImageURL))
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -356,13 +381,33 @@ func (s *ImageGenerationService) getImageClient(provider string) (image.ImageCli
|
||||
model = config.Model[0]
|
||||
}
|
||||
|
||||
switch provider {
|
||||
// 使用配置中的 provider,如果没有则使用传入的 provider
|
||||
actualProvider := config.Provider
|
||||
if actualProvider == "" {
|
||||
actualProvider = provider
|
||||
}
|
||||
|
||||
// 根据 provider 自动设置默认端点
|
||||
var endpoint string
|
||||
var queryEndpoint string
|
||||
|
||||
switch actualProvider {
|
||||
case "openai", "dalle":
|
||||
return image.NewOpenAIImageClient(config.BaseURL, config.APIKey, model), nil
|
||||
case "stable_diffusion", "sd":
|
||||
return image.NewStableDiffusionClient(config.BaseURL, config.APIKey, model), nil
|
||||
endpoint = "/images/generations"
|
||||
return image.NewOpenAIImageClient(config.BaseURL, config.APIKey, model, endpoint), nil
|
||||
case "chatfire":
|
||||
endpoint = "/images/generations"
|
||||
return image.NewOpenAIImageClient(config.BaseURL, config.APIKey, model, endpoint), nil
|
||||
case "volcengine", "volces", "doubao":
|
||||
endpoint = "/images/generations"
|
||||
queryEndpoint = ""
|
||||
return image.NewVolcEngineImageClient(config.BaseURL, config.APIKey, model, endpoint, queryEndpoint), nil
|
||||
case "gemini", "google":
|
||||
endpoint = "/v1beta/models/{model}:generateContent"
|
||||
return image.NewGeminiImageClient(config.BaseURL, config.APIKey, model, endpoint), nil
|
||||
default:
|
||||
return image.NewOpenAIImageClient(config.BaseURL, config.APIKey, model), nil
|
||||
endpoint = "/images/generations"
|
||||
return image.NewOpenAIImageClient(config.BaseURL, config.APIKey, model, endpoint), nil
|
||||
}
|
||||
}
|
||||
|
||||
@@ -394,13 +439,33 @@ func (s *ImageGenerationService) getImageClientWithModel(provider string, modelN
|
||||
model = config.Model[0]
|
||||
}
|
||||
|
||||
switch provider {
|
||||
// 使用配置中的 provider,如果没有则使用传入的 provider
|
||||
actualProvider := config.Provider
|
||||
if actualProvider == "" {
|
||||
actualProvider = provider
|
||||
}
|
||||
|
||||
// 根据 provider 自动设置默认端点
|
||||
var endpoint string
|
||||
var queryEndpoint string
|
||||
|
||||
switch actualProvider {
|
||||
case "openai", "dalle":
|
||||
return image.NewOpenAIImageClient(config.BaseURL, config.APIKey, model), nil
|
||||
case "stable_diffusion", "sd":
|
||||
return image.NewStableDiffusionClient(config.BaseURL, config.APIKey, model), nil
|
||||
endpoint = "/images/generations"
|
||||
return image.NewOpenAIImageClient(config.BaseURL, config.APIKey, model, endpoint), nil
|
||||
case "chatfire":
|
||||
endpoint = "/images/generations"
|
||||
return image.NewOpenAIImageClient(config.BaseURL, config.APIKey, model, endpoint), nil
|
||||
case "volcengine", "volces", "doubao":
|
||||
endpoint = "/images/generations"
|
||||
queryEndpoint = ""
|
||||
return image.NewVolcEngineImageClient(config.BaseURL, config.APIKey, model, endpoint, queryEndpoint), nil
|
||||
case "gemini", "google":
|
||||
endpoint = "/v1beta/models/{model}:generateContent"
|
||||
return image.NewGeminiImageClient(config.BaseURL, config.APIKey, model, endpoint), nil
|
||||
default:
|
||||
return image.NewOpenAIImageClient(config.BaseURL, config.APIKey, model), nil
|
||||
endpoint = "/images/generations"
|
||||
return image.NewOpenAIImageClient(config.BaseURL, config.APIKey, model, endpoint), nil
|
||||
}
|
||||
}
|
||||
|
||||
@@ -716,21 +781,11 @@ func (s *ImageGenerationService) extractBackgroundsFromScript(scriptContent stri
|
||||
|
||||
请严格按照JSON格式输出,确保所有字段都使用中文。`, scriptContent)
|
||||
|
||||
messages := []ai.ChatMessage{
|
||||
{Role: "user", Content: prompt},
|
||||
}
|
||||
|
||||
resp, err := client.ChatCompletion(messages, ai.WithTemperature(0.7), ai.WithMaxTokens(8000))
|
||||
response, err := client.GenerateText(prompt, "", ai.WithTemperature(0.7), ai.WithMaxTokens(8000))
|
||||
if err != nil {
|
||||
s.log.Errorw("Failed to extract backgrounds with AI", "error", err)
|
||||
return nil, fmt.Errorf("AI提取场景失败: %w", err)
|
||||
}
|
||||
|
||||
if len(resp.Choices) == 0 {
|
||||
return nil, fmt.Errorf("AI未返回有效响应")
|
||||
}
|
||||
|
||||
response := resp.Choices[0].Message.Content
|
||||
s.log.Infow("AI backgrounds extraction response", "length", len(response))
|
||||
|
||||
// 解析JSON响应
|
||||
|
||||
@@ -185,17 +185,17 @@ func (s *ScriptGenerationService) GenerateCharacters(req *GenerateCharactersRequ
|
||||
count = 5
|
||||
}
|
||||
|
||||
systemPrompt := `你是一个专业的角色设计师,擅长创作立体丰富的剧中角色。
|
||||
systemPrompt := `你是一个专业的角色分析师,擅长从剧本中提取和分析角色信息。
|
||||
|
||||
你的任务是根据提供的剧本大纲,创作符合故事需求的角色设定。
|
||||
你的任务是根据提供的剧本内容,提取并整理剧中出现的所有角色的详细设定。
|
||||
|
||||
要求:
|
||||
1. 角色必须服务于大纲中的故事情节和冲突
|
||||
2. 角色性格鲜明,有辨识度,符合故事类型
|
||||
3. 每个角色都有清晰的动机和目标,与大纲中的矛盾冲突相关
|
||||
4. 角色之间有合理的关系和联系
|
||||
5. 外貌描述必须极其详细,便于AI绘画生成角色形象
|
||||
6. 根据大纲的关键场景,合理设置角色数量(通常3-6个主要角色)
|
||||
1. 仔细阅读剧本,识别所有出现的角色
|
||||
2. 根据剧本中的对话、行为和描述,总结角色的性格特点
|
||||
3. 提取角色在剧本中的关键信息:背景、动机、目标、关系等
|
||||
4. 角色之间的关系必须基于剧本中的实际描述
|
||||
5. 外貌描述必须极其详细,如果剧本中有描述则使用,如果没有则根据角色设定合理推断,便于AI绘画生成角色形象
|
||||
6. 优先提取主要角色和重要配角,次要角色可以简略
|
||||
|
||||
请严格按照以下 JSON 格式输出,不要添加任何其他文字:
|
||||
|
||||
@@ -213,21 +213,21 @@ func (s *ScriptGenerationService) GenerateCharacters(req *GenerateCharactersRequ
|
||||
}
|
||||
|
||||
注意:
|
||||
- 角色数量根据故事复杂度确定,不要过多
|
||||
- 每个角色都要与大纲中的故事线有明确关联
|
||||
- 必须基于剧本内容提取角色,不要凭空创作
|
||||
- 优先提取主要角色和重要配角,数量根据剧本实际情况确定
|
||||
- description、personality、appearance、voice_style都必须详细描述,字数要充足
|
||||
- appearance外貌描述是重中之重,必须极其详细具体,要能让AI准确生成角色形象
|
||||
- 避免模糊描述,多用具体的视觉特征和细节`
|
||||
- 如果剧本中角色信息不完整,可以根据角色设定合理补充,但要符合剧本整体风格`
|
||||
|
||||
outlineText := req.Outline
|
||||
if outlineText == "" {
|
||||
outlineText = fmt.Sprintf("剧名:%s\n简介:%s\n类型:%s", drama.Title, drama.Description, drama.Genre)
|
||||
}
|
||||
|
||||
userPrompt := fmt.Sprintf(`剧本大纲:
|
||||
userPrompt := fmt.Sprintf(`剧本内容:
|
||||
%s
|
||||
|
||||
请创作 %d 个角色的详细设定。`, outlineText, count)
|
||||
请从剧本中提取并整理最多 %d 个主要角色的详细设定。`, outlineText, count)
|
||||
|
||||
temperature := req.Temperature
|
||||
if temperature == 0 {
|
||||
|
||||
@@ -414,24 +414,33 @@ func (s *VideoGenerationService) getVideoClient(provider string, modelName strin
|
||||
// 使用配置中的信息创建客户端
|
||||
baseURL := config.BaseURL
|
||||
apiKey := config.APIKey
|
||||
endpoint := config.Endpoint
|
||||
queryEndpoint := config.QueryEndpoint
|
||||
model := modelName
|
||||
if model == "" && len(config.Model) > 0 {
|
||||
model = config.Model[0]
|
||||
}
|
||||
|
||||
// 根据 provider 自动设置默认端点
|
||||
var endpoint string
|
||||
var queryEndpoint string
|
||||
|
||||
switch provider {
|
||||
case "doubao":
|
||||
case "chatfire":
|
||||
endpoint = "/video/generations"
|
||||
queryEndpoint = "/v1/video/task/{taskId}"
|
||||
return video.NewChatfireClient(baseURL, apiKey, model, endpoint, queryEndpoint), nil
|
||||
case "doubao", "volcengine", "volces":
|
||||
endpoint = "/contents/generations/tasks"
|
||||
queryEndpoint = "/generations/tasks/{taskId}"
|
||||
return video.NewVolcesArkClient(baseURL, apiKey, model, endpoint, queryEndpoint), nil
|
||||
case "openai":
|
||||
// OpenAI Sora 使用 /v1/videos 端点
|
||||
return video.NewOpenAISoraClient(baseURL, apiKey, model), nil
|
||||
case "runway":
|
||||
return video.NewRunwayClient(baseURL, apiKey, model), nil
|
||||
case "pika":
|
||||
return video.NewPikaClient(baseURL, apiKey, model), nil
|
||||
case "minimax":
|
||||
return video.NewMinimaxClient(baseURL, apiKey, model), nil
|
||||
case "openai":
|
||||
return video.NewOpenAISoraClient(baseURL, apiKey, model), nil
|
||||
default:
|
||||
return nil, fmt.Errorf("unsupported video provider: %s", provider)
|
||||
}
|
||||
|
||||
@@ -297,6 +297,10 @@ func (s *VideoMergeService) getVideoClient(provider string) (video.VideoClient,
|
||||
model = config.Model[0]
|
||||
}
|
||||
|
||||
// 根据 provider 自动设置默认端点
|
||||
var endpoint string
|
||||
var queryEndpoint string
|
||||
|
||||
switch provider {
|
||||
case "runway":
|
||||
return video.NewRunwayClient(config.BaseURL, config.APIKey, model), nil
|
||||
@@ -306,10 +310,18 @@ func (s *VideoMergeService) getVideoClient(provider string) (video.VideoClient,
|
||||
return video.NewOpenAISoraClient(config.BaseURL, config.APIKey, model), nil
|
||||
case "minimax":
|
||||
return video.NewMinimaxClient(config.BaseURL, config.APIKey, model), nil
|
||||
case "chatfire":
|
||||
endpoint = "/video/generations"
|
||||
queryEndpoint = "/v1/video/task/{taskId}"
|
||||
return video.NewChatfireClient(config.BaseURL, config.APIKey, model, endpoint, queryEndpoint), nil
|
||||
case "doubao", "volces", "ark":
|
||||
return video.NewVolcesArkClient(config.BaseURL, config.APIKey, model, config.Endpoint, config.QueryEndpoint), nil
|
||||
endpoint = "/contents/generations/tasks"
|
||||
queryEndpoint = "/generations/tasks/{taskId}"
|
||||
return video.NewVolcesArkClient(config.BaseURL, config.APIKey, model, endpoint, queryEndpoint), nil
|
||||
default:
|
||||
return video.NewVolcesArkClient(config.BaseURL, config.APIKey, model, config.Endpoint, config.QueryEndpoint), nil
|
||||
endpoint = "/contents/generations/tasks"
|
||||
queryEndpoint = "/generations/tasks/{taskId}"
|
||||
return video.NewVolcesArkClient(config.BaseURL, config.APIKey, model, endpoint, queryEndpoint), nil
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user