Files
huobao-drama/application/services/asset_service.go
empty d970107a34 添加视频帧提取功能和阿里云OSS存储支持
- 新增从视频素材提取首帧/尾帧的功能,支持画面连续性编辑
- 添加阿里云OSS存储支持,可配置本地或OSS存储方式
- 导入视频素材时自动探测并更新视频时长信息
- 前端添加从素材提取尾帧的UI界面
- 添加FramePrompt模型的数据库迁移

Co-Authored-By: Claude <noreply@anthropic.com>
2026-01-18 21:44:39 +08:00

446 lines
13 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
package services
import (
"fmt"
"os"
"path/filepath"
"strconv"
"strings"
"time"
models "github.com/drama-generator/backend/domain/models"
"github.com/drama-generator/backend/infrastructure/external/ffmpeg"
"github.com/drama-generator/backend/infrastructure/storage"
"github.com/drama-generator/backend/pkg/config"
"github.com/drama-generator/backend/pkg/logger"
"gorm.io/gorm"
)
type AssetService struct {
db *gorm.DB
log *logger.Logger
ffmpeg *ffmpeg.FFmpeg
cfg *config.Config
ossStorage *storage.OssStorage
}
func NewAssetService(db *gorm.DB, log *logger.Logger, cfg ...*config.Config) *AssetService {
service := &AssetService{
db: db,
log: log,
ffmpeg: ffmpeg.NewFFmpeg(log),
}
if len(cfg) > 0 {
service.cfg = cfg[0]
// 如果配置了 OSS初始化 OSS 存储
if cfg[0].Storage.Type == "oss" && storage.IsOssConfigured(&cfg[0].Storage.Oss) {
ossStorage, err := storage.NewOssStorage(&cfg[0].Storage.Oss)
if err != nil {
log.Warnw("Failed to initialize OSS storage, falling back to local", "error", err)
} else {
service.ossStorage = ossStorage
log.Infow("OSS storage initialized", "bucket", cfg[0].Storage.Oss.BucketName)
}
}
}
return service
}
type CreateAssetRequest struct {
DramaID *string `json:"drama_id"`
Name string `json:"name" binding:"required"`
Description *string `json:"description"`
Type models.AssetType `json:"type" binding:"required"`
Category *string `json:"category"`
URL string `json:"url" binding:"required"`
ThumbnailURL *string `json:"thumbnail_url"`
LocalPath *string `json:"local_path"`
FileSize *int64 `json:"file_size"`
MimeType *string `json:"mime_type"`
Width *int `json:"width"`
Height *int `json:"height"`
Duration *int `json:"duration"`
Format *string `json:"format"`
ImageGenID *uint `json:"image_gen_id"`
VideoGenID *uint `json:"video_gen_id"`
TagIDs []uint `json:"tag_ids"`
}
type UpdateAssetRequest struct {
Name *string `json:"name"`
Description *string `json:"description"`
Category *string `json:"category"`
ThumbnailURL *string `json:"thumbnail_url"`
TagIDs []uint `json:"tag_ids"`
IsFavorite *bool `json:"is_favorite"`
}
type ListAssetsRequest struct {
DramaID *string `json:"drama_id"`
EpisodeID *uint `json:"episode_id"`
StoryboardID *uint `json:"storyboard_id"`
Type *models.AssetType `json:"type"`
Category string `json:"category"`
TagIDs []uint `json:"tag_ids"`
IsFavorite *bool `json:"is_favorite"`
Search string `json:"search"`
Page int `json:"page"`
PageSize int `json:"page_size"`
}
func (s *AssetService) CreateAsset(req *CreateAssetRequest) (*models.Asset, error) {
var dramaID *uint
if req.DramaID != nil && *req.DramaID != "" {
id, err := strconv.ParseUint(*req.DramaID, 10, 32)
if err == nil {
uid := uint(id)
dramaID = &uid
}
}
if dramaID != nil {
var drama models.Drama
if err := s.db.Where("id = ?", *dramaID).First(&drama).Error; err != nil {
return nil, fmt.Errorf("drama not found")
}
}
asset := &models.Asset{
DramaID: dramaID,
Name: req.Name,
Description: req.Description,
Type: req.Type,
Category: req.Category,
URL: req.URL,
ThumbnailURL: req.ThumbnailURL,
LocalPath: req.LocalPath,
FileSize: req.FileSize,
MimeType: req.MimeType,
Width: req.Width,
Height: req.Height,
Duration: req.Duration,
Format: req.Format,
ImageGenID: req.ImageGenID,
VideoGenID: req.VideoGenID,
}
if err := s.db.Create(asset).Error; err != nil {
return nil, fmt.Errorf("failed to create asset: %w", err)
}
return asset, nil
}
func (s *AssetService) UpdateAsset(assetID uint, req *UpdateAssetRequest) (*models.Asset, error) {
var asset models.Asset
if err := s.db.Where("id = ?", assetID).First(&asset).Error; err != nil {
return nil, fmt.Errorf("asset not found")
}
updates := make(map[string]interface{})
if req.Name != nil {
updates["name"] = *req.Name
}
if req.Description != nil {
updates["description"] = *req.Description
}
if req.Category != nil {
updates["category"] = *req.Category
}
if req.ThumbnailURL != nil {
updates["thumbnail_url"] = *req.ThumbnailURL
}
if req.IsFavorite != nil {
updates["is_favorite"] = *req.IsFavorite
}
if len(updates) > 0 {
if err := s.db.Model(&asset).Updates(updates).Error; err != nil {
return nil, fmt.Errorf("failed to update asset: %w", err)
}
}
if err := s.db.First(&asset, assetID).Error; err != nil {
return nil, err
}
return &asset, nil
}
func (s *AssetService) GetAsset(assetID uint) (*models.Asset, error) {
var asset models.Asset
if err := s.db.Where("id = ? ", assetID).First(&asset).Error; err != nil {
return nil, err
}
s.db.Model(&asset).UpdateColumn("view_count", gorm.Expr("view_count + ?", 1))
return &asset, nil
}
func (s *AssetService) ListAssets(req *ListAssetsRequest) ([]models.Asset, int64, error) {
query := s.db.Model(&models.Asset{})
if req.DramaID != nil {
var dramaID uint64
dramaID, _ = strconv.ParseUint(*req.DramaID, 10, 32)
query = query.Where("drama_id = ?", uint(dramaID))
}
if req.EpisodeID != nil {
query = query.Where("episode_id = ?", *req.EpisodeID)
}
if req.StoryboardID != nil {
query = query.Where("storyboard_id = ?", *req.StoryboardID)
}
if req.Type != nil {
query = query.Where("type = ?", *req.Type)
}
if req.Category != "" {
query = query.Where("category = ?", req.Category)
}
if req.IsFavorite != nil {
query = query.Where("is_favorite = ?", *req.IsFavorite)
}
if req.Search != "" {
searchTerm := "%" + strings.ToLower(req.Search) + "%"
query = query.Where("LOWER(name) LIKE ? OR LOWER(description) LIKE ?", searchTerm, searchTerm)
}
var total int64
if err := query.Count(&total).Error; err != nil {
return nil, 0, err
}
var assets []models.Asset
offset := (req.Page - 1) * req.PageSize
if err := query.Order("created_at DESC").
Offset(offset).Limit(req.PageSize).Find(&assets).Error; err != nil {
return nil, 0, err
}
return assets, total, nil
}
func (s *AssetService) DeleteAsset(assetID uint) error {
result := s.db.Where("id = ?", assetID).Delete(&models.Asset{})
if result.Error != nil {
return result.Error
}
if result.RowsAffected == 0 {
return fmt.Errorf("asset not found")
}
return nil
}
func (s *AssetService) ImportFromImageGen(imageGenID uint) (*models.Asset, error) {
var imageGen models.ImageGeneration
if err := s.db.Where("id = ? ", imageGenID).First(&imageGen).Error; err != nil {
return nil, fmt.Errorf("image generation not found")
}
if imageGen.Status != models.ImageStatusCompleted || imageGen.ImageURL == nil {
return nil, fmt.Errorf("image is not ready")
}
dramaID := imageGen.DramaID
asset := &models.Asset{
Name: fmt.Sprintf("Image_%d", imageGen.ID),
Type: models.AssetTypeImage,
URL: *imageGen.ImageURL,
DramaID: &dramaID,
ImageGenID: &imageGenID,
Width: imageGen.Width,
Height: imageGen.Height,
}
if err := s.db.Create(asset).Error; err != nil {
return nil, fmt.Errorf("failed to create asset: %w", err)
}
return asset, nil
}
func (s *AssetService) ImportFromVideoGen(videoGenID uint) (*models.Asset, error) {
var videoGen models.VideoGeneration
if err := s.db.Preload("Storyboard.Episode").Where("id = ? ", videoGenID).First(&videoGen).Error; err != nil {
return nil, fmt.Errorf("video generation not found")
}
if videoGen.Status != models.VideoStatusCompleted || videoGen.VideoURL == nil {
return nil, fmt.Errorf("video is not ready")
}
dramaID := videoGen.DramaID
var episodeID *uint
var storyboardNum *int
if videoGen.Storyboard != nil {
episodeID = &videoGen.Storyboard.Episode.ID
storyboardNum = &videoGen.Storyboard.StoryboardNumber
}
// 如果 duration 为空,尝试使用 FFmpeg 探测
duration := videoGen.Duration
if duration == nil || *duration == 0 {
s.log.Infow("Duration is empty, probing video duration", "video_gen_id", videoGenID)
probedDuration, err := s.ffmpeg.GetVideoDuration(*videoGen.VideoURL)
if err == nil && probedDuration > 0 {
durationInt := int(probedDuration + 0.5) // 四舍五入
duration = &durationInt
s.log.Infow("Probed video duration", "video_gen_id", videoGenID, "duration", durationInt)
// 同时更新 VideoGeneration 表
s.db.Model(&videoGen).Update("duration", durationInt)
} else {
s.log.Warnw("Failed to probe video duration", "video_gen_id", videoGenID, "error", err)
}
}
asset := &models.Asset{
Name: fmt.Sprintf("Video_%d", videoGen.ID),
Type: models.AssetTypeVideo,
URL: *videoGen.VideoURL,
DramaID: &dramaID,
EpisodeID: episodeID,
StoryboardID: videoGen.StoryboardID,
StoryboardNum: storyboardNum,
VideoGenID: &videoGenID,
Duration: duration,
Width: videoGen.Width,
Height: videoGen.Height,
}
if videoGen.FirstFrameURL != nil {
asset.ThumbnailURL = videoGen.FirstFrameURL
}
if err := s.db.Create(asset).Error; err != nil {
return nil, fmt.Errorf("failed to create asset: %w\n", err)
}
return asset, nil
}
// ExtractFrameRequest 视频帧提取请求
type ExtractFrameRequest struct {
Position string `json:"position"` // "first" 或 "last"
StoryboardID uint `json:"storyboard_id"` // 关联的分镜 ID
FrameType string `json:"frame_type"` // image_generations 的帧类型,默认 "first"
}
// ExtractFrameFromAsset 从视频素材中提取帧并保存到 image_generations 表
func (s *AssetService) ExtractFrameFromAsset(assetID uint, req *ExtractFrameRequest) (*models.ImageGeneration, error) {
// 获取素材
var asset models.Asset
if err := s.db.First(&asset, assetID).Error; err != nil {
return nil, fmt.Errorf("asset not found: %w", err)
}
// 验证是否为视频素材
if asset.Type != models.AssetTypeVideo {
return nil, fmt.Errorf("asset is not a video")
}
// 默认值
position := req.Position
if position == "" {
position = "last"
}
frameType := req.FrameType
if frameType == "" {
frameType = "first" // 提取的尾帧用作下一个分镜的首帧
}
// 生成唯一文件名
timestamp := time.Now().Format("20060102_150405")
fileName := fmt.Sprintf("frame_%s_%d_%s.jpg", position, assetID, timestamp)
// 创建临时目录
tmpDir := "./data/tmp"
if err := os.MkdirAll(tmpDir, 0755); err != nil {
return nil, fmt.Errorf("failed to create temp directory: %w", err)
}
outputPath := filepath.Join(tmpDir, fileName)
// 使用 FFmpeg 提取帧
_, err := s.ffmpeg.ExtractFrame(asset.URL, outputPath, position)
if err != nil {
return nil, fmt.Errorf("failed to extract frame: %w", err)
}
var imageURL string
// 根据存储类型上传
if s.ossStorage != nil {
// 上传到 OSS
url, err := s.ossStorage.UploadWithFilename(outputPath, "extracted_frames", fileName)
if err != nil {
s.log.Errorw("Failed to upload to OSS, falling back to local", "error", err)
} else {
imageURL = url
s.log.Infow("Frame uploaded to OSS", "url", url)
// 删除临时文件
os.Remove(outputPath)
}
}
// 如果 OSS 上传失败或未配置,使用本地存储
if imageURL == "" {
localPath := "./data/storage"
baseURL := "/static"
if s.cfg != nil && s.cfg.Storage.LocalPath != "" {
localPath = s.cfg.Storage.LocalPath
}
if s.cfg != nil && s.cfg.Storage.BaseURL != "" {
baseURL = s.cfg.Storage.BaseURL
}
// 移动文件到最终位置
finalDir := filepath.Join(localPath, "extracted_frames")
if err := os.MkdirAll(finalDir, 0755); err != nil {
return nil, fmt.Errorf("failed to create output directory: %w", err)
}
finalPath := filepath.Join(finalDir, fileName)
if err := os.Rename(outputPath, finalPath); err != nil {
return nil, fmt.Errorf("failed to move file: %w", err)
}
imageURL = fmt.Sprintf("%s/extracted_frames/%s", baseURL, fileName)
}
// 获取 DramaID
var dramaID uint
if asset.DramaID != nil {
dramaID = *asset.DramaID
}
// 创建 image_generation 记录
imageGen := &models.ImageGeneration{
DramaID: dramaID,
StoryboardID: &req.StoryboardID,
Prompt: fmt.Sprintf("Extracted %s frame from video asset #%d", position, assetID),
Status: models.ImageStatusCompleted,
ImageURL: &imageURL,
FrameType: &frameType,
}
if err := s.db.Create(imageGen).Error; err != nil {
return nil, fmt.Errorf("failed to create image generation record: %w", err)
}
s.log.Infow("Frame extracted and saved",
"asset_id", assetID,
"position", position,
"image_gen_id", imageGen.ID,
"frame_type", frameType,
"image_url", imageURL)
return imageGen, nil
}