diff --git a/application/services/ai_service.go b/application/services/ai_service.go index 2140c2d..c1802c3 100644 --- a/application/services/ai_service.go +++ b/application/services/ai_service.go @@ -71,7 +71,18 @@ func (s *AIService) CreateConfig(req *CreateAIConfigRequest) (*models.AIServiceC } else if req.ServiceType == "image" { endpoint = "/v1beta/models/{model}:generateContent" } - case "openai", "chatfire": + case "openai": + if req.ServiceType == "text" { + endpoint = "/chat/completions" + } else if req.ServiceType == "image" { + endpoint = "/images/generations" + } else if req.ServiceType == "video" { + endpoint = "/videos" + if queryEndpoint == "" { + queryEndpoint = "/videos/{taskId}" + } + } + case "chatfire": if req.ServiceType == "text" { endpoint = "/chat/completions" } else if req.ServiceType == "image" { @@ -79,7 +90,14 @@ func (s *AIService) CreateConfig(req *CreateAIConfigRequest) (*models.AIServiceC } else if req.ServiceType == "video" { endpoint = "/video/generations" if queryEndpoint == "" { - queryEndpoint = "/v1/video/task/{taskId}" + queryEndpoint = "/video/task/{taskId}" + } + } + case "doubao", "volcengine", "volces": + if req.ServiceType == "video" { + endpoint = "/contents/generations/tasks" + if queryEndpoint == "" { + queryEndpoint = "/generations/tasks/{taskId}" } } default: @@ -188,13 +206,23 @@ func (s *AIService) UpdateConfig(configID uint, req *UpdateAIConfigRequest) (*mo if serviceType == "text" || serviceType == "image" { updates["endpoint"] = "/v1beta/models/{model}:generateContent" } - case "openai", "chatfire": + case "openai": + if serviceType == "text" { + updates["endpoint"] = "/chat/completions" + } else if serviceType == "image" { + updates["endpoint"] = "/images/generations" + } else if serviceType == "video" { + updates["endpoint"] = "/videos" + updates["query_endpoint"] = "/videos/{taskId}" + } + case "chatfire": if serviceType == "text" { updates["endpoint"] = "/chat/completions" } else if serviceType == "image" { updates["endpoint"] = "/images/generations" } else if serviceType == "video" { updates["endpoint"] = "/video/generations" + updates["query_endpoint"] = "/video/task/{taskId}" } } } else if req.Endpoint != "" { diff --git a/application/services/video_generation_service.go b/application/services/video_generation_service.go index ac99469..91ec201 100644 --- a/application/services/video_generation_service.go +++ b/application/services/video_generation_service.go @@ -419,14 +419,14 @@ func (s *VideoGenerationService) getVideoClient(provider string, modelName strin model = config.Model[0] } - // 根据 provider 自动设置默认端点 + // 根据配置中的 provider 创建对应的客户端 var endpoint string var queryEndpoint string - switch provider { + switch config.Provider { case "chatfire": endpoint = "/video/generations" - queryEndpoint = "/v1/video/task/{taskId}" + queryEndpoint = "/video/task/{taskId}" return video.NewChatfireClient(baseURL, apiKey, model, endpoint, queryEndpoint), nil case "doubao", "volcengine", "volces": endpoint = "/contents/generations/tasks" diff --git a/application/services/video_merge_service.go b/application/services/video_merge_service.go index 3b8b21e..1a14f2f 100644 --- a/application/services/video_merge_service.go +++ b/application/services/video_merge_service.go @@ -297,11 +297,11 @@ func (s *VideoMergeService) getVideoClient(provider string) (video.VideoClient, model = config.Model[0] } - // 根据 provider 自动设置默认端点 + // 根据配置中的 provider 创建对应的客户端 var endpoint string var queryEndpoint string - switch provider { + switch config.Provider { case "runway": return video.NewRunwayClient(config.BaseURL, config.APIKey, model), nil case "pika": @@ -312,7 +312,7 @@ func (s *VideoMergeService) getVideoClient(provider string) (video.VideoClient, return video.NewMinimaxClient(config.BaseURL, config.APIKey, model), nil case "chatfire": endpoint = "/video/generations" - queryEndpoint = "/v1/video/task/{taskId}" + queryEndpoint = "/video/task/{taskId}" return video.NewChatfireClient(config.BaseURL, config.APIKey, model, endpoint, queryEndpoint), nil case "doubao", "volces", "ark": endpoint = "/contents/generations/tasks" diff --git a/pkg/video/chatfire_client.go b/pkg/video/chatfire_client.go index f9b6a91..38ae403 100644 --- a/pkg/video/chatfire_client.go +++ b/pkg/video/chatfire_client.go @@ -28,17 +28,75 @@ type ChatfireRequest struct { Size string `json:"size,omitempty"` } +// ChatfireSoraRequest Sora 模型请求格式 +type ChatfireSoraRequest struct { + Model string `json:"model"` + Prompt string `json:"prompt"` + Seconds string `json:"seconds,omitempty"` + Size string `json:"size,omitempty"` + InputReference string `json:"input_reference,omitempty"` +} + +// ChatfireDoubaoRequest 豆包/火山模型请求格式 +type ChatfireDoubaoRequest struct { + Model string `json:"model"` + Content []struct { + Type string `json:"type"` + Text string `json:"text,omitempty"` + URL string `json:"url,omitempty"` + } `json:"content"` +} + type ChatfireResponse struct { - TaskID string `json:"task_id"` - Status string `json:"status"` - Error string `json:"error,omitempty"` + ID string `json:"id"` + TaskID string `json:"task_id,omitempty"` + Status string `json:"status,omitempty"` + Error json.RawMessage `json:"error,omitempty"` + Data struct { + ID string `json:"id,omitempty"` + Status string `json:"status,omitempty"` + VideoURL string `json:"video_url,omitempty"` + } `json:"data,omitempty"` } type ChatfireTaskResponse struct { - TaskID string `json:"task_id"` - Status string `json:"status"` - VideoURL string `json:"video_url,omitempty"` - Error string `json:"error,omitempty"` + ID string `json:"id,omitempty"` + TaskID string `json:"task_id,omitempty"` + Status string `json:"status,omitempty"` + VideoURL string `json:"video_url,omitempty"` + Error json.RawMessage `json:"error,omitempty"` + Data struct { + ID string `json:"id,omitempty"` + Status string `json:"status,omitempty"` + VideoURL string `json:"video_url,omitempty"` + } `json:"data,omitempty"` +} + +// getErrorMessage 从 error 字段提取错误信息(支持字符串或对象) +func getErrorMessage(errorData json.RawMessage) string { + if len(errorData) == 0 { + return "" + } + + // 尝试解析为字符串 + var errStr string + if err := json.Unmarshal(errorData, &errStr); err == nil { + return errStr + } + + // 尝试解析为对象 + var errObj struct { + Message string `json:"message"` + Code string `json:"code"` + } + if err := json.Unmarshal(errorData, &errObj); err == nil { + if errObj.Message != "" { + return errObj.Message + } + } + + // 返回原始 JSON 字符串 + return string(errorData) } func NewChatfireClient(baseURL, apiKey, model, endpoint, queryEndpoint string) *ChatfireClient { @@ -46,7 +104,7 @@ func NewChatfireClient(baseURL, apiKey, model, endpoint, queryEndpoint string) * endpoint = "/video/generations" } if queryEndpoint == "" { - queryEndpoint = "/v1/video/task/{taskId}" + queryEndpoint = "/video/task/{taskId}" } return &ChatfireClient{ BaseURL: baseURL, @@ -75,15 +133,62 @@ func (c *ChatfireClient) GenerateVideo(imageURL, prompt string, opts ...VideoOpt model = options.Model } - reqBody := ChatfireRequest{ - Model: model, - Prompt: prompt, - ImageURL: imageURL, - Duration: options.Duration, - Size: options.AspectRatio, + // 根据模型名称选择请求格式 + var jsonData []byte + var err error + + if strings.Contains(model, "doubao") || strings.Contains(model, "seedance") { + // 豆包/火山格式 + reqBody := ChatfireDoubaoRequest{ + Model: model, + } + // 添加文本内容 + reqBody.Content = append(reqBody.Content, struct { + Type string `json:"type"` + Text string `json:"text,omitempty"` + URL string `json:"url,omitempty"` + }{Type: "text", Text: prompt}) + + // 如果有图片URL,添加图片内容 + if imageURL != "" { + reqBody.Content = append(reqBody.Content, struct { + Type string `json:"type"` + Text string `json:"text,omitempty"` + URL string `json:"url,omitempty"` + }{Type: "image", URL: imageURL}) + } + + jsonData, err = json.Marshal(reqBody) + } else if strings.Contains(model, "sora") { + // Sora 格式 + seconds := fmt.Sprintf("%d", options.Duration) + size := options.AspectRatio + if size == "16:9" { + size = "1280x720" + } else if size == "9:16" { + size = "720x1280" + } + + reqBody := ChatfireSoraRequest{ + Model: model, + Prompt: prompt, + Seconds: seconds, + Size: size, + InputReference: imageURL, + } + jsonData, err = json.Marshal(reqBody) + } else { + // 默认格式 + reqBody := ChatfireRequest{ + Model: model, + Prompt: prompt, + ImageURL: imageURL, + Duration: options.Duration, + Size: options.AspectRatio, + } + jsonData, err = json.Marshal(reqBody) } - jsonData, err := json.Marshal(reqBody) if err != nil { return nil, fmt.Errorf("marshal request: %w", err) } @@ -112,19 +217,40 @@ func (c *ChatfireClient) GenerateVideo(imageURL, prompt string, opts ...VideoOpt return nil, fmt.Errorf("API error (status %d): %s", resp.StatusCode, string(body)) } + // 调试日志:打印响应内容 + fmt.Printf("[Chatfire] Response body: %s\n", string(body)) + var result ChatfireResponse if err := json.Unmarshal(body, &result); err != nil { - return nil, fmt.Errorf("parse response: %w", err) + return nil, fmt.Errorf("parse response: %w, body: %s", err, string(body)) } - if result.Error != "" { - return nil, fmt.Errorf("chatfire error: %s", result.Error) + // 优先使用 id 字段,其次使用 task_id + taskID := result.ID + if taskID == "" { + taskID = result.TaskID + } + + // 如果有 data 嵌套,优先使用 data 中的值 + if result.Data.ID != "" { + taskID = result.Data.ID + } + + status := result.Status + if status == "" && result.Data.Status != "" { + status = result.Data.Status + } + + fmt.Printf("[Chatfire] Parsed result - TaskID: %s, Status: %s\n", taskID, status) + + if errMsg := getErrorMessage(result.Error); errMsg != "" { + return nil, fmt.Errorf("chatfire error: %s", errMsg) } videoResult := &VideoResult{ - TaskID: result.TaskID, - Status: result.Status, - Completed: result.Status == "completed" || result.Status == "succeeded", + TaskID: taskID, + Status: status, + Completed: status == "completed" || status == "succeeded", Duration: options.Duration, } @@ -162,21 +288,42 @@ func (c *ChatfireClient) GetTaskStatus(taskID string) (*VideoResult, error) { var result ChatfireTaskResponse if err := json.Unmarshal(body, &result); err != nil { - return nil, fmt.Errorf("parse response: %w", err) + return nil, fmt.Errorf("parse response: %w, body: %s", err, string(body)) + } + + // 优先使用 id 字段,其次使用 task_id + responseTaskID := result.ID + if responseTaskID == "" { + responseTaskID = result.TaskID + } + + // 如果有 data 嵌套,优先使用 data 中的值 + if result.Data.ID != "" { + responseTaskID = result.Data.ID + } + + status := result.Status + if status == "" && result.Data.Status != "" { + status = result.Data.Status + } + + videoURL := result.VideoURL + if videoURL == "" && result.Data.VideoURL != "" { + videoURL = result.Data.VideoURL } videoResult := &VideoResult{ - TaskID: result.TaskID, - Status: result.Status, - Completed: result.Status == "completed" || result.Status == "succeeded", + TaskID: responseTaskID, + Status: status, + Completed: status == "completed" || status == "succeeded", } - if result.Error != "" { - videoResult.Error = result.Error + if errMsg := getErrorMessage(result.Error); errMsg != "" { + videoResult.Error = errMsg } - if result.VideoURL != "" { - videoResult.VideoURL = result.VideoURL + if videoURL != "" { + videoResult.VideoURL = videoURL videoResult.Completed = true }