Files
huobao-drama/application/services/video_generation_service.go
2026-01-14 22:57:03 +08:00

567 lines
18 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
package services
import (
"encoding/json"
"fmt"
"strconv"
"time"
models "github.com/drama-generator/backend/domain/models"
"github.com/drama-generator/backend/infrastructure/storage"
"github.com/drama-generator/backend/pkg/logger"
"github.com/drama-generator/backend/pkg/video"
"gorm.io/gorm"
)
type VideoGenerationService struct {
db *gorm.DB
transferService *ResourceTransferService
log *logger.Logger
localStorage *storage.LocalStorage
aiService *AIService
}
func NewVideoGenerationService(db *gorm.DB, transferService *ResourceTransferService, localStorage *storage.LocalStorage, aiService *AIService, log *logger.Logger) *VideoGenerationService {
service := &VideoGenerationService{
db: db,
localStorage: localStorage,
transferService: transferService,
aiService: aiService,
log: log,
}
go service.RecoverPendingTasks()
return service
}
type GenerateVideoRequest struct {
StoryboardID *uint `json:"storyboard_id"`
DramaID string `json:"drama_id" binding:"required"`
ImageGenID *uint `json:"image_gen_id"`
// 参考图模式single, first_last, multiple, none
ReferenceMode string `json:"reference_mode"`
// 单图模式
ImageURL string `json:"image_url"`
// 首尾帧模式
FirstFrameURL *string `json:"first_frame_url"`
LastFrameURL *string `json:"last_frame_url"`
// 多图模式
ReferenceImageURLs []string `json:"reference_image_urls"`
Prompt string `json:"prompt" binding:"required,min=5,max=2000"`
Provider string `json:"provider"`
Model string `json:"model"`
Duration *int `json:"duration"`
FPS *int `json:"fps"`
AspectRatio *string `json:"aspect_ratio"`
Style *string `json:"style"`
MotionLevel *int `json:"motion_level"`
CameraMotion *string `json:"camera_motion"`
Seed *int64 `json:"seed"`
}
func (s *VideoGenerationService) GenerateVideo(request *GenerateVideoRequest) (*models.VideoGeneration, error) {
if request.StoryboardID != nil {
var storyboard models.Storyboard
if err := s.db.Preload("Episode").Where("id = ?", *request.StoryboardID).First(&storyboard).Error; err != nil {
return nil, fmt.Errorf("storyboard not found")
}
if fmt.Sprintf("%d", storyboard.Episode.DramaID) != request.DramaID {
return nil, fmt.Errorf("storyboard does not belong to drama")
}
}
if request.ImageGenID != nil {
var imageGen models.ImageGeneration
if err := s.db.Where("id = ?", *request.ImageGenID).First(&imageGen).Error; err != nil {
return nil, fmt.Errorf("image generation not found")
}
}
provider := request.Provider
if provider == "" {
provider = "doubao"
}
dramaID, _ := strconv.ParseUint(request.DramaID, 10, 32)
videoGen := &models.VideoGeneration{
StoryboardID: request.StoryboardID,
DramaID: uint(dramaID),
ImageGenID: request.ImageGenID,
Provider: provider,
Prompt: request.Prompt,
Model: request.Model,
Duration: request.Duration,
FPS: request.FPS,
AspectRatio: request.AspectRatio,
Style: request.Style,
MotionLevel: request.MotionLevel,
CameraMotion: request.CameraMotion,
Seed: request.Seed,
Status: models.VideoStatusPending,
}
// 根据参考图模式处理不同的参数
if request.ReferenceMode != "" {
videoGen.ReferenceMode = &request.ReferenceMode
}
switch request.ReferenceMode {
case "single":
// 单图模式
if request.ImageURL != "" {
videoGen.ImageURL = &request.ImageURL
}
case "first_last":
// 首尾帧模式
if request.FirstFrameURL != nil {
videoGen.FirstFrameURL = request.FirstFrameURL
}
if request.LastFrameURL != nil {
videoGen.LastFrameURL = request.LastFrameURL
}
case "multiple":
// 多图模式
if len(request.ReferenceImageURLs) > 0 {
referenceImagesJSON, err := json.Marshal(request.ReferenceImageURLs)
if err == nil {
referenceImagesStr := string(referenceImagesJSON)
videoGen.ReferenceImageURLs = &referenceImagesStr
}
}
case "none":
// 无参考图,纯文本生成
default:
// 向后兼容:如果没有指定模式,根据提供的参数自动判断
if request.ImageURL != "" {
videoGen.ImageURL = &request.ImageURL
mode := "single"
videoGen.ReferenceMode = &mode
} else if request.FirstFrameURL != nil || request.LastFrameURL != nil {
videoGen.FirstFrameURL = request.FirstFrameURL
videoGen.LastFrameURL = request.LastFrameURL
mode := "first_last"
videoGen.ReferenceMode = &mode
} else if len(request.ReferenceImageURLs) > 0 {
referenceImagesJSON, err := json.Marshal(request.ReferenceImageURLs)
if err == nil {
referenceImagesStr := string(referenceImagesJSON)
videoGen.ReferenceImageURLs = &referenceImagesStr
mode := "multiple"
videoGen.ReferenceMode = &mode
}
}
}
if err := s.db.Create(videoGen).Error; err != nil {
return nil, fmt.Errorf("failed to create record: %w", err)
}
go s.ProcessVideoGeneration(videoGen.ID)
return videoGen, nil
}
func (s *VideoGenerationService) ProcessVideoGeneration(videoGenID uint) {
var videoGen models.VideoGeneration
if err := s.db.First(&videoGen, videoGenID).Error; err != nil {
s.log.Errorw("Failed to load video generation", "error", err, "id", videoGenID)
return
}
s.db.Model(&videoGen).Update("status", models.VideoStatusProcessing)
client, err := s.getVideoClient(videoGen.Provider, videoGen.Model)
if err != nil {
s.log.Errorw("Failed to get video client", "error", err, "provider", videoGen.Provider, "model", videoGen.Model)
s.updateVideoGenError(videoGenID, err.Error())
return
}
s.log.Infow("Starting video generation", "id", videoGenID, "prompt", videoGen.Prompt, "provider", videoGen.Provider)
var opts []video.VideoOption
if videoGen.Model != "" {
opts = append(opts, video.WithModel(videoGen.Model))
}
if videoGen.Duration != nil {
opts = append(opts, video.WithDuration(*videoGen.Duration))
}
if videoGen.FPS != nil {
opts = append(opts, video.WithFPS(*videoGen.FPS))
}
if videoGen.AspectRatio != nil {
opts = append(opts, video.WithAspectRatio(*videoGen.AspectRatio))
}
if videoGen.Style != nil {
opts = append(opts, video.WithStyle(*videoGen.Style))
}
if videoGen.MotionLevel != nil {
opts = append(opts, video.WithMotionLevel(*videoGen.MotionLevel))
}
if videoGen.CameraMotion != nil {
opts = append(opts, video.WithCameraMotion(*videoGen.CameraMotion))
}
if videoGen.Seed != nil {
opts = append(opts, video.WithSeed(*videoGen.Seed))
}
// 根据参考图模式添加相应的选项
if videoGen.ReferenceMode != nil {
switch *videoGen.ReferenceMode {
case "first_last":
// 首尾帧模式
if videoGen.FirstFrameURL != nil {
opts = append(opts, video.WithFirstFrame(*videoGen.FirstFrameURL))
}
if videoGen.LastFrameURL != nil {
opts = append(opts, video.WithLastFrame(*videoGen.LastFrameURL))
}
case "multiple":
// 多图模式
if videoGen.ReferenceImageURLs != nil {
var imageURLs []string
if err := json.Unmarshal([]byte(*videoGen.ReferenceImageURLs), &imageURLs); err == nil {
opts = append(opts, video.WithReferenceImages(imageURLs))
}
}
}
}
// 构造imageURL参数单图模式使用其他模式传空字符串
imageURL := ""
if videoGen.ImageURL != nil {
imageURL = *videoGen.ImageURL
}
result, err := client.GenerateVideo(imageURL, videoGen.Prompt, opts...)
if err != nil {
s.log.Errorw("Video generation API call failed", "error", err, "id", videoGenID)
s.updateVideoGenError(videoGenID, err.Error())
return
}
if result.TaskID != "" {
s.db.Model(&videoGen).Updates(map[string]interface{}{
"task_id": result.TaskID,
"status": models.VideoStatusProcessing,
})
go s.pollTaskStatus(videoGenID, result.TaskID, videoGen.Provider, videoGen.Model)
return
}
if result.VideoURL != "" {
s.completeVideoGeneration(videoGenID, result.VideoURL, &result.Duration, &result.Width, &result.Height, nil)
return
}
s.updateVideoGenError(videoGenID, "no task ID or video URL returned")
}
func (s *VideoGenerationService) pollTaskStatus(videoGenID uint, taskID string, provider string, model string) {
client, err := s.getVideoClient(provider, model)
if err != nil {
s.log.Errorw("Failed to get video client for polling", "error", err)
s.updateVideoGenError(videoGenID, "failed to get video client")
return
}
maxAttempts := 300
interval := 10 * time.Second
for attempt := 0; attempt < maxAttempts; attempt++ {
time.Sleep(interval)
var videoGen models.VideoGeneration
if err := s.db.First(&videoGen, videoGenID).Error; err != nil {
s.log.Errorw("Failed to load video generation", "error", err, "id", videoGenID)
return
}
if videoGen.Status != models.VideoStatusProcessing {
s.log.Infow("Video generation status changed, stopping poll", "id", videoGenID, "status", videoGen.Status)
return
}
result, err := client.GetTaskStatus(taskID)
if err != nil {
s.log.Errorw("Failed to get task status", "error", err, "task_id", taskID)
continue
}
if result.Completed {
if result.VideoURL != "" {
s.completeVideoGeneration(videoGenID, result.VideoURL, &result.Duration, &result.Width, &result.Height, nil)
return
}
s.updateVideoGenError(videoGenID, "task completed but no video URL")
return
}
if result.Error != "" {
s.updateVideoGenError(videoGenID, result.Error)
return
}
s.log.Infow("Video generation in progress", "id", videoGenID, "attempt", attempt+1)
}
s.updateVideoGenError(videoGenID, "polling timeout")
}
func (s *VideoGenerationService) completeVideoGeneration(videoGenID uint, videoURL string, duration *int, width *int, height *int, firstFrameURL *string) {
// 下载视频到本地存储(仅用于缓存,不更新数据库)
if s.localStorage != nil && videoURL != "" {
_, err := s.localStorage.DownloadFromURL(videoURL, "videos")
if err != nil {
s.log.Warnw("Failed to download video to local storage",
"error", err,
"id", videoGenID,
"original_url", videoURL)
} else {
s.log.Infow("Video downloaded to local storage for caching",
"id", videoGenID,
"original_url", videoURL)
}
}
// 下载首帧图片到本地存储(仅用于缓存,不更新数据库)
if firstFrameURL != nil && *firstFrameURL != "" && s.localStorage != nil {
_, err := s.localStorage.DownloadFromURL(*firstFrameURL, "video_frames")
if err != nil {
s.log.Warnw("Failed to download first frame to local storage",
"error", err,
"id", videoGenID,
"original_url", *firstFrameURL)
} else {
s.log.Infow("First frame downloaded to local storage for caching",
"id", videoGenID,
"original_url", *firstFrameURL)
}
}
// 数据库中保持使用原始URL
updates := map[string]interface{}{
"status": models.VideoStatusCompleted,
"video_url": videoURL,
}
if duration != nil {
updates["duration"] = *duration
}
if width != nil {
updates["width"] = *width
}
if height != nil {
updates["height"] = *height
}
if firstFrameURL != nil {
updates["first_frame_url"] = *firstFrameURL
}
if err := s.db.Model(&models.VideoGeneration{}).Where("id = ?", videoGenID).Updates(updates).Error; err != nil {
s.log.Errorw("Failed to update video generation", "error", err, "id", videoGenID)
return
}
var videoGen models.VideoGeneration
if err := s.db.First(&videoGen, videoGenID).Error; err == nil {
if videoGen.StoryboardID != nil {
if err := s.db.Model(&models.Storyboard{}).Where("id = ?", *videoGen.StoryboardID).Update("video_url", videoURL).Error; err != nil {
s.log.Warnw("Failed to update storyboard video_url", "storyboard_id", *videoGen.StoryboardID, "error", err)
}
}
}
s.log.Infow("Video generation completed", "id", videoGenID, "url", videoURL)
}
func (s *VideoGenerationService) updateVideoGenError(videoGenID uint, errorMsg string) {
if err := s.db.Model(&models.VideoGeneration{}).Where("id = ?", videoGenID).Updates(map[string]interface{}{
"status": models.VideoStatusFailed,
"error_msg": errorMsg,
}).Error; err != nil {
s.log.Errorw("Failed to update video generation error", "error", err, "id", videoGenID)
}
}
func (s *VideoGenerationService) getVideoClient(provider string, modelName string) (video.VideoClient, error) {
// 根据模型名称获取AI配置
var config *models.AIServiceConfig
var err error
if modelName != "" {
config, err = s.aiService.GetConfigForModel("video", modelName)
if err != nil {
s.log.Warnw("Failed to get config for model, using default", "model", modelName, "error", err)
config, err = s.aiService.GetDefaultConfig("video")
if err != nil {
return nil, fmt.Errorf("no video AI config found: %w", err)
}
}
} else {
config, err = s.aiService.GetDefaultConfig("video")
if err != nil {
return nil, fmt.Errorf("no video AI config found: %w", err)
}
}
// 使用配置中的信息创建客户端
baseURL := config.BaseURL
apiKey := config.APIKey
model := modelName
if model == "" && len(config.Model) > 0 {
model = config.Model[0]
}
// 根据配置中的 provider 创建对应的客户端
var endpoint string
var queryEndpoint string
switch config.Provider {
case "chatfire":
endpoint = "/video/generations"
queryEndpoint = "/video/task/{taskId}"
return video.NewChatfireClient(baseURL, apiKey, model, endpoint, queryEndpoint), nil
case "doubao", "volcengine", "volces":
endpoint = "/contents/generations/tasks"
queryEndpoint = "/contents/generations/tasks/{taskId}"
return video.NewVolcesArkClient(baseURL, apiKey, model, endpoint, queryEndpoint), nil
case "openai":
// OpenAI Sora 使用 /v1/videos 端点
return video.NewOpenAISoraClient(baseURL, apiKey, model), nil
case "runway":
return video.NewRunwayClient(baseURL, apiKey, model), nil
case "pika":
return video.NewPikaClient(baseURL, apiKey, model), nil
case "minimax":
return video.NewMinimaxClient(baseURL, apiKey, model), nil
default:
return nil, fmt.Errorf("unsupported video provider: %s", provider)
}
}
func (s *VideoGenerationService) RecoverPendingTasks() {
var pendingVideos []models.VideoGeneration
if err := s.db.Where("status = ? AND task_id != ''", models.VideoStatusProcessing).Find(&pendingVideos).Error; err != nil {
s.log.Errorw("Failed to load pending video tasks", "error", err)
return
}
s.log.Infow("Recovering pending video generation tasks", "count", len(pendingVideos))
for _, videoGen := range pendingVideos {
go s.pollTaskStatus(videoGen.ID, *videoGen.TaskID, videoGen.Provider, videoGen.Model)
}
}
func (s *VideoGenerationService) GetVideoGeneration(id uint) (*models.VideoGeneration, error) {
var videoGen models.VideoGeneration
if err := s.db.First(&videoGen, id).Error; err != nil {
return nil, err
}
return &videoGen, nil
}
func (s *VideoGenerationService) ListVideoGenerations(dramaID *uint, storyboardID *uint, status string, limit int, offset int) ([]*models.VideoGeneration, int64, error) {
var videos []*models.VideoGeneration
var total int64
query := s.db.Model(&models.VideoGeneration{})
if dramaID != nil {
query = query.Where("drama_id = ?", *dramaID)
}
if storyboardID != nil {
query = query.Where("storyboard_id = ?", *storyboardID)
}
if status != "" {
query = query.Where("status = ?", status)
}
if err := query.Count(&total).Error; err != nil {
return nil, 0, err
}
if err := query.Order("created_at DESC").Limit(limit).Offset(offset).Find(&videos).Error; err != nil {
return nil, 0, err
}
return videos, total, nil
}
func (s *VideoGenerationService) GenerateVideoFromImage(imageGenID uint) (*models.VideoGeneration, error) {
var imageGen models.ImageGeneration
if err := s.db.First(&imageGen, imageGenID).Error; err != nil {
return nil, fmt.Errorf("image generation not found")
}
if imageGen.Status != models.ImageStatusCompleted || imageGen.ImageURL == nil {
return nil, fmt.Errorf("image is not ready")
}
// 获取关联的Storyboard以获取时长
var duration *int
if imageGen.StoryboardID != nil {
var storyboard models.Storyboard
if err := s.db.Where("id = ?", *imageGen.StoryboardID).First(&storyboard).Error; err == nil {
duration = &storyboard.Duration
s.log.Infow("Using storyboard duration for video generation",
"storyboard_id", *imageGen.StoryboardID,
"duration", storyboard.Duration)
}
}
req := &GenerateVideoRequest{
DramaID: fmt.Sprintf("%d", imageGen.DramaID),
StoryboardID: imageGen.StoryboardID,
ImageGenID: &imageGenID,
ImageURL: *imageGen.ImageURL,
Prompt: imageGen.Prompt,
Provider: "doubao",
Duration: duration,
}
return s.GenerateVideo(req)
}
func (s *VideoGenerationService) BatchGenerateVideosForEpisode(episodeID string) ([]*models.VideoGeneration, error) {
var episode models.Episode
if err := s.db.Preload("Storyboards").Where("id = ?", episodeID).First(&episode).Error; err != nil {
return nil, fmt.Errorf("episode not found")
}
var results []*models.VideoGeneration
for _, storyboard := range episode.Storyboards {
if storyboard.ImagePrompt == nil {
continue
}
var imageGen models.ImageGeneration
if err := s.db.Where("storyboard_id = ? AND status = ?", storyboard.ID, models.ImageStatusCompleted).
Order("created_at DESC").First(&imageGen).Error; err != nil {
s.log.Warnw("No completed image for storyboard", "storyboard_id", storyboard.ID)
continue
}
videoGen, err := s.GenerateVideoFromImage(imageGen.ID)
if err != nil {
s.log.Errorw("Failed to generate video", "storyboard_id", storyboard.ID, "error", err)
continue
}
results = append(results, videoGen)
}
return results, nil
}
func (s *VideoGenerationService) DeleteVideoGeneration(id uint) error {
return s.db.Delete(&models.VideoGeneration{}, id).Error
}