- 新增从视频素材提取首帧/尾帧的功能,支持画面连续性编辑 - 添加阿里云OSS存储支持,可配置本地或OSS存储方式 - 导入视频素材时自动探测并更新视频时长信息 - 前端添加从素材提取尾帧的UI界面 - 添加FramePrompt模型的数据库迁移 Co-Authored-By: Claude <noreply@anthropic.com>
446 lines
13 KiB
Go
446 lines
13 KiB
Go
package services
|
||
|
||
import (
|
||
"fmt"
|
||
"os"
|
||
"path/filepath"
|
||
"strconv"
|
||
"strings"
|
||
"time"
|
||
|
||
models "github.com/drama-generator/backend/domain/models"
|
||
"github.com/drama-generator/backend/infrastructure/external/ffmpeg"
|
||
"github.com/drama-generator/backend/infrastructure/storage"
|
||
"github.com/drama-generator/backend/pkg/config"
|
||
"github.com/drama-generator/backend/pkg/logger"
|
||
"gorm.io/gorm"
|
||
)
|
||
|
||
type AssetService struct {
|
||
db *gorm.DB
|
||
log *logger.Logger
|
||
ffmpeg *ffmpeg.FFmpeg
|
||
cfg *config.Config
|
||
ossStorage *storage.OssStorage
|
||
}
|
||
|
||
func NewAssetService(db *gorm.DB, log *logger.Logger, cfg ...*config.Config) *AssetService {
|
||
service := &AssetService{
|
||
db: db,
|
||
log: log,
|
||
ffmpeg: ffmpeg.NewFFmpeg(log),
|
||
}
|
||
if len(cfg) > 0 {
|
||
service.cfg = cfg[0]
|
||
// 如果配置了 OSS,初始化 OSS 存储
|
||
if cfg[0].Storage.Type == "oss" && storage.IsOssConfigured(&cfg[0].Storage.Oss) {
|
||
ossStorage, err := storage.NewOssStorage(&cfg[0].Storage.Oss)
|
||
if err != nil {
|
||
log.Warnw("Failed to initialize OSS storage, falling back to local", "error", err)
|
||
} else {
|
||
service.ossStorage = ossStorage
|
||
log.Infow("OSS storage initialized", "bucket", cfg[0].Storage.Oss.BucketName)
|
||
}
|
||
}
|
||
}
|
||
return service
|
||
}
|
||
|
||
type CreateAssetRequest struct {
|
||
DramaID *string `json:"drama_id"`
|
||
Name string `json:"name" binding:"required"`
|
||
Description *string `json:"description"`
|
||
Type models.AssetType `json:"type" binding:"required"`
|
||
Category *string `json:"category"`
|
||
URL string `json:"url" binding:"required"`
|
||
ThumbnailURL *string `json:"thumbnail_url"`
|
||
LocalPath *string `json:"local_path"`
|
||
FileSize *int64 `json:"file_size"`
|
||
MimeType *string `json:"mime_type"`
|
||
Width *int `json:"width"`
|
||
Height *int `json:"height"`
|
||
Duration *int `json:"duration"`
|
||
Format *string `json:"format"`
|
||
ImageGenID *uint `json:"image_gen_id"`
|
||
VideoGenID *uint `json:"video_gen_id"`
|
||
TagIDs []uint `json:"tag_ids"`
|
||
}
|
||
|
||
type UpdateAssetRequest struct {
|
||
Name *string `json:"name"`
|
||
Description *string `json:"description"`
|
||
Category *string `json:"category"`
|
||
ThumbnailURL *string `json:"thumbnail_url"`
|
||
TagIDs []uint `json:"tag_ids"`
|
||
IsFavorite *bool `json:"is_favorite"`
|
||
}
|
||
|
||
type ListAssetsRequest struct {
|
||
DramaID *string `json:"drama_id"`
|
||
EpisodeID *uint `json:"episode_id"`
|
||
StoryboardID *uint `json:"storyboard_id"`
|
||
Type *models.AssetType `json:"type"`
|
||
Category string `json:"category"`
|
||
TagIDs []uint `json:"tag_ids"`
|
||
IsFavorite *bool `json:"is_favorite"`
|
||
Search string `json:"search"`
|
||
Page int `json:"page"`
|
||
PageSize int `json:"page_size"`
|
||
}
|
||
|
||
func (s *AssetService) CreateAsset(req *CreateAssetRequest) (*models.Asset, error) {
|
||
var dramaID *uint
|
||
if req.DramaID != nil && *req.DramaID != "" {
|
||
id, err := strconv.ParseUint(*req.DramaID, 10, 32)
|
||
if err == nil {
|
||
uid := uint(id)
|
||
dramaID = &uid
|
||
}
|
||
}
|
||
|
||
if dramaID != nil {
|
||
var drama models.Drama
|
||
if err := s.db.Where("id = ?", *dramaID).First(&drama).Error; err != nil {
|
||
return nil, fmt.Errorf("drama not found")
|
||
}
|
||
}
|
||
|
||
asset := &models.Asset{
|
||
DramaID: dramaID,
|
||
Name: req.Name,
|
||
Description: req.Description,
|
||
Type: req.Type,
|
||
Category: req.Category,
|
||
URL: req.URL,
|
||
ThumbnailURL: req.ThumbnailURL,
|
||
LocalPath: req.LocalPath,
|
||
FileSize: req.FileSize,
|
||
MimeType: req.MimeType,
|
||
Width: req.Width,
|
||
Height: req.Height,
|
||
Duration: req.Duration,
|
||
Format: req.Format,
|
||
ImageGenID: req.ImageGenID,
|
||
VideoGenID: req.VideoGenID,
|
||
}
|
||
|
||
if err := s.db.Create(asset).Error; err != nil {
|
||
return nil, fmt.Errorf("failed to create asset: %w", err)
|
||
}
|
||
|
||
return asset, nil
|
||
}
|
||
|
||
func (s *AssetService) UpdateAsset(assetID uint, req *UpdateAssetRequest) (*models.Asset, error) {
|
||
var asset models.Asset
|
||
if err := s.db.Where("id = ?", assetID).First(&asset).Error; err != nil {
|
||
return nil, fmt.Errorf("asset not found")
|
||
}
|
||
|
||
updates := make(map[string]interface{})
|
||
if req.Name != nil {
|
||
updates["name"] = *req.Name
|
||
}
|
||
if req.Description != nil {
|
||
updates["description"] = *req.Description
|
||
}
|
||
if req.Category != nil {
|
||
updates["category"] = *req.Category
|
||
}
|
||
if req.ThumbnailURL != nil {
|
||
updates["thumbnail_url"] = *req.ThumbnailURL
|
||
}
|
||
if req.IsFavorite != nil {
|
||
updates["is_favorite"] = *req.IsFavorite
|
||
}
|
||
|
||
if len(updates) > 0 {
|
||
if err := s.db.Model(&asset).Updates(updates).Error; err != nil {
|
||
return nil, fmt.Errorf("failed to update asset: %w", err)
|
||
}
|
||
}
|
||
|
||
if err := s.db.First(&asset, assetID).Error; err != nil {
|
||
return nil, err
|
||
}
|
||
|
||
return &asset, nil
|
||
}
|
||
|
||
func (s *AssetService) GetAsset(assetID uint) (*models.Asset, error) {
|
||
var asset models.Asset
|
||
if err := s.db.Where("id = ? ", assetID).First(&asset).Error; err != nil {
|
||
return nil, err
|
||
}
|
||
|
||
s.db.Model(&asset).UpdateColumn("view_count", gorm.Expr("view_count + ?", 1))
|
||
|
||
return &asset, nil
|
||
}
|
||
|
||
func (s *AssetService) ListAssets(req *ListAssetsRequest) ([]models.Asset, int64, error) {
|
||
query := s.db.Model(&models.Asset{})
|
||
|
||
if req.DramaID != nil {
|
||
var dramaID uint64
|
||
dramaID, _ = strconv.ParseUint(*req.DramaID, 10, 32)
|
||
query = query.Where("drama_id = ?", uint(dramaID))
|
||
}
|
||
|
||
if req.EpisodeID != nil {
|
||
query = query.Where("episode_id = ?", *req.EpisodeID)
|
||
}
|
||
|
||
if req.StoryboardID != nil {
|
||
query = query.Where("storyboard_id = ?", *req.StoryboardID)
|
||
}
|
||
|
||
if req.Type != nil {
|
||
query = query.Where("type = ?", *req.Type)
|
||
}
|
||
|
||
if req.Category != "" {
|
||
query = query.Where("category = ?", req.Category)
|
||
}
|
||
|
||
if req.IsFavorite != nil {
|
||
query = query.Where("is_favorite = ?", *req.IsFavorite)
|
||
}
|
||
|
||
if req.Search != "" {
|
||
searchTerm := "%" + strings.ToLower(req.Search) + "%"
|
||
query = query.Where("LOWER(name) LIKE ? OR LOWER(description) LIKE ?", searchTerm, searchTerm)
|
||
}
|
||
|
||
var total int64
|
||
if err := query.Count(&total).Error; err != nil {
|
||
return nil, 0, err
|
||
}
|
||
|
||
var assets []models.Asset
|
||
offset := (req.Page - 1) * req.PageSize
|
||
if err := query.Order("created_at DESC").
|
||
Offset(offset).Limit(req.PageSize).Find(&assets).Error; err != nil {
|
||
return nil, 0, err
|
||
}
|
||
|
||
return assets, total, nil
|
||
}
|
||
|
||
func (s *AssetService) DeleteAsset(assetID uint) error {
|
||
result := s.db.Where("id = ?", assetID).Delete(&models.Asset{})
|
||
if result.Error != nil {
|
||
return result.Error
|
||
}
|
||
if result.RowsAffected == 0 {
|
||
return fmt.Errorf("asset not found")
|
||
}
|
||
return nil
|
||
}
|
||
|
||
func (s *AssetService) ImportFromImageGen(imageGenID uint) (*models.Asset, error) {
|
||
var imageGen models.ImageGeneration
|
||
if err := s.db.Where("id = ? ", imageGenID).First(&imageGen).Error; err != nil {
|
||
return nil, fmt.Errorf("image generation not found")
|
||
}
|
||
|
||
if imageGen.Status != models.ImageStatusCompleted || imageGen.ImageURL == nil {
|
||
return nil, fmt.Errorf("image is not ready")
|
||
}
|
||
|
||
dramaID := imageGen.DramaID
|
||
asset := &models.Asset{
|
||
Name: fmt.Sprintf("Image_%d", imageGen.ID),
|
||
Type: models.AssetTypeImage,
|
||
URL: *imageGen.ImageURL,
|
||
DramaID: &dramaID,
|
||
ImageGenID: &imageGenID,
|
||
Width: imageGen.Width,
|
||
Height: imageGen.Height,
|
||
}
|
||
|
||
if err := s.db.Create(asset).Error; err != nil {
|
||
return nil, fmt.Errorf("failed to create asset: %w", err)
|
||
}
|
||
|
||
return asset, nil
|
||
}
|
||
|
||
func (s *AssetService) ImportFromVideoGen(videoGenID uint) (*models.Asset, error) {
|
||
var videoGen models.VideoGeneration
|
||
if err := s.db.Preload("Storyboard.Episode").Where("id = ? ", videoGenID).First(&videoGen).Error; err != nil {
|
||
return nil, fmt.Errorf("video generation not found")
|
||
}
|
||
|
||
if videoGen.Status != models.VideoStatusCompleted || videoGen.VideoURL == nil {
|
||
return nil, fmt.Errorf("video is not ready")
|
||
}
|
||
|
||
dramaID := videoGen.DramaID
|
||
|
||
var episodeID *uint
|
||
var storyboardNum *int
|
||
if videoGen.Storyboard != nil {
|
||
episodeID = &videoGen.Storyboard.Episode.ID
|
||
storyboardNum = &videoGen.Storyboard.StoryboardNumber
|
||
}
|
||
|
||
// 如果 duration 为空,尝试使用 FFmpeg 探测
|
||
duration := videoGen.Duration
|
||
if duration == nil || *duration == 0 {
|
||
s.log.Infow("Duration is empty, probing video duration", "video_gen_id", videoGenID)
|
||
probedDuration, err := s.ffmpeg.GetVideoDuration(*videoGen.VideoURL)
|
||
if err == nil && probedDuration > 0 {
|
||
durationInt := int(probedDuration + 0.5) // 四舍五入
|
||
duration = &durationInt
|
||
s.log.Infow("Probed video duration", "video_gen_id", videoGenID, "duration", durationInt)
|
||
|
||
// 同时更新 VideoGeneration 表
|
||
s.db.Model(&videoGen).Update("duration", durationInt)
|
||
} else {
|
||
s.log.Warnw("Failed to probe video duration", "video_gen_id", videoGenID, "error", err)
|
||
}
|
||
}
|
||
|
||
asset := &models.Asset{
|
||
Name: fmt.Sprintf("Video_%d", videoGen.ID),
|
||
Type: models.AssetTypeVideo,
|
||
URL: *videoGen.VideoURL,
|
||
DramaID: &dramaID,
|
||
EpisodeID: episodeID,
|
||
StoryboardID: videoGen.StoryboardID,
|
||
StoryboardNum: storyboardNum,
|
||
VideoGenID: &videoGenID,
|
||
Duration: duration,
|
||
Width: videoGen.Width,
|
||
Height: videoGen.Height,
|
||
}
|
||
|
||
if videoGen.FirstFrameURL != nil {
|
||
asset.ThumbnailURL = videoGen.FirstFrameURL
|
||
}
|
||
|
||
if err := s.db.Create(asset).Error; err != nil {
|
||
return nil, fmt.Errorf("failed to create asset: %w\n", err)
|
||
}
|
||
|
||
return asset, nil
|
||
}
|
||
|
||
// ExtractFrameRequest 视频帧提取请求
|
||
type ExtractFrameRequest struct {
|
||
Position string `json:"position"` // "first" 或 "last"
|
||
StoryboardID uint `json:"storyboard_id"` // 关联的分镜 ID
|
||
FrameType string `json:"frame_type"` // image_generations 的帧类型,默认 "first"
|
||
}
|
||
|
||
// ExtractFrameFromAsset 从视频素材中提取帧并保存到 image_generations 表
|
||
func (s *AssetService) ExtractFrameFromAsset(assetID uint, req *ExtractFrameRequest) (*models.ImageGeneration, error) {
|
||
// 获取素材
|
||
var asset models.Asset
|
||
if err := s.db.First(&asset, assetID).Error; err != nil {
|
||
return nil, fmt.Errorf("asset not found: %w", err)
|
||
}
|
||
|
||
// 验证是否为视频素材
|
||
if asset.Type != models.AssetTypeVideo {
|
||
return nil, fmt.Errorf("asset is not a video")
|
||
}
|
||
|
||
// 默认值
|
||
position := req.Position
|
||
if position == "" {
|
||
position = "last"
|
||
}
|
||
frameType := req.FrameType
|
||
if frameType == "" {
|
||
frameType = "first" // 提取的尾帧用作下一个分镜的首帧
|
||
}
|
||
|
||
// 生成唯一文件名
|
||
timestamp := time.Now().Format("20060102_150405")
|
||
fileName := fmt.Sprintf("frame_%s_%d_%s.jpg", position, assetID, timestamp)
|
||
|
||
// 创建临时目录
|
||
tmpDir := "./data/tmp"
|
||
if err := os.MkdirAll(tmpDir, 0755); err != nil {
|
||
return nil, fmt.Errorf("failed to create temp directory: %w", err)
|
||
}
|
||
outputPath := filepath.Join(tmpDir, fileName)
|
||
|
||
// 使用 FFmpeg 提取帧
|
||
_, err := s.ffmpeg.ExtractFrame(asset.URL, outputPath, position)
|
||
if err != nil {
|
||
return nil, fmt.Errorf("failed to extract frame: %w", err)
|
||
}
|
||
|
||
var imageURL string
|
||
|
||
// 根据存储类型上传
|
||
if s.ossStorage != nil {
|
||
// 上传到 OSS
|
||
url, err := s.ossStorage.UploadWithFilename(outputPath, "extracted_frames", fileName)
|
||
if err != nil {
|
||
s.log.Errorw("Failed to upload to OSS, falling back to local", "error", err)
|
||
} else {
|
||
imageURL = url
|
||
s.log.Infow("Frame uploaded to OSS", "url", url)
|
||
// 删除临时文件
|
||
os.Remove(outputPath)
|
||
}
|
||
}
|
||
|
||
// 如果 OSS 上传失败或未配置,使用本地存储
|
||
if imageURL == "" {
|
||
localPath := "./data/storage"
|
||
baseURL := "/static"
|
||
if s.cfg != nil && s.cfg.Storage.LocalPath != "" {
|
||
localPath = s.cfg.Storage.LocalPath
|
||
}
|
||
if s.cfg != nil && s.cfg.Storage.BaseURL != "" {
|
||
baseURL = s.cfg.Storage.BaseURL
|
||
}
|
||
|
||
// 移动文件到最终位置
|
||
finalDir := filepath.Join(localPath, "extracted_frames")
|
||
if err := os.MkdirAll(finalDir, 0755); err != nil {
|
||
return nil, fmt.Errorf("failed to create output directory: %w", err)
|
||
}
|
||
finalPath := filepath.Join(finalDir, fileName)
|
||
if err := os.Rename(outputPath, finalPath); err != nil {
|
||
return nil, fmt.Errorf("failed to move file: %w", err)
|
||
}
|
||
|
||
imageURL = fmt.Sprintf("%s/extracted_frames/%s", baseURL, fileName)
|
||
}
|
||
|
||
// 获取 DramaID
|
||
var dramaID uint
|
||
if asset.DramaID != nil {
|
||
dramaID = *asset.DramaID
|
||
}
|
||
|
||
// 创建 image_generation 记录
|
||
imageGen := &models.ImageGeneration{
|
||
DramaID: dramaID,
|
||
StoryboardID: &req.StoryboardID,
|
||
Prompt: fmt.Sprintf("Extracted %s frame from video asset #%d", position, assetID),
|
||
Status: models.ImageStatusCompleted,
|
||
ImageURL: &imageURL,
|
||
FrameType: &frameType,
|
||
}
|
||
|
||
if err := s.db.Create(imageGen).Error; err != nil {
|
||
return nil, fmt.Errorf("failed to create image generation record: %w", err)
|
||
}
|
||
|
||
s.log.Infow("Frame extracted and saved",
|
||
"asset_id", assetID,
|
||
"position", position,
|
||
"image_gen_id", imageGen.ID,
|
||
"frame_type", frameType,
|
||
"image_url", imageURL)
|
||
|
||
return imageGen, nil
|
||
}
|