Files
AI-Video/api/routers/editor.py
empty 79a6c2ef3e feat: Add inpainting (局部重绘) feature for timeline editor
- Add canvas-based mask drawing tools (brush, eraser, rect, lasso)
- Add undo/redo history support for mask editing
- Integrate inpainting UI into preview player
- Add backend API endpoint for inpainting requests
- Add MediaService.inpaint method with ComfyUI workflow support
- Add Flux inpainting workflows for selfhost and RunningHub

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-01-05 23:44:51 +08:00

686 lines
24 KiB
Python

# Copyright (C) 2025 AIDC-AI
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Editor API router for timeline editor operations
Provides endpoints for:
- Fetching storyboard data
- Reordering frames
- Updating frame duration
- Generating preview
"""
from fastapi import APIRouter, HTTPException, Path
from loguru import logger
from api.schemas.editor import (
StoryboardSchema,
StoryboardFrameSchema,
ReorderFramesRequest,
UpdateDurationRequest,
PreviewRequest,
PreviewResponse,
UpdateFrameRequest,
UpdateFrameResponse,
RegenerateImageRequest,
RegenerateImageResponse,
RegenerateAudioRequest,
RegenerateAudioResponse,
InpaintRequest,
InpaintResponse,
)
router = APIRouter(prefix="/editor", tags=["Editor"])
def _path_to_url(file_path: str, base_url: str = "http://localhost:8000") -> str:
"""Convert local file path to URL accessible through API"""
if not file_path:
return None
import os
from pathlib import Path
# Normalize path separators
file_path = file_path.replace("\\", "/")
# Extract relative path from output directory
parts = file_path.split("/")
try:
output_idx = parts.index("output")
relative_parts = parts[output_idx + 1:]
relative_path = "/".join(relative_parts)
except ValueError:
relative_path = Path(file_path).name
return f"{base_url}/api/files/{relative_path}"
# In-memory cache for demo (in production, use database)
_storyboard_cache: dict = {}
# Demo data for testing
_demo_storyboard = {
"id": "demo-1",
"title": "演示视频",
"total_duration": 15.5,
"final_video_path": None,
"created_at": None,
"frames": [
{"id": "frame-0", "index": 0, "order": 0, "narration": "在一个宁静的早晨,阳光洒满了整个城市", "image_prompt": "A peaceful morning", "duration": 3.2},
{"id": "frame-1", "index": 1, "order": 1, "narration": "小明决定出门去探索这个美丽的世界", "image_prompt": "A young man stepping out", "duration": 2.8},
{"id": "frame-2", "index": 2, "order": 2, "narration": "他走过熟悉的街道,感受着微风的吹拂", "image_prompt": "Walking through streets", "duration": 3.5},
{"id": "frame-3", "index": 3, "order": 3, "narration": "公园里的花朵正在盛开,散发着迷人的芬芳", "image_prompt": "Blooming flowers", "duration": 3.0},
{"id": "frame-4", "index": 4, "order": 4, "narration": "这是新的一天的开始,充满了无限可能", "image_prompt": "New day begins", "duration": 3.0},
],
}
# Import task manager
from api.tasks.manager import task_manager
@router.get("/storyboard/{storyboard_id}", response_model=StoryboardSchema)
async def get_storyboard(storyboard_id: str = Path(..., description="Storyboard/task ID")):
"""
Get storyboard by ID
Supports:
- 'demo-1': Returns demo data for testing
- Any task_id: Loads real storyboard from completed video generation tasks
- History tasks: Loads from persistence service
"""
# Return demo data for demo-1
if storyboard_id == "demo-1":
if "demo-1" not in _storyboard_cache:
_storyboard_cache["demo-1"] = _demo_storyboard.copy()
return _storyboard_cache["demo-1"]
# Try to get from cache first
if storyboard_id in _storyboard_cache:
return _storyboard_cache[storyboard_id]
# Try to load from task manager (in-memory task)
task = task_manager.get_task(storyboard_id)
if task and task.result:
# Extract storyboard from task result
result = task.result
# Handle different result formats
storyboard_data = None
if hasattr(result, 'storyboard'):
storyboard_data = result.storyboard
elif isinstance(result, dict) and 'storyboard' in result:
storyboard_data = result['storyboard']
if storyboard_data:
# Convert to editor schema format
schema = _convert_storyboard_to_schema(storyboard_id, storyboard_data)
_storyboard_cache[storyboard_id] = schema
logger.info(f"Loaded storyboard from task {storyboard_id}")
return schema
# Try to load from persistence service (history)
try:
from pixelle_video.services.persistence import PersistenceService
persistence = PersistenceService(output_dir="output")
# Load storyboard from disk (await since we're in an async function)
storyboard = await persistence.load_storyboard(storyboard_id)
if storyboard:
schema = _convert_storyboard_to_schema(storyboard_id, storyboard)
_storyboard_cache[storyboard_id] = schema
logger.info(f"Loaded storyboard from persistence {storyboard_id}")
return schema
except Exception as e:
logger.warning(f"Failed to load from persistence: {e}")
raise HTTPException(status_code=404, detail=f"Storyboard {storyboard_id} not found")
def _convert_storyboard_to_schema(storyboard_id: str, storyboard) -> dict:
"""Convert internal Storyboard model to API schema format."""
frames = []
# Handle both object and dict formats
if hasattr(storyboard, 'frames'):
frame_list = storyboard.frames
title = getattr(storyboard, 'title', storyboard_id)
total_duration = getattr(storyboard, 'total_duration', 0)
final_video_path = getattr(storyboard, 'final_video_path', None)
created_at = getattr(storyboard, 'created_at', None)
elif isinstance(storyboard, dict):
frame_list = storyboard.get('frames', [])
title = storyboard.get('title', storyboard_id)
total_duration = storyboard.get('total_duration', 0)
final_video_path = storyboard.get('final_video_path')
created_at = storyboard.get('created_at')
else:
frame_list = []
title = storyboard_id
total_duration = 0
final_video_path = None
created_at = None
for i, frame in enumerate(frame_list):
if hasattr(frame, 'narration'):
# Object format
frames.append({
"id": f"frame-{i}",
"index": getattr(frame, 'index', i),
"order": i,
"narration": frame.narration or "",
"image_prompt": getattr(frame, 'image_prompt', ""),
"image_path": _path_to_url(getattr(frame, 'image_path', None)),
"audio_path": _path_to_url(getattr(frame, 'audio_path', None)),
"video_segment_path": _path_to_url(getattr(frame, 'video_segment_path', None)),
"duration": getattr(frame, 'duration', 3.0),
})
elif isinstance(frame, dict):
# Dict format
frames.append({
"id": f"frame-{i}",
"index": frame.get('index', i),
"order": i,
"narration": frame.get('narration', ""),
"image_prompt": frame.get('image_prompt', ""),
"image_path": _path_to_url(frame.get('image_path')),
"audio_path": _path_to_url(frame.get('audio_path')),
"video_segment_path": _path_to_url(frame.get('video_segment_path')),
"duration": frame.get('duration', 3.0),
})
return {
"id": storyboard_id,
"title": title,
"frames": frames,
"total_duration": total_duration or sum(f.get('duration', 3.0) for f in frames),
"final_video_path": final_video_path,
"created_at": created_at.isoformat() if created_at else None,
}
@router.patch("/storyboard/{storyboard_id}/reorder", response_model=StoryboardSchema)
async def reorder_frames(
storyboard_id: str = Path(..., description="Storyboard/task ID"),
request: ReorderFramesRequest = None
):
"""
Reorder frames in storyboard
Updates the order of frames based on the provided frame ID list.
"""
if storyboard_id not in _storyboard_cache:
raise HTTPException(status_code=404, detail=f"Storyboard {storyboard_id} not found in cache")
storyboard = _storyboard_cache[storyboard_id]
frames = storyboard["frames"]
# Create ID to frame mapping
frame_map = {f["id"]: f for f in frames}
# Validate all IDs exist
for frame_id in request.order:
if frame_id not in frame_map:
raise HTTPException(status_code=400, detail=f"Frame {frame_id} not found")
# Reorder frames
reordered = []
for idx, frame_id in enumerate(request.order):
frame = frame_map[frame_id].copy()
frame["order"] = idx
reordered.append(frame)
storyboard["frames"] = reordered
_storyboard_cache[storyboard_id] = storyboard
logger.info(f"Reordered {len(reordered)} frames in storyboard {storyboard_id}")
return storyboard
@router.patch(
"/storyboard/{storyboard_id}/frames/{frame_id}/duration",
response_model=StoryboardFrameSchema
)
async def update_frame_duration(
storyboard_id: str = Path(..., description="Storyboard/task ID"),
frame_id: str = Path(..., description="Frame ID"),
request: UpdateDurationRequest = None
):
"""
Update frame duration
Changes the duration of a specific frame and recalculates total duration.
"""
if storyboard_id not in _storyboard_cache:
raise HTTPException(status_code=404, detail=f"Storyboard {storyboard_id} not found in cache")
storyboard = _storyboard_cache[storyboard_id]
frames = storyboard["frames"]
# Find and update frame
updated_frame = None
for frame in frames:
if frame["id"] == frame_id:
frame["duration"] = request.duration
updated_frame = frame
break
if not updated_frame:
raise HTTPException(status_code=404, detail=f"Frame {frame_id} not found")
# Recalculate total duration
storyboard["total_duration"] = sum(f["duration"] for f in frames)
_storyboard_cache[storyboard_id] = storyboard
logger.info(f"Updated frame {frame_id} duration to {request.duration}s")
return updated_frame
@router.post("/storyboard/{storyboard_id}/preview", response_model=PreviewResponse)
async def generate_preview(
storyboard_id: str = Path(..., description="Storyboard/task ID"),
request: PreviewRequest = None
):
"""
Generate preview video for selected frames
Creates a preview video from the specified frame range.
"""
if storyboard_id not in _storyboard_cache:
raise HTTPException(status_code=404, detail=f"Storyboard {storyboard_id} not found in cache")
storyboard = _storyboard_cache[storyboard_id]
frames = storyboard["frames"]
# Determine frame range
start = request.start_frame if request else 0
end = request.end_frame if request and request.end_frame else len(frames)
if start >= len(frames):
raise HTTPException(status_code=400, detail="Start frame out of range")
preview_frames = frames[start:end]
total_duration = sum(f["duration"] for f in preview_frames)
# TODO: Implement actual preview generation logic
# For now, return mock response
preview_path = f"/output/{storyboard_id}/preview_{start}_{end}.mp4"
logger.info(f"Generated preview for frames {start}-{end} ({len(preview_frames)} frames)")
return PreviewResponse(
preview_path=preview_path,
duration=total_duration,
frames_count=len(preview_frames)
)
def _storyboard_to_schema(storyboard_id: str, storyboard) -> dict:
"""Convert internal Storyboard to API schema format"""
frames = []
for i, frame in enumerate(storyboard.frames):
frames.append({
"id": f"frame-{i}",
"index": frame.index,
"order": i,
"narration": frame.narration,
"image_prompt": frame.image_prompt,
"image_path": frame.image_path,
"audio_path": frame.audio_path,
"video_segment_path": frame.video_segment_path,
"duration": frame.duration,
})
return {
"id": storyboard_id,
"title": storyboard.title,
"frames": frames,
"total_duration": storyboard.total_duration,
"final_video_path": storyboard.final_video_path,
"created_at": storyboard.created_at,
}
@router.put(
"/storyboard/{storyboard_id}/frames/{frame_id}",
response_model=UpdateFrameResponse
)
async def update_frame(
storyboard_id: str = Path(..., description="Storyboard/task ID"),
frame_id: str = Path(..., description="Frame ID"),
request: UpdateFrameRequest = None
):
"""
Update frame content (narration and/or image prompt)
Updates the text content of a frame without regenerating media.
"""
if storyboard_id not in _storyboard_cache:
raise HTTPException(status_code=404, detail=f"Storyboard {storyboard_id} not found in cache")
storyboard = _storyboard_cache[storyboard_id]
frames = storyboard["frames"]
# Find and update frame
updated_frame = None
for frame in frames:
if frame["id"] == frame_id:
if request.narration is not None:
frame["narration"] = request.narration
if request.image_prompt is not None:
frame["image_prompt"] = request.image_prompt
updated_frame = frame
break
if not updated_frame:
raise HTTPException(status_code=404, detail=f"Frame {frame_id} not found")
_storyboard_cache[storyboard_id] = storyboard
logger.info(f"Updated frame {frame_id} content")
return UpdateFrameResponse(
id=frame_id,
narration=updated_frame["narration"],
image_prompt=updated_frame.get("image_prompt"),
updated=True
)
@router.post(
"/storyboard/{storyboard_id}/frames/{frame_id}/regenerate-image",
response_model=RegenerateImageResponse
)
async def regenerate_frame_image(
storyboard_id: str = Path(..., description="Storyboard/task ID"),
frame_id: str = Path(..., description="Frame ID"),
request: RegenerateImageRequest = None
):
"""
Regenerate image for a frame
Uses the frame's image_prompt (or override) to generate a new image.
Requires ComfyUI service to be running.
"""
if storyboard_id not in _storyboard_cache:
raise HTTPException(status_code=404, detail=f"Storyboard {storyboard_id} not found")
storyboard = _storyboard_cache[storyboard_id]
frames = storyboard["frames"]
# Find frame
target_frame = None
frame_index = 0
for i, frame in enumerate(frames):
if frame["id"] == frame_id:
target_frame = frame
frame_index = i
break
if not target_frame:
raise HTTPException(status_code=404, detail=f"Frame {frame_id} not found")
# Get prompt to use
prompt = request.image_prompt if request and request.image_prompt else target_frame.get("image_prompt", "")
if not prompt:
raise HTTPException(status_code=400, detail="No image prompt available")
try:
# Import and use PixelleVideo core for image generation
from api.dependencies import get_pixelle_video
from pixelle_video.models.storyboard import StoryboardFrame, StoryboardConfig
pixelle_video = get_pixelle_video()
# Generate image using ComfyKit
result = await pixelle_video.comfy(
workflow="image_gen",
prompt=prompt,
task_id=storyboard_id,
)
if result and result.get("images"):
# Download and save image
image_url = result["images"][0]
import aiohttp
import os
output_dir = f"output/{storyboard_id}"
os.makedirs(output_dir, exist_ok=True)
image_path = f"{output_dir}/frame_{frame_index}_regenerated.png"
async with aiohttp.ClientSession() as session:
async with session.get(image_url) as resp:
if resp.status == 200:
with open(image_path, 'wb') as f:
f.write(await resp.read())
# Update frame
target_frame["image_path"] = _path_to_url(image_path)
_storyboard_cache[storyboard_id] = storyboard
logger.info(f"Regenerated image for frame {frame_id}")
return RegenerateImageResponse(
image_path=target_frame["image_path"],
success=True
)
else:
raise HTTPException(status_code=500, detail="Image generation failed")
except ImportError as e:
logger.error(f"Failed to import dependencies: {e}")
raise HTTPException(status_code=500, detail="Image generation service not available")
except Exception as e:
logger.error(f"Image regeneration failed: {e}")
raise HTTPException(status_code=500, detail=str(e))
@router.post(
"/storyboard/{storyboard_id}/frames/{frame_id}/regenerate-audio",
response_model=RegenerateAudioResponse
)
async def regenerate_frame_audio(
storyboard_id: str = Path(..., description="Storyboard/task ID"),
frame_id: str = Path(..., description="Frame ID"),
request: RegenerateAudioRequest = None
):
"""
Regenerate audio for a frame
Uses the frame's narration (or override) to generate new audio via TTS.
"""
if storyboard_id not in _storyboard_cache:
raise HTTPException(status_code=404, detail=f"Storyboard {storyboard_id} not found")
storyboard = _storyboard_cache[storyboard_id]
frames = storyboard["frames"]
# Find frame
target_frame = None
frame_index = 0
for i, frame in enumerate(frames):
if frame["id"] == frame_id:
target_frame = frame
frame_index = i
break
if not target_frame:
raise HTTPException(status_code=404, detail=f"Frame {frame_id} not found")
# Get narration to use
narration = request.narration if request and request.narration else target_frame.get("narration", "")
if not narration:
raise HTTPException(status_code=400, detail="No narration text available")
try:
from api.dependencies import get_pixelle_video
import os
pixelle_video = get_pixelle_video()
# Create output path
output_dir = f"output/{storyboard_id}"
os.makedirs(output_dir, exist_ok=True)
audio_path = f"{output_dir}/frame_{frame_index}_audio_regenerated.mp3"
# Generate audio using TTS service
voice = request.voice if request and request.voice else None
result_path = await pixelle_video.tts(
text=narration,
voice=voice,
output_path=audio_path
)
# Get audio duration
from mutagen.mp3 import MP3
try:
audio = MP3(result_path)
duration = audio.info.length
except:
duration = 3.0 # Default duration
# Update frame
target_frame["audio_path"] = _path_to_url(result_path)
target_frame["duration"] = duration
# Recalculate total duration
storyboard["total_duration"] = sum(f.get("duration", 3.0) for f in frames)
_storyboard_cache[storyboard_id] = storyboard
logger.info(f"Regenerated audio for frame {frame_id}, duration: {duration}s")
return RegenerateAudioResponse(
audio_path=target_frame["audio_path"],
duration=duration,
success=True
)
except ImportError as e:
logger.error(f"Failed to import dependencies: {e}")
raise HTTPException(status_code=500, detail="TTS service not available")
except Exception as e:
logger.error(f"Audio regeneration failed: {e}")
raise HTTPException(status_code=500, detail=str(e))
@router.post(
"/storyboard/{storyboard_id}/frames/{frame_id}/inpaint",
response_model=InpaintResponse
)
async def inpaint_frame_image(
storyboard_id: str = Path(..., description="Storyboard/task ID"),
frame_id: str = Path(..., description="Frame ID"),
request: InpaintRequest = None
):
"""
Inpaint (局部重绘) frame image
Uses mask to selectively regenerate parts of the image.
"""
if storyboard_id not in _storyboard_cache:
raise HTTPException(status_code=404, detail=f"Storyboard {storyboard_id} not found")
storyboard = _storyboard_cache[storyboard_id]
frames = storyboard["frames"]
# Find frame
target_frame = None
frame_index = 0
for i, frame in enumerate(frames):
if frame["id"] == frame_id:
target_frame = frame
frame_index = i
break
if not target_frame:
raise HTTPException(status_code=404, detail=f"Frame {frame_id} not found")
# Get original image path
original_image = target_frame.get("image_path")
if not original_image:
raise HTTPException(status_code=400, detail="No image to inpaint")
if not request or not request.mask:
raise HTTPException(status_code=400, detail="Mask is required")
try:
from api.dependencies import get_pixelle_video
import base64
import tempfile
import os
pixelle_video = await get_pixelle_video()
# Save mask to temp file
mask_data = base64.b64decode(request.mask)
output_dir = f"output/{storyboard_id}"
os.makedirs(output_dir, exist_ok=True)
mask_path = f"{output_dir}/mask_{frame_index}.png"
with open(mask_path, 'wb') as f:
f.write(mask_data)
# Get prompt
prompt = request.prompt or target_frame.get("image_prompt", "")
# Call inpaint service
# Convert URL back to file path (URL format: http://localhost:8000/api/files/{relative_path})
image_file_path = original_image
if "/api/files/" in original_image:
image_file_path = "output/" + original_image.split("/api/files/")[-1]
result = await pixelle_video.media.inpaint(
image_path=image_file_path,
mask_path=mask_path,
prompt=prompt,
denoise_strength=request.denoise_strength,
)
if result and result.url:
# Save inpainted image
import aiohttp
image_path = f"{output_dir}/frame_{frame_index}_inpainted.png"
async with aiohttp.ClientSession() as session:
async with session.get(result.url) as resp:
if resp.status == 200:
with open(image_path, 'wb') as f:
f.write(await resp.read())
# Update frame
target_frame["image_path"] = _path_to_url(image_path)
_storyboard_cache[storyboard_id] = storyboard
logger.info(f"Inpainted image for frame {frame_id}")
return InpaintResponse(
image_path=target_frame["image_path"],
success=True
)
else:
raise HTTPException(status_code=500, detail="Inpainting failed")
except ImportError as e:
logger.error(f"Failed to import dependencies: {e}")
raise HTTPException(status_code=500, detail="Inpainting service not available")
except Exception as e:
logger.error(f"Inpainting failed: {e}")
raise HTTPException(status_code=500, detail=str(e))