Files
huobao-drama/application/services/drama_service.go
Connor 9600fc542c init
2026-01-12 13:17:11 +08:00

631 lines
18 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
package services
import (
"encoding/json"
"errors"
"fmt"
"strconv"
"time"
"github.com/drama-generator/backend/domain/models"
"github.com/drama-generator/backend/pkg/logger"
"gorm.io/gorm"
)
type DramaService struct {
db *gorm.DB
log *logger.Logger
}
func NewDramaService(db *gorm.DB, log *logger.Logger) *DramaService {
return &DramaService{
db: db,
log: log,
}
}
type CreateDramaRequest struct {
Title string `json:"title" binding:"required,min=1,max=100"`
Description string `json:"description"`
Genre string `json:"genre"`
Tags string `json:"tags"`
}
type UpdateDramaRequest struct {
Title string `json:"title" binding:"omitempty,min=1,max=100"`
Description string `json:"description"`
Genre string `json:"genre"`
Tags string `json:"tags"`
Status string `json:"status" binding:"omitempty,oneof=draft planning production completed archived"`
}
type DramaListQuery struct {
Page int `form:"page,default=1"`
PageSize int `form:"page_size,default=20"`
Status string `form:"status"`
Genre string `form:"genre"`
Keyword string `form:"keyword"`
}
func (s *DramaService) CreateDrama(req *CreateDramaRequest) (*models.Drama, error) {
drama := &models.Drama{
Title: req.Title,
Status: "draft",
}
if req.Description != "" {
drama.Description = &req.Description
}
if req.Genre != "" {
drama.Genre = &req.Genre
}
if err := s.db.Create(drama).Error; err != nil {
s.log.Errorw("Failed to create drama", "error", err)
return nil, err
}
s.log.Infow("Drama created", "drama_id", drama.ID)
return drama, nil
}
func (s *DramaService) GetDrama(dramaID string) (*models.Drama, error) {
var drama models.Drama
err := s.db.Where("id = ? ", dramaID).
Preload("Characters"). // 加载Drama级别的角色
Preload("Scenes"). // 加载Drama级别的场景
Preload("Episodes.Characters"). // 加载每个章节关联的角色
Preload("Episodes.Scenes"). // 加载每个章节关联的场景
Preload("Episodes.Storyboards", func(db *gorm.DB) *gorm.DB {
return db.Order("storyboards.storyboard_number ASC")
}).
First(&drama).Error
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, errors.New("drama not found")
}
s.log.Errorw("Failed to get drama", "error", err)
return nil, err
}
// 统计每个剧集的时长(基于场景时长之和)
for i := range drama.Episodes {
totalDuration := 0
for _, scene := range drama.Episodes[i].Storyboards {
totalDuration += scene.Duration
}
// 更新剧集时长(秒转分钟,向上取整)
durationMinutes := (totalDuration + 59) / 60
drama.Episodes[i].Duration = durationMinutes
// 如果数据库中的时长与计算的不一致,更新数据库
if drama.Episodes[i].Duration != durationMinutes {
s.db.Model(&models.Episode{}).Where("id = ?", drama.Episodes[i].ID).Update("duration", durationMinutes)
}
// 查询角色的图片生成状态
for j := range drama.Episodes[i].Characters {
var imageGen models.ImageGeneration
err := s.db.Where("character_id = ? AND (status = ? OR status = ?)",
drama.Episodes[i].Characters[j].ID, "pending", "processing").
Order("created_at DESC").
First(&imageGen).Error
if err == nil {
// 找到生成中的记录,设置状态
statusStr := string(imageGen.Status)
drama.Episodes[i].Characters[j].ImageGenerationStatus = &statusStr
if imageGen.ErrorMsg != nil {
drama.Episodes[i].Characters[j].ImageGenerationError = imageGen.ErrorMsg
}
} else if errors.Is(err, gorm.ErrRecordNotFound) {
// 检查是否有失败的记录
err := s.db.Where("character_id = ? AND status = ?",
drama.Episodes[i].Characters[j].ID, "failed").
Order("created_at DESC").
First(&imageGen).Error
if err == nil {
statusStr := string(imageGen.Status)
drama.Episodes[i].Characters[j].ImageGenerationStatus = &statusStr
if imageGen.ErrorMsg != nil {
drama.Episodes[i].Characters[j].ImageGenerationError = imageGen.ErrorMsg
}
}
}
}
// 查询场景的图片生成状态
for j := range drama.Episodes[i].Scenes {
var imageGen models.ImageGeneration
err := s.db.Where("scene_id = ? AND (status = ? OR status = ?)",
drama.Episodes[i].Scenes[j].ID, "pending", "processing").
Order("created_at DESC").
First(&imageGen).Error
if err == nil {
// 找到生成中的记录,设置状态
statusStr := string(imageGen.Status)
drama.Episodes[i].Scenes[j].ImageGenerationStatus = &statusStr
if imageGen.ErrorMsg != nil {
drama.Episodes[i].Scenes[j].ImageGenerationError = imageGen.ErrorMsg
}
} else if errors.Is(err, gorm.ErrRecordNotFound) {
// 检查是否有失败的记录
err := s.db.Where("scene_id = ? AND status = ?",
drama.Episodes[i].Scenes[j].ID, "failed").
Order("created_at DESC").
First(&imageGen).Error
if err == nil {
statusStr := string(imageGen.Status)
drama.Episodes[i].Scenes[j].ImageGenerationStatus = &statusStr
if imageGen.ErrorMsg != nil {
drama.Episodes[i].Scenes[j].ImageGenerationError = imageGen.ErrorMsg
}
}
}
}
}
// 整合所有剧集的场景到Drama级别的Scenes字段
sceneMap := make(map[uint]*models.Scene) // 用于去重
for i := range drama.Episodes {
for j := range drama.Episodes[i].Scenes {
scene := &drama.Episodes[i].Scenes[j]
sceneMap[scene.ID] = scene
}
}
// 将整合的场景添加到drama.Scenes
drama.Scenes = make([]models.Scene, 0, len(sceneMap))
for _, scene := range sceneMap {
drama.Scenes = append(drama.Scenes, *scene)
}
return &drama, nil
}
func (s *DramaService) ListDramas(query *DramaListQuery) ([]models.Drama, int64, error) {
var dramas []models.Drama
var total int64
db := s.db.Model(&models.Drama{})
if query.Status != "" {
db = db.Where("status = ?", query.Status)
}
if query.Genre != "" {
db = db.Where("genre = ?", query.Genre)
}
if query.Keyword != "" {
db = db.Where("title LIKE ? OR description LIKE ?", "%"+query.Keyword+"%", "%"+query.Keyword+"%")
}
if err := db.Count(&total).Error; err != nil {
s.log.Errorw("Failed to count dramas", "error", err)
return nil, 0, err
}
offset := (query.Page - 1) * query.PageSize
err := db.Order("updated_at DESC").
Offset(offset).
Limit(query.PageSize).
Preload("Episodes.Storyboards", func(db *gorm.DB) *gorm.DB {
return db.Order("storyboards.storyboard_number ASC")
}).
Find(&dramas).Error
if err != nil {
s.log.Errorw("Failed to list dramas", "error", err)
return nil, 0, err
}
// 统计每个剧本的每个剧集的时长(基于场景时长之和)
for i := range dramas {
for j := range dramas[i].Episodes {
totalDuration := 0
for _, scene := range dramas[i].Episodes[j].Storyboards {
totalDuration += scene.Duration
}
// 更新剧集时长(秒转分钟,向上取整)
durationMinutes := (totalDuration + 59) / 60
dramas[i].Episodes[j].Duration = durationMinutes
}
}
return dramas, total, nil
}
func (s *DramaService) UpdateDrama(dramaID string, req *UpdateDramaRequest) (*models.Drama, error) {
var drama models.Drama
if err := s.db.Where("id = ? ", dramaID).First(&drama).Error; err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, errors.New("drama not found")
}
return nil, err
}
updates := make(map[string]interface{})
if req.Title != "" {
updates["title"] = req.Title
}
if req.Description != "" {
updates["description"] = req.Description
}
if req.Genre != "" {
updates["genre"] = req.Genre
}
if req.Tags != "" {
updates["tags"] = req.Tags
}
if req.Status != "" {
updates["status"] = req.Status
}
updates["updated_at"] = time.Now()
if err := s.db.Model(&drama).Updates(updates).Error; err != nil {
s.log.Errorw("Failed to update drama", "error", err)
return nil, err
}
s.log.Infow("Drama updated", "drama_id", dramaID)
return &drama, nil
}
func (s *DramaService) DeleteDrama(dramaID string) error {
result := s.db.Where("id = ? ", dramaID).Delete(&models.Drama{})
if result.Error != nil {
s.log.Errorw("Failed to delete drama", "error", result.Error)
return result.Error
}
if result.RowsAffected == 0 {
return errors.New("drama not found")
}
s.log.Infow("Drama deleted", "drama_id", dramaID)
return nil
}
func (s *DramaService) GetDramaStats() (map[string]interface{}, error) {
var total int64
var byStatus []struct {
Status string
Count int64
}
if err := s.db.Model(&models.Drama{}).Count(&total).Error; err != nil {
return nil, err
}
if err := s.db.Model(&models.Drama{}).
Select("status, count(*) as count").
Group("status").
Scan(&byStatus).Error; err != nil {
return nil, err
}
stats := map[string]interface{}{
"total": total,
"by_status": byStatus,
}
return stats, nil
}
type SaveOutlineRequest struct {
Title string `json:"title" binding:"required"`
Summary string `json:"summary" binding:"required"`
Genre string `json:"genre"`
Tags []string `json:"tags"`
}
type SaveCharactersRequest struct {
Characters []models.Character `json:"characters" binding:"required"`
EpisodeID *uint `json:"episode_id"` // 可选:如果提供则关联到指定章节
}
type SaveProgressRequest struct {
CurrentStep string `json:"current_step" binding:"required"`
StepData map[string]interface{} `json:"step_data"`
}
type SaveEpisodesRequest struct {
Episodes []models.Episode `json:"episodes" binding:"required"`
}
func (s *DramaService) SaveOutline(dramaID string, req *SaveOutlineRequest) error {
var drama models.Drama
if err := s.db.Where("id = ? ", dramaID).First(&drama).Error; err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return errors.New("drama not found")
}
return err
}
updates := map[string]interface{}{
"title": req.Title,
"description": req.Summary,
"updated_at": time.Now(),
}
if req.Genre != "" {
updates["genre"] = req.Genre
}
if len(req.Tags) > 0 {
tagsJSON, err := json.Marshal(req.Tags)
if err != nil {
s.log.Errorw("Failed to marshal tags", "error", err)
return err
}
updates["tags"] = tagsJSON
}
if err := s.db.Model(&drama).Updates(updates).Error; err != nil {
s.log.Errorw("Failed to save outline", "error", err)
return err
}
s.log.Infow("Outline saved", "drama_id", dramaID)
return nil
}
func (s *DramaService) GetCharacters(dramaID string, episodeID *string) ([]models.Character, error) {
var drama models.Drama
if err := s.db.Where("id = ? ", dramaID).First(&drama).Error; err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, errors.New("drama not found")
}
return nil, err
}
var characters []models.Character
// 如果指定了episodeID只获取该章节关联的角色
if episodeID != nil {
var episode models.Episode
if err := s.db.Preload("Characters").Where("id = ? AND drama_id = ?", *episodeID, dramaID).First(&episode).Error; err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, errors.New("episode not found")
}
return nil, err
}
characters = episode.Characters
} else {
// 如果没有指定episodeID获取项目的所有角色
if err := s.db.Where("drama_id = ?", dramaID).Find(&characters).Error; err != nil {
s.log.Errorw("Failed to get characters", "error", err)
return nil, err
}
}
// 查询每个角色的图片生成任务状态
for i := range characters {
// 查询该角色最新的图片生成任务
var imageGen models.ImageGeneration
err := s.db.Where("character_id = ?", characters[i].ID).
Order("created_at DESC").
First(&imageGen).Error
if err == nil {
// 如果有进行中的任务,填充状态信息
if imageGen.Status == models.ImageStatusPending || imageGen.Status == models.ImageStatusProcessing {
statusStr := string(imageGen.Status)
characters[i].ImageGenerationStatus = &statusStr
} else if imageGen.Status == models.ImageStatusFailed {
statusStr := "failed"
characters[i].ImageGenerationStatus = &statusStr
if imageGen.ErrorMsg != nil {
characters[i].ImageGenerationError = imageGen.ErrorMsg
}
}
}
}
return characters, nil
}
func (s *DramaService) SaveCharacters(dramaID string, req *SaveCharactersRequest) error {
// 转换dramaID
id, err := strconv.ParseUint(dramaID, 10, 32)
if err != nil {
return fmt.Errorf("invalid drama ID")
}
dramaIDUint := uint(id)
var drama models.Drama
if err := s.db.Where("id = ? ", dramaIDUint).First(&drama).Error; err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return errors.New("drama not found")
}
return err
}
// 如果指定了EpisodeID验证章节存在性
if req.EpisodeID != nil {
var episode models.Episode
if err := s.db.Where("id = ? AND drama_id = ?", *req.EpisodeID, dramaIDUint).First(&episode).Error; err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return errors.New("episode not found")
}
return err
}
}
// 获取该项目已存在的所有角色
var existingCharacters []models.Character
if err := s.db.Where("drama_id = ?", dramaIDUint).Find(&existingCharacters).Error; err != nil {
s.log.Errorw("Failed to get existing characters", "error", err)
return err
}
// 创建角色名称到角色的映射
existingCharMap := make(map[string]*models.Character)
for i := range existingCharacters {
existingCharMap[existingCharacters[i].Name] = &existingCharacters[i]
}
// 收集需要关联到章节的角色ID
var characterIDs []uint
// 创建新角色或复用已有角色
for _, char := range req.Characters {
if existingChar, exists := existingCharMap[char.Name]; exists {
// 角色已存在,直接复用
s.log.Infow("Character already exists, reusing", "name", char.Name, "character_id", existingChar.ID)
characterIDs = append(characterIDs, existingChar.ID)
continue
}
// 角色不存在,创建新角色
character := models.Character{
DramaID: dramaIDUint,
Name: char.Name,
Role: char.Role,
Description: char.Description,
Personality: char.Personality,
Appearance: char.Appearance,
}
if err := s.db.Create(&character).Error; err != nil {
s.log.Errorw("Failed to create character", "error", err, "name", char.Name)
continue
}
s.log.Infow("New character created", "character_id", character.ID, "name", char.Name)
characterIDs = append(characterIDs, character.ID)
}
// 如果指定了EpisodeID建立角色与章节的关联
if req.EpisodeID != nil && len(characterIDs) > 0 {
var episode models.Episode
if err := s.db.First(&episode, *req.EpisodeID).Error; err != nil {
return err
}
// 获取角色对象
var characters []models.Character
if err := s.db.Where("id IN ?", characterIDs).Find(&characters).Error; err != nil {
s.log.Errorw("Failed to get characters", "error", err)
return err
}
// 使用GORM的Association API建立多对多关系会自动去重
if err := s.db.Model(&episode).Association("Characters").Append(&characters); err != nil {
s.log.Errorw("Failed to associate characters with episode", "error", err)
return err
}
s.log.Infow("Characters associated with episode", "episode_id", *req.EpisodeID, "character_count", len(characterIDs))
}
if err := s.db.Model(&drama).Update("updated_at", time.Now()).Error; err != nil {
s.log.Errorw("Failed to update drama timestamp", "error", err)
}
s.log.Infow("Characters saved", "drama_id", dramaID, "count", len(req.Characters))
return nil
}
func (s *DramaService) SaveEpisodes(dramaID string, req *SaveEpisodesRequest) error {
// 转换dramaID
id, err := strconv.ParseUint(dramaID, 10, 32)
if err != nil {
return fmt.Errorf("invalid drama ID")
}
dramaIDUint := uint(id)
var drama models.Drama
if err := s.db.Where("id = ? ", dramaIDUint).First(&drama).Error; err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return errors.New("drama not found")
}
return err
}
// 删除旧剧集
if err := s.db.Where("drama_id = ?", dramaIDUint).Delete(&models.Episode{}).Error; err != nil {
s.log.Errorw("Failed to delete old episodes", "error", err)
return err
}
// 创建新剧集(不包含场景,场景由后续步骤生成)
for _, ep := range req.Episodes {
episode := models.Episode{
DramaID: dramaIDUint,
EpisodeNum: ep.EpisodeNum,
Title: ep.Title,
Description: ep.Description,
ScriptContent: ep.ScriptContent,
Duration: ep.Duration,
Status: "draft",
}
if err := s.db.Create(&episode).Error; err != nil {
s.log.Errorw("Failed to create episode", "error", err, "episode", ep.EpisodeNum)
continue
}
}
if err := s.db.Model(&drama).Update("updated_at", time.Now()).Error; err != nil {
s.log.Errorw("Failed to update drama timestamp", "error", err)
}
s.log.Infow("Episodes saved", "drama_id", dramaID, "count", len(req.Episodes))
return nil
}
func (s *DramaService) SaveProgress(dramaID string, req *SaveProgressRequest) error {
var drama models.Drama
if err := s.db.Where("id = ? ", dramaID).First(&drama).Error; err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return errors.New("drama not found")
}
return err
}
// 构建metadata对象
metadata := make(map[string]interface{})
// 保留现有metadata
if drama.Metadata != nil {
if err := json.Unmarshal(drama.Metadata, &metadata); err != nil {
s.log.Warnw("Failed to unmarshal existing metadata", "error", err)
}
}
// 更新progress信息
metadata["current_step"] = req.CurrentStep
if req.StepData != nil {
metadata["step_data"] = req.StepData
}
// 序列化metadata
metadataJSON, err := json.Marshal(metadata)
if err != nil {
s.log.Errorw("Failed to marshal metadata", "error", err)
return err
}
updates := map[string]interface{}{
"metadata": metadataJSON,
"updated_at": time.Now(),
}
if err := s.db.Model(&drama).Updates(updates).Error; err != nil {
s.log.Errorw("Failed to save progress", "error", err)
return err
}
s.log.Infow("Progress saved", "drama_id", dramaID, "step", req.CurrentStep)
return nil
}