diff --git a/api/handlers/character_library.go b/api/handlers/character_library.go index 3ffeee5..aa97936 100644 --- a/api/handlers/character_library.go +++ b/api/handlers/character_library.go @@ -4,6 +4,7 @@ import ( "strconv" services2 "github.com/drama-generator/backend/application/services" + "github.com/drama-generator/backend/infrastructure/storage" "github.com/drama-generator/backend/pkg/config" "github.com/drama-generator/backend/pkg/logger" "github.com/drama-generator/backend/pkg/response" @@ -17,10 +18,10 @@ type CharacterLibraryHandler struct { log *logger.Logger } -func NewCharacterLibraryHandler(db *gorm.DB, cfg *config.Config, log *logger.Logger, transferService *services2.ResourceTransferService) *CharacterLibraryHandler { +func NewCharacterLibraryHandler(db *gorm.DB, cfg *config.Config, log *logger.Logger, transferService *services2.ResourceTransferService, localStorage *storage.LocalStorage) *CharacterLibraryHandler { return &CharacterLibraryHandler{ libraryService: services2.NewCharacterLibraryService(db, log), - imageService: services2.NewImageGenerationService(db, transferService, log), + imageService: services2.NewImageGenerationService(db, transferService, localStorage, log), log: log, } } diff --git a/api/handlers/image_generation.go b/api/handlers/image_generation.go index 93d3ec3..c389c99 100644 --- a/api/handlers/image_generation.go +++ b/api/handlers/image_generation.go @@ -4,6 +4,7 @@ import ( "strconv" "github.com/drama-generator/backend/application/services" + "github.com/drama-generator/backend/infrastructure/storage" "github.com/drama-generator/backend/pkg/config" "github.com/drama-generator/backend/pkg/logger" "github.com/drama-generator/backend/pkg/response" @@ -16,9 +17,9 @@ type ImageGenerationHandler struct { log *logger.Logger } -func NewImageGenerationHandler(db *gorm.DB, cfg *config.Config, log *logger.Logger, transferService *services.ResourceTransferService) *ImageGenerationHandler { +func NewImageGenerationHandler(db *gorm.DB, cfg *config.Config, log *logger.Logger, transferService *services.ResourceTransferService, localStorage *storage.LocalStorage) *ImageGenerationHandler { return &ImageGenerationHandler{ - imageService: services.NewImageGenerationService(db, transferService, log), + imageService: services.NewImageGenerationService(db, transferService, localStorage, log), log: log, } } diff --git a/api/routes/routes.go b/api/routes/routes.go index d786ae7..00ec503 100644 --- a/api/routes/routes.go +++ b/api/routes/routes.go @@ -30,18 +30,18 @@ func SetupRouter(cfg *config.Config, db *gorm.DB, log *logger.Logger, localStora }) aiService := services2.NewAIService(db, log) + localStoragePtr := localStorage.(*storage2.LocalStorage) + transferService := services2.NewResourceTransferService(db, log) dramaHandler := handlers2.NewDramaHandler(db, cfg, log, nil) aiConfigHandler := handlers2.NewAIConfigHandler(db, cfg, log) scriptGenHandler := handlers2.NewScriptGenerationHandler(db, cfg, log) - imageGenService := services2.NewImageGenerationService(db, nil, log) - imageGenHandler := handlers2.NewImageGenerationHandler(db, cfg, log, nil) - localStoragePtr := localStorage.(*storage2.LocalStorage) - transferService := services2.NewResourceTransferService(db, log) + imageGenService := services2.NewImageGenerationService(db, transferService, localStoragePtr, log) + imageGenHandler := handlers2.NewImageGenerationHandler(db, cfg, log, transferService, localStoragePtr) videoGenHandler := handlers2.NewVideoGenerationHandler(db, transferService, localStoragePtr, aiService, log) videoMergeHandler := handlers2.NewVideoMergeHandler(db, nil, cfg.Storage.LocalPath, cfg.Storage.BaseURL, log) assetHandler := handlers2.NewAssetHandler(db, cfg, log) characterLibraryService := services2.NewCharacterLibraryService(db, log) - characterLibraryHandler := handlers2.NewCharacterLibraryHandler(db, cfg, log, nil) + characterLibraryHandler := handlers2.NewCharacterLibraryHandler(db, cfg, log, transferService, localStoragePtr) uploadHandler, err := handlers2.NewUploadHandler(cfg, log, characterLibraryService) if err != nil { log.Fatalw("Failed to create upload handler", "error", err) diff --git a/application/services/image_generation_service.go b/application/services/image_generation_service.go index 4f4d8d9..0368fdf 100644 --- a/application/services/image_generation_service.go +++ b/application/services/image_generation_service.go @@ -7,6 +7,7 @@ import ( "time" models "github.com/drama-generator/backend/domain/models" + "github.com/drama-generator/backend/infrastructure/storage" "github.com/drama-generator/backend/pkg/ai" "github.com/drama-generator/backend/pkg/image" "github.com/drama-generator/backend/pkg/logger" @@ -18,14 +19,16 @@ type ImageGenerationService struct { db *gorm.DB aiService *AIService transferService *ResourceTransferService + localStorage *storage.LocalStorage log *logger.Logger } -func NewImageGenerationService(db *gorm.DB, transferService *ResourceTransferService, log *logger.Logger) *ImageGenerationService { +func NewImageGenerationService(db *gorm.DB, transferService *ResourceTransferService, localStorage *storage.LocalStorage, log *logger.Logger) *ImageGenerationService { return &ImageGenerationService{ db: db, aiService: NewAIService(db, log), transferService: transferService, + localStorage: localStorage, log: log, } } @@ -241,6 +244,23 @@ func (s *ImageGenerationService) pollTaskStatus(imageGenID uint, client image.Im func (s *ImageGenerationService) completeImageGeneration(imageGenID uint, result *image.ImageResult) { now := time.Now() + + // 下载图片到本地存储(仅用于缓存,不更新数据库) + if s.localStorage != nil && result.ImageURL != "" { + _, err := s.localStorage.DownloadFromURL(result.ImageURL, "images") + if err != nil { + s.log.Warnw("Failed to download image to local storage", + "error", err, + "id", imageGenID, + "original_url", result.ImageURL) + } else { + s.log.Infow("Image downloaded to local storage for caching", + "id", imageGenID, + "original_url", result.ImageURL) + } + } + + // 数据库中保持使用原始URL updates := map[string]interface{}{ "status": models.ImageStatusCompleted, "image_url": result.ImageURL, diff --git a/application/services/video_generation_service.go b/application/services/video_generation_service.go index 8169222..92e8c7a 100644 --- a/application/services/video_generation_service.go +++ b/application/services/video_generation_service.go @@ -316,6 +316,37 @@ func (s *VideoGenerationService) pollTaskStatus(videoGenID uint, taskID string, } func (s *VideoGenerationService) completeVideoGeneration(videoGenID uint, videoURL string, duration *int, width *int, height *int, firstFrameURL *string) { + // 下载视频到本地存储(仅用于缓存,不更新数据库) + if s.localStorage != nil && videoURL != "" { + _, err := s.localStorage.DownloadFromURL(videoURL, "videos") + if err != nil { + s.log.Warnw("Failed to download video to local storage", + "error", err, + "id", videoGenID, + "original_url", videoURL) + } else { + s.log.Infow("Video downloaded to local storage for caching", + "id", videoGenID, + "original_url", videoURL) + } + } + + // 下载首帧图片到本地存储(仅用于缓存,不更新数据库) + if firstFrameURL != nil && *firstFrameURL != "" && s.localStorage != nil { + _, err := s.localStorage.DownloadFromURL(*firstFrameURL, "video_frames") + if err != nil { + s.log.Warnw("Failed to download first frame to local storage", + "error", err, + "id", videoGenID, + "original_url", *firstFrameURL) + } else { + s.log.Infow("First frame downloaded to local storage for caching", + "id", videoGenID, + "original_url", *firstFrameURL) + } + } + + // 数据库中保持使用原始URL updates := map[string]interface{}{ "status": models.VideoStatusCompleted, "video_url": videoURL, diff --git a/infrastructure/storage/local_storage.go b/infrastructure/storage/local_storage.go index d75da66..45a4ec6 100644 --- a/infrastructure/storage/local_storage.go +++ b/infrastructure/storage/local_storage.go @@ -3,8 +3,10 @@ package storage import ( "fmt" "io" + "net/http" "os" "path/filepath" + "strings" "time" ) @@ -55,3 +57,81 @@ func (s *LocalStorage) Delete(url string) error { func (s *LocalStorage) GetURL(path string) string { return fmt.Sprintf("%s/%s", s.baseURL, path) } + +// DownloadFromURL 从远程URL下载文件到本地存储 +func (s *LocalStorage) DownloadFromURL(url, category string) (string, error) { + // 发送HTTP请求下载文件 + resp, err := http.Get(url) + if err != nil { + return "", fmt.Errorf("failed to download file: %w", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + return "", fmt.Errorf("failed to download file: HTTP %d", resp.StatusCode) + } + + // 从URL或Content-Type推断文件扩展名 + ext := getFileExtension(url, resp.Header.Get("Content-Type")) + + // 创建目录 + dir := filepath.Join(s.basePath, category) + if err := os.MkdirAll(dir, 0755); err != nil { + return "", fmt.Errorf("failed to create category directory: %w", err) + } + + // 生成唯一文件名 + timestamp := time.Now().Format("20060102_150405_000") + filename := fmt.Sprintf("%s%s", timestamp, ext) + filePath := filepath.Join(dir, filename) + + // 保存文件 + dst, err := os.Create(filePath) + if err != nil { + return "", fmt.Errorf("failed to create file: %w", err) + } + defer dst.Close() + + if _, err := io.Copy(dst, resp.Body); err != nil { + return "", fmt.Errorf("failed to save file: %w", err) + } + + // 返回本地URL + localURL := fmt.Sprintf("%s/%s/%s", s.baseURL, category, filename) + return localURL, nil +} + +// getFileExtension 从URL或Content-Type推断文件扩展名 +func getFileExtension(url, contentType string) string { + // 首先尝试从URL获取扩展名 + if idx := strings.LastIndex(url, "."); idx != -1 { + ext := url[idx:] + // 只取扩展名部分,忽略查询参数 + if qIdx := strings.Index(ext, "?"); qIdx != -1 { + ext = ext[:qIdx] + } + if len(ext) <= 5 { // 合理的扩展名长度 + return ext + } + } + + // 根据Content-Type推断扩展名 + switch { + case strings.Contains(contentType, "image/jpeg"): + return ".jpg" + case strings.Contains(contentType, "image/png"): + return ".png" + case strings.Contains(contentType, "image/gif"): + return ".gif" + case strings.Contains(contentType, "image/webp"): + return ".webp" + case strings.Contains(contentType, "video/mp4"): + return ".mp4" + case strings.Contains(contentType, "video/webm"): + return ".webm" + case strings.Contains(contentType, "video/quicktime"): + return ".mov" + default: + return ".bin" + } +} diff --git a/pkg/video/volces_ark_client.go b/pkg/video/volces_ark_client.go index cfd5c8a..8844a0f 100644 --- a/pkg/video/volces_ark_client.go +++ b/pkg/video/volces_ark_client.go @@ -230,10 +230,12 @@ func (c *VolcesArkClient) GenerateVideo(imageURL, prompt string, opts ...VideoOp } func (c *VolcesArkClient) GetTaskStatus(taskID string) (*VideoResult, error) { - // 替换占位符{taskId}或直接拼接 + // 替换占位符{taskId}、{task_id}或直接拼接 queryPath := c.QueryEndpoint - if contains := bytes.Contains([]byte(queryPath), []byte("{taskId}")); contains { - queryPath = string(bytes.ReplaceAll([]byte(queryPath), []byte("{taskId}"), []byte(taskID))) + if strings.Contains(queryPath, "{taskId}") { + queryPath = strings.ReplaceAll(queryPath, "{taskId}", taskID) + } else if strings.Contains(queryPath, "{task_id}") { + queryPath = strings.ReplaceAll(queryPath, "{task_id}", taskID) } else { queryPath = queryPath + "/" + taskID } diff --git a/web/src/stores/episode.ts b/web/src/stores/episode.ts new file mode 100644 index 0000000..4bea741 --- /dev/null +++ b/web/src/stores/episode.ts @@ -0,0 +1,233 @@ +import { ref, computed, reactive } from 'vue' +import { defineStore } from 'pinia' +import { dramaAPI } from '@/api/drama' +import type { Episode, Character, Scene } from '@/types/drama' + +interface EpisodeCache { + data: Episode + loading: boolean + error: string | null + lastFetch: number +} + +interface EpisodeOperations { + refresh: () => Promise + set: (params: SetOperationParams) => Promise + del: (params: DeleteOperationParams) => Promise + saveScript: (content: string) => Promise + extractData: () => Promise + generateImages: (options?: GenerateImageOptions) => Promise + generateStoryboards: () => Promise +} + +interface SetOperationParams { + type: 'character' | 'scene' | 'storyboard' + data: any +} + +interface DeleteOperationParams { + type: 'character' | 'scene' | 'storyboard' + id: string | number +} + +interface GenerateImageOptions { + characterIds?: number[] + sceneIds?: string[] +} + +export interface CachedEpisode { + value: Episode + loading: boolean + error: string | null + refresh: () => Promise + set: (params: SetOperationParams) => Promise + del: (params: DeleteOperationParams) => Promise + saveScript: (content: string) => Promise + extractData: () => Promise + generateImages: (options?: GenerateImageOptions) => Promise + generateStoryboards: () => Promise +} + +export const useEpisodeStore = defineStore('episode', () => { + const caches = reactive>(new Map()) + + const getCacheByEpisodeId = (episodeId: string): CachedEpisode => { + if (!caches.has(episodeId)) { + caches.set(episodeId, { + data: {} as Episode, + loading: false, + error: null, + lastFetch: 0 + }) + fetchEpisode(episodeId) + } + + const cache = caches.get(episodeId)! + + const operations: EpisodeOperations = { + async refresh() { + await fetchEpisode(episodeId, true) + }, + + async set(params: SetOperationParams) { + const { type, data } = params + + switch (type) { + case 'character': + await dramaAPI.saveCharacters(cache.data.drama_id, [data], episodeId) + await fetchEpisode(episodeId, true) + break + case 'scene': + await dramaAPI.updateScene(data.id, data) + await fetchEpisode(episodeId, true) + break + case 'storyboard': + await dramaAPI.updateStoryboard(data.id, data) + await fetchEpisode(episodeId, true) + break + } + }, + + async del(params: DeleteOperationParams) { + const { type, id } = params + + switch (type) { + case 'character': + const characters = cache.data.characters?.filter(c => c.id !== id) || [] + await dramaAPI.saveCharacters(cache.data.drama_id, characters, episodeId) + await fetchEpisode(episodeId, true) + break + case 'scene': + break + case 'storyboard': + break + } + }, + + async saveScript(content: string) { + const parts = episodeId.split('-') + const dramaId = parts[0] + const episodeNumber = parseInt(parts.length > 1 ? parts[1] : cache.data.episode_number?.toString() || '1') + + await dramaAPI.saveEpisodes(dramaId, [{ + episode_number: episodeNumber, + script_content: content + }]) + + await fetchEpisode(episodeId, true) + }, + + async extractData() { + await dramaAPI.extractBackgrounds(episodeId) + await fetchEpisode(episodeId, true) + }, + + async generateImages(options?: GenerateImageOptions) { + const promises: Promise[] = [] + + if (options?.characterIds && options.characterIds.length > 0) { + options.characterIds.forEach(id => { + const character = cache.data.characters?.find(c => c.id === id) + if (character) { + promises.push( + dramaAPI.generateSceneImage({ + scene_id: character.id.toString(), + prompt: character.appearance || character.description || character.name, + model: undefined + }) + ) + } + }) + } + + if (options?.sceneIds && options.sceneIds.length > 0) { + options.sceneIds.forEach(sceneId => { + promises.push( + dramaAPI.generateSceneImage({ + scene_id: sceneId, + model: undefined + }) + ) + }) + } + + if (promises.length > 0) { + await Promise.allSettled(promises) + } + + await fetchEpisode(episodeId, true) + }, + + async generateStoryboards() { + await dramaAPI.generateStoryboard(episodeId) + await fetchEpisode(episodeId, true) + } + } + + return { + get value() { + return cache.data + }, + get loading() { + return cache.loading + }, + get error() { + return cache.error + }, + ...operations + } + } + + const fetchEpisode = async (episodeId: string, force = false) => { + const cache = caches.get(episodeId) + if (!cache) return + + const now = Date.now() + if (!force && cache.lastFetch && (now - cache.lastFetch) < 3000) { + return + } + + cache.loading = true + cache.error = null + + try { + const parts = episodeId.split('-') + const dramaId = parts[0] + const episodeNumber = parts.length > 1 ? parseInt(parts[1]) : null + + const drama = await dramaAPI.get(dramaId) + + let episode: Episode | undefined + if (episodeNumber !== null) { + episode = drama.episodes?.find(e => e.episode_number === episodeNumber) + } else { + episode = drama.episodes?.find(e => e.id === episodeId) + } + + if (episode) { + cache.data = episode + cache.lastFetch = now + } else { + cache.error = '未找到章节数据' + } + } catch (error: any) { + cache.error = error.message || '加载章节数据失败' + console.error('Failed to fetch episode:', error) + } finally { + cache.loading = false + } + } + + const clearCache = (episodeId?: string) => { + if (episodeId) { + caches.delete(episodeId) + } else { + caches.clear() + } + } + + return { + getCacheByEpisodeId, + clearCache + } +}) diff --git a/web/src/types/video.ts b/web/src/types/video.ts index 455965e..5077a53 100644 --- a/web/src/types/video.ts +++ b/web/src/types/video.ts @@ -38,7 +38,7 @@ export interface GenerateVideoRequest { scene_id?: string // 已废弃,保留用于兼容 drama_id: string image_gen_id?: number - image_url: string + image_url?: string prompt: string provider?: string model?: string @@ -49,8 +49,10 @@ export interface GenerateVideoRequest { motion_level?: number camera_motion?: string seed?: number + reference_mode?: string // 参考图模式:single, first_last, multiple, none first_frame_url?: string // 首帧图片URL last_frame_url?: string // 尾帧图片URL + reference_image_urls?: string[] // 多图参考模式 } export interface VideoGenerationListParams { diff --git a/web/src/views/drama/DramaManagement.vue b/web/src/views/drama/DramaManagement.vue index 7b9e358..45871d2 100644 --- a/web/src/views/drama/DramaManagement.vue +++ b/web/src/views/drama/DramaManagement.vue @@ -126,14 +126,11 @@ {{ formatDate(row.created_at) }} - +