567 lines
18 KiB
Go
567 lines
18 KiB
Go
package services
|
||
|
||
import (
|
||
"encoding/json"
|
||
"fmt"
|
||
"strconv"
|
||
"time"
|
||
|
||
models "github.com/drama-generator/backend/domain/models"
|
||
"github.com/drama-generator/backend/infrastructure/storage"
|
||
"github.com/drama-generator/backend/pkg/logger"
|
||
"github.com/drama-generator/backend/pkg/video"
|
||
"gorm.io/gorm"
|
||
)
|
||
|
||
type VideoGenerationService struct {
|
||
db *gorm.DB
|
||
transferService *ResourceTransferService
|
||
log *logger.Logger
|
||
localStorage *storage.LocalStorage
|
||
aiService *AIService
|
||
}
|
||
|
||
func NewVideoGenerationService(db *gorm.DB, transferService *ResourceTransferService, localStorage *storage.LocalStorage, aiService *AIService, log *logger.Logger) *VideoGenerationService {
|
||
service := &VideoGenerationService{
|
||
db: db,
|
||
localStorage: localStorage,
|
||
transferService: transferService,
|
||
aiService: aiService,
|
||
log: log,
|
||
}
|
||
|
||
go service.RecoverPendingTasks()
|
||
|
||
return service
|
||
}
|
||
|
||
type GenerateVideoRequest struct {
|
||
StoryboardID *uint `json:"storyboard_id"`
|
||
DramaID string `json:"drama_id" binding:"required"`
|
||
ImageGenID *uint `json:"image_gen_id"`
|
||
|
||
// 参考图模式:single, first_last, multiple, none
|
||
ReferenceMode string `json:"reference_mode"`
|
||
|
||
// 单图模式
|
||
ImageURL string `json:"image_url"`
|
||
|
||
// 首尾帧模式
|
||
FirstFrameURL *string `json:"first_frame_url"`
|
||
LastFrameURL *string `json:"last_frame_url"`
|
||
|
||
// 多图模式
|
||
ReferenceImageURLs []string `json:"reference_image_urls"`
|
||
|
||
Prompt string `json:"prompt" binding:"required,min=5,max=2000"`
|
||
Provider string `json:"provider"`
|
||
Model string `json:"model"`
|
||
Duration *int `json:"duration"`
|
||
FPS *int `json:"fps"`
|
||
AspectRatio *string `json:"aspect_ratio"`
|
||
Style *string `json:"style"`
|
||
MotionLevel *int `json:"motion_level"`
|
||
CameraMotion *string `json:"camera_motion"`
|
||
Seed *int64 `json:"seed"`
|
||
}
|
||
|
||
func (s *VideoGenerationService) GenerateVideo(request *GenerateVideoRequest) (*models.VideoGeneration, error) {
|
||
if request.StoryboardID != nil {
|
||
var storyboard models.Storyboard
|
||
if err := s.db.Preload("Episode").Where("id = ?", *request.StoryboardID).First(&storyboard).Error; err != nil {
|
||
return nil, fmt.Errorf("storyboard not found")
|
||
}
|
||
if fmt.Sprintf("%d", storyboard.Episode.DramaID) != request.DramaID {
|
||
return nil, fmt.Errorf("storyboard does not belong to drama")
|
||
}
|
||
}
|
||
|
||
if request.ImageGenID != nil {
|
||
var imageGen models.ImageGeneration
|
||
if err := s.db.Where("id = ?", *request.ImageGenID).First(&imageGen).Error; err != nil {
|
||
return nil, fmt.Errorf("image generation not found")
|
||
}
|
||
}
|
||
|
||
provider := request.Provider
|
||
if provider == "" {
|
||
provider = "doubao"
|
||
}
|
||
|
||
dramaID, _ := strconv.ParseUint(request.DramaID, 10, 32)
|
||
|
||
videoGen := &models.VideoGeneration{
|
||
StoryboardID: request.StoryboardID,
|
||
DramaID: uint(dramaID),
|
||
ImageGenID: request.ImageGenID,
|
||
Provider: provider,
|
||
Prompt: request.Prompt,
|
||
Model: request.Model,
|
||
Duration: request.Duration,
|
||
FPS: request.FPS,
|
||
AspectRatio: request.AspectRatio,
|
||
Style: request.Style,
|
||
MotionLevel: request.MotionLevel,
|
||
CameraMotion: request.CameraMotion,
|
||
Seed: request.Seed,
|
||
Status: models.VideoStatusPending,
|
||
}
|
||
|
||
// 根据参考图模式处理不同的参数
|
||
if request.ReferenceMode != "" {
|
||
videoGen.ReferenceMode = &request.ReferenceMode
|
||
}
|
||
|
||
switch request.ReferenceMode {
|
||
case "single":
|
||
// 单图模式
|
||
if request.ImageURL != "" {
|
||
videoGen.ImageURL = &request.ImageURL
|
||
}
|
||
case "first_last":
|
||
// 首尾帧模式
|
||
if request.FirstFrameURL != nil {
|
||
videoGen.FirstFrameURL = request.FirstFrameURL
|
||
}
|
||
if request.LastFrameURL != nil {
|
||
videoGen.LastFrameURL = request.LastFrameURL
|
||
}
|
||
case "multiple":
|
||
// 多图模式
|
||
if len(request.ReferenceImageURLs) > 0 {
|
||
referenceImagesJSON, err := json.Marshal(request.ReferenceImageURLs)
|
||
if err == nil {
|
||
referenceImagesStr := string(referenceImagesJSON)
|
||
videoGen.ReferenceImageURLs = &referenceImagesStr
|
||
}
|
||
}
|
||
case "none":
|
||
// 无参考图,纯文本生成
|
||
default:
|
||
// 向后兼容:如果没有指定模式,根据提供的参数自动判断
|
||
if request.ImageURL != "" {
|
||
videoGen.ImageURL = &request.ImageURL
|
||
mode := "single"
|
||
videoGen.ReferenceMode = &mode
|
||
} else if request.FirstFrameURL != nil || request.LastFrameURL != nil {
|
||
videoGen.FirstFrameURL = request.FirstFrameURL
|
||
videoGen.LastFrameURL = request.LastFrameURL
|
||
mode := "first_last"
|
||
videoGen.ReferenceMode = &mode
|
||
} else if len(request.ReferenceImageURLs) > 0 {
|
||
referenceImagesJSON, err := json.Marshal(request.ReferenceImageURLs)
|
||
if err == nil {
|
||
referenceImagesStr := string(referenceImagesJSON)
|
||
videoGen.ReferenceImageURLs = &referenceImagesStr
|
||
mode := "multiple"
|
||
videoGen.ReferenceMode = &mode
|
||
}
|
||
}
|
||
}
|
||
|
||
if err := s.db.Create(videoGen).Error; err != nil {
|
||
return nil, fmt.Errorf("failed to create record: %w", err)
|
||
}
|
||
|
||
go s.ProcessVideoGeneration(videoGen.ID)
|
||
|
||
return videoGen, nil
|
||
}
|
||
|
||
func (s *VideoGenerationService) ProcessVideoGeneration(videoGenID uint) {
|
||
var videoGen models.VideoGeneration
|
||
if err := s.db.First(&videoGen, videoGenID).Error; err != nil {
|
||
s.log.Errorw("Failed to load video generation", "error", err, "id", videoGenID)
|
||
return
|
||
}
|
||
|
||
s.db.Model(&videoGen).Update("status", models.VideoStatusProcessing)
|
||
|
||
client, err := s.getVideoClient(videoGen.Provider, videoGen.Model)
|
||
if err != nil {
|
||
s.log.Errorw("Failed to get video client", "error", err, "provider", videoGen.Provider, "model", videoGen.Model)
|
||
s.updateVideoGenError(videoGenID, err.Error())
|
||
return
|
||
}
|
||
|
||
s.log.Infow("Starting video generation", "id", videoGenID, "prompt", videoGen.Prompt, "provider", videoGen.Provider)
|
||
|
||
var opts []video.VideoOption
|
||
if videoGen.Model != "" {
|
||
opts = append(opts, video.WithModel(videoGen.Model))
|
||
}
|
||
if videoGen.Duration != nil {
|
||
opts = append(opts, video.WithDuration(*videoGen.Duration))
|
||
}
|
||
if videoGen.FPS != nil {
|
||
opts = append(opts, video.WithFPS(*videoGen.FPS))
|
||
}
|
||
if videoGen.AspectRatio != nil {
|
||
opts = append(opts, video.WithAspectRatio(*videoGen.AspectRatio))
|
||
}
|
||
if videoGen.Style != nil {
|
||
opts = append(opts, video.WithStyle(*videoGen.Style))
|
||
}
|
||
if videoGen.MotionLevel != nil {
|
||
opts = append(opts, video.WithMotionLevel(*videoGen.MotionLevel))
|
||
}
|
||
if videoGen.CameraMotion != nil {
|
||
opts = append(opts, video.WithCameraMotion(*videoGen.CameraMotion))
|
||
}
|
||
if videoGen.Seed != nil {
|
||
opts = append(opts, video.WithSeed(*videoGen.Seed))
|
||
}
|
||
|
||
// 根据参考图模式添加相应的选项
|
||
if videoGen.ReferenceMode != nil {
|
||
switch *videoGen.ReferenceMode {
|
||
case "first_last":
|
||
// 首尾帧模式
|
||
if videoGen.FirstFrameURL != nil {
|
||
opts = append(opts, video.WithFirstFrame(*videoGen.FirstFrameURL))
|
||
}
|
||
if videoGen.LastFrameURL != nil {
|
||
opts = append(opts, video.WithLastFrame(*videoGen.LastFrameURL))
|
||
}
|
||
case "multiple":
|
||
// 多图模式
|
||
if videoGen.ReferenceImageURLs != nil {
|
||
var imageURLs []string
|
||
if err := json.Unmarshal([]byte(*videoGen.ReferenceImageURLs), &imageURLs); err == nil {
|
||
opts = append(opts, video.WithReferenceImages(imageURLs))
|
||
}
|
||
}
|
||
}
|
||
}
|
||
|
||
// 构造imageURL参数(单图模式使用,其他模式传空字符串)
|
||
imageURL := ""
|
||
if videoGen.ImageURL != nil {
|
||
imageURL = *videoGen.ImageURL
|
||
}
|
||
|
||
result, err := client.GenerateVideo(imageURL, videoGen.Prompt, opts...)
|
||
if err != nil {
|
||
s.log.Errorw("Video generation API call failed", "error", err, "id", videoGenID)
|
||
s.updateVideoGenError(videoGenID, err.Error())
|
||
return
|
||
}
|
||
|
||
if result.TaskID != "" {
|
||
s.db.Model(&videoGen).Updates(map[string]interface{}{
|
||
"task_id": result.TaskID,
|
||
"status": models.VideoStatusProcessing,
|
||
})
|
||
go s.pollTaskStatus(videoGenID, result.TaskID, videoGen.Provider, videoGen.Model)
|
||
return
|
||
}
|
||
|
||
if result.VideoURL != "" {
|
||
s.completeVideoGeneration(videoGenID, result.VideoURL, &result.Duration, &result.Width, &result.Height, nil)
|
||
return
|
||
}
|
||
|
||
s.updateVideoGenError(videoGenID, "no task ID or video URL returned")
|
||
}
|
||
|
||
func (s *VideoGenerationService) pollTaskStatus(videoGenID uint, taskID string, provider string, model string) {
|
||
client, err := s.getVideoClient(provider, model)
|
||
if err != nil {
|
||
s.log.Errorw("Failed to get video client for polling", "error", err)
|
||
s.updateVideoGenError(videoGenID, "failed to get video client")
|
||
return
|
||
}
|
||
|
||
maxAttempts := 300
|
||
interval := 10 * time.Second
|
||
|
||
for attempt := 0; attempt < maxAttempts; attempt++ {
|
||
time.Sleep(interval)
|
||
|
||
var videoGen models.VideoGeneration
|
||
if err := s.db.First(&videoGen, videoGenID).Error; err != nil {
|
||
s.log.Errorw("Failed to load video generation", "error", err, "id", videoGenID)
|
||
return
|
||
}
|
||
|
||
if videoGen.Status != models.VideoStatusProcessing {
|
||
s.log.Infow("Video generation status changed, stopping poll", "id", videoGenID, "status", videoGen.Status)
|
||
return
|
||
}
|
||
|
||
result, err := client.GetTaskStatus(taskID)
|
||
if err != nil {
|
||
s.log.Errorw("Failed to get task status", "error", err, "task_id", taskID)
|
||
continue
|
||
}
|
||
|
||
if result.Completed {
|
||
if result.VideoURL != "" {
|
||
s.completeVideoGeneration(videoGenID, result.VideoURL, &result.Duration, &result.Width, &result.Height, nil)
|
||
return
|
||
}
|
||
s.updateVideoGenError(videoGenID, "task completed but no video URL")
|
||
return
|
||
}
|
||
|
||
if result.Error != "" {
|
||
s.updateVideoGenError(videoGenID, result.Error)
|
||
return
|
||
}
|
||
|
||
s.log.Infow("Video generation in progress", "id", videoGenID, "attempt", attempt+1)
|
||
}
|
||
|
||
s.updateVideoGenError(videoGenID, "polling timeout")
|
||
}
|
||
|
||
func (s *VideoGenerationService) completeVideoGeneration(videoGenID uint, videoURL string, duration *int, width *int, height *int, firstFrameURL *string) {
|
||
// 下载视频到本地存储(仅用于缓存,不更新数据库)
|
||
if s.localStorage != nil && videoURL != "" {
|
||
_, err := s.localStorage.DownloadFromURL(videoURL, "videos")
|
||
if err != nil {
|
||
s.log.Warnw("Failed to download video to local storage",
|
||
"error", err,
|
||
"id", videoGenID,
|
||
"original_url", videoURL)
|
||
} else {
|
||
s.log.Infow("Video downloaded to local storage for caching",
|
||
"id", videoGenID,
|
||
"original_url", videoURL)
|
||
}
|
||
}
|
||
|
||
// 下载首帧图片到本地存储(仅用于缓存,不更新数据库)
|
||
if firstFrameURL != nil && *firstFrameURL != "" && s.localStorage != nil {
|
||
_, err := s.localStorage.DownloadFromURL(*firstFrameURL, "video_frames")
|
||
if err != nil {
|
||
s.log.Warnw("Failed to download first frame to local storage",
|
||
"error", err,
|
||
"id", videoGenID,
|
||
"original_url", *firstFrameURL)
|
||
} else {
|
||
s.log.Infow("First frame downloaded to local storage for caching",
|
||
"id", videoGenID,
|
||
"original_url", *firstFrameURL)
|
||
}
|
||
}
|
||
|
||
// 数据库中保持使用原始URL
|
||
updates := map[string]interface{}{
|
||
"status": models.VideoStatusCompleted,
|
||
"video_url": videoURL,
|
||
}
|
||
if duration != nil {
|
||
updates["duration"] = *duration
|
||
}
|
||
if width != nil {
|
||
updates["width"] = *width
|
||
}
|
||
if height != nil {
|
||
updates["height"] = *height
|
||
}
|
||
if firstFrameURL != nil {
|
||
updates["first_frame_url"] = *firstFrameURL
|
||
}
|
||
|
||
if err := s.db.Model(&models.VideoGeneration{}).Where("id = ?", videoGenID).Updates(updates).Error; err != nil {
|
||
s.log.Errorw("Failed to update video generation", "error", err, "id", videoGenID)
|
||
return
|
||
}
|
||
|
||
var videoGen models.VideoGeneration
|
||
if err := s.db.First(&videoGen, videoGenID).Error; err == nil {
|
||
if videoGen.StoryboardID != nil {
|
||
if err := s.db.Model(&models.Storyboard{}).Where("id = ?", *videoGen.StoryboardID).Update("video_url", videoURL).Error; err != nil {
|
||
s.log.Warnw("Failed to update storyboard video_url", "storyboard_id", *videoGen.StoryboardID, "error", err)
|
||
}
|
||
}
|
||
}
|
||
|
||
s.log.Infow("Video generation completed", "id", videoGenID, "url", videoURL)
|
||
}
|
||
|
||
func (s *VideoGenerationService) updateVideoGenError(videoGenID uint, errorMsg string) {
|
||
if err := s.db.Model(&models.VideoGeneration{}).Where("id = ?", videoGenID).Updates(map[string]interface{}{
|
||
"status": models.VideoStatusFailed,
|
||
"error_msg": errorMsg,
|
||
}).Error; err != nil {
|
||
s.log.Errorw("Failed to update video generation error", "error", err, "id", videoGenID)
|
||
}
|
||
}
|
||
|
||
func (s *VideoGenerationService) getVideoClient(provider string, modelName string) (video.VideoClient, error) {
|
||
// 根据模型名称获取AI配置
|
||
var config *models.AIServiceConfig
|
||
var err error
|
||
|
||
if modelName != "" {
|
||
config, err = s.aiService.GetConfigForModel("video", modelName)
|
||
if err != nil {
|
||
s.log.Warnw("Failed to get config for model, using default", "model", modelName, "error", err)
|
||
config, err = s.aiService.GetDefaultConfig("video")
|
||
if err != nil {
|
||
return nil, fmt.Errorf("no video AI config found: %w", err)
|
||
}
|
||
}
|
||
} else {
|
||
config, err = s.aiService.GetDefaultConfig("video")
|
||
if err != nil {
|
||
return nil, fmt.Errorf("no video AI config found: %w", err)
|
||
}
|
||
}
|
||
|
||
// 使用配置中的信息创建客户端
|
||
baseURL := config.BaseURL
|
||
apiKey := config.APIKey
|
||
model := modelName
|
||
if model == "" && len(config.Model) > 0 {
|
||
model = config.Model[0]
|
||
}
|
||
|
||
// 根据配置中的 provider 创建对应的客户端
|
||
var endpoint string
|
||
var queryEndpoint string
|
||
|
||
switch config.Provider {
|
||
case "chatfire":
|
||
endpoint = "/video/generations"
|
||
queryEndpoint = "/video/task/{taskId}"
|
||
return video.NewChatfireClient(baseURL, apiKey, model, endpoint, queryEndpoint), nil
|
||
case "doubao", "volcengine", "volces":
|
||
endpoint = "/contents/generations/tasks"
|
||
queryEndpoint = "/contents/generations/tasks/{taskId}"
|
||
return video.NewVolcesArkClient(baseURL, apiKey, model, endpoint, queryEndpoint), nil
|
||
case "openai":
|
||
// OpenAI Sora 使用 /v1/videos 端点
|
||
return video.NewOpenAISoraClient(baseURL, apiKey, model), nil
|
||
case "runway":
|
||
return video.NewRunwayClient(baseURL, apiKey, model), nil
|
||
case "pika":
|
||
return video.NewPikaClient(baseURL, apiKey, model), nil
|
||
case "minimax":
|
||
return video.NewMinimaxClient(baseURL, apiKey, model), nil
|
||
default:
|
||
return nil, fmt.Errorf("unsupported video provider: %s", provider)
|
||
}
|
||
}
|
||
|
||
func (s *VideoGenerationService) RecoverPendingTasks() {
|
||
var pendingVideos []models.VideoGeneration
|
||
if err := s.db.Where("status = ? AND task_id != ''", models.VideoStatusProcessing).Find(&pendingVideos).Error; err != nil {
|
||
s.log.Errorw("Failed to load pending video tasks", "error", err)
|
||
return
|
||
}
|
||
|
||
s.log.Infow("Recovering pending video generation tasks", "count", len(pendingVideos))
|
||
|
||
for _, videoGen := range pendingVideos {
|
||
go s.pollTaskStatus(videoGen.ID, *videoGen.TaskID, videoGen.Provider, videoGen.Model)
|
||
}
|
||
}
|
||
|
||
func (s *VideoGenerationService) GetVideoGeneration(id uint) (*models.VideoGeneration, error) {
|
||
var videoGen models.VideoGeneration
|
||
if err := s.db.First(&videoGen, id).Error; err != nil {
|
||
return nil, err
|
||
}
|
||
return &videoGen, nil
|
||
}
|
||
|
||
func (s *VideoGenerationService) ListVideoGenerations(dramaID *uint, storyboardID *uint, status string, limit int, offset int) ([]*models.VideoGeneration, int64, error) {
|
||
var videos []*models.VideoGeneration
|
||
var total int64
|
||
|
||
query := s.db.Model(&models.VideoGeneration{})
|
||
|
||
if dramaID != nil {
|
||
query = query.Where("drama_id = ?", *dramaID)
|
||
}
|
||
if storyboardID != nil {
|
||
query = query.Where("storyboard_id = ?", *storyboardID)
|
||
}
|
||
if status != "" {
|
||
query = query.Where("status = ?", status)
|
||
}
|
||
|
||
if err := query.Count(&total).Error; err != nil {
|
||
return nil, 0, err
|
||
}
|
||
|
||
if err := query.Order("created_at DESC").Limit(limit).Offset(offset).Find(&videos).Error; err != nil {
|
||
return nil, 0, err
|
||
}
|
||
|
||
return videos, total, nil
|
||
}
|
||
|
||
func (s *VideoGenerationService) GenerateVideoFromImage(imageGenID uint) (*models.VideoGeneration, error) {
|
||
var imageGen models.ImageGeneration
|
||
if err := s.db.First(&imageGen, imageGenID).Error; err != nil {
|
||
return nil, fmt.Errorf("image generation not found")
|
||
}
|
||
|
||
if imageGen.Status != models.ImageStatusCompleted || imageGen.ImageURL == nil {
|
||
return nil, fmt.Errorf("image is not ready")
|
||
}
|
||
|
||
// 获取关联的Storyboard以获取时长
|
||
var duration *int
|
||
if imageGen.StoryboardID != nil {
|
||
var storyboard models.Storyboard
|
||
if err := s.db.Where("id = ?", *imageGen.StoryboardID).First(&storyboard).Error; err == nil {
|
||
duration = &storyboard.Duration
|
||
s.log.Infow("Using storyboard duration for video generation",
|
||
"storyboard_id", *imageGen.StoryboardID,
|
||
"duration", storyboard.Duration)
|
||
}
|
||
}
|
||
|
||
req := &GenerateVideoRequest{
|
||
DramaID: fmt.Sprintf("%d", imageGen.DramaID),
|
||
StoryboardID: imageGen.StoryboardID,
|
||
ImageGenID: &imageGenID,
|
||
ImageURL: *imageGen.ImageURL,
|
||
Prompt: imageGen.Prompt,
|
||
Provider: "doubao",
|
||
Duration: duration,
|
||
}
|
||
|
||
return s.GenerateVideo(req)
|
||
}
|
||
|
||
func (s *VideoGenerationService) BatchGenerateVideosForEpisode(episodeID string) ([]*models.VideoGeneration, error) {
|
||
var episode models.Episode
|
||
if err := s.db.Preload("Storyboards").Where("id = ?", episodeID).First(&episode).Error; err != nil {
|
||
return nil, fmt.Errorf("episode not found")
|
||
}
|
||
|
||
var results []*models.VideoGeneration
|
||
for _, storyboard := range episode.Storyboards {
|
||
if storyboard.ImagePrompt == nil {
|
||
continue
|
||
}
|
||
|
||
var imageGen models.ImageGeneration
|
||
if err := s.db.Where("storyboard_id = ? AND status = ?", storyboard.ID, models.ImageStatusCompleted).
|
||
Order("created_at DESC").First(&imageGen).Error; err != nil {
|
||
s.log.Warnw("No completed image for storyboard", "storyboard_id", storyboard.ID)
|
||
continue
|
||
}
|
||
|
||
videoGen, err := s.GenerateVideoFromImage(imageGen.ID)
|
||
if err != nil {
|
||
s.log.Errorw("Failed to generate video", "storyboard_id", storyboard.ID, "error", err)
|
||
continue
|
||
}
|
||
|
||
results = append(results, videoGen)
|
||
}
|
||
|
||
return results, nil
|
||
}
|
||
|
||
func (s *VideoGenerationService) DeleteVideoGeneration(id uint) error {
|
||
return s.db.Delete(&models.VideoGeneration{}, id).Error
|
||
}
|