支持fastapi服务
This commit is contained in:
15
api/__init__.py
Normal file
15
api/__init__.py
Normal file
@@ -0,0 +1,15 @@
|
||||
"""
|
||||
ReelForge API Layer
|
||||
|
||||
FastAPI-based REST API for video generation services.
|
||||
"""
|
||||
|
||||
# Lazy import to avoid loading dependencies until needed
|
||||
def get_app():
|
||||
"""Get FastAPI app instance (lazy loading)"""
|
||||
from api.app import app
|
||||
return app
|
||||
|
||||
|
||||
__all__ = ["get_app"]
|
||||
|
||||
133
api/app.py
Normal file
133
api/app.py
Normal file
@@ -0,0 +1,133 @@
|
||||
"""
|
||||
ReelForge FastAPI Application
|
||||
|
||||
Main FastAPI app with all routers and middleware.
|
||||
"""
|
||||
|
||||
from contextlib import asynccontextmanager
|
||||
from fastapi import FastAPI
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from loguru import logger
|
||||
|
||||
from api.config import api_config
|
||||
from api.tasks import task_manager
|
||||
from api.dependencies import shutdown_reelforge
|
||||
|
||||
# Import routers
|
||||
from api.routers import (
|
||||
health_router,
|
||||
llm_router,
|
||||
tts_router,
|
||||
image_router,
|
||||
content_router,
|
||||
video_router,
|
||||
tasks_router,
|
||||
files_router,
|
||||
)
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
"""
|
||||
Application lifespan manager
|
||||
|
||||
Handles startup and shutdown events.
|
||||
"""
|
||||
# Startup
|
||||
logger.info("🚀 Starting ReelForge API...")
|
||||
await task_manager.start()
|
||||
logger.info("✅ ReelForge API started successfully\n")
|
||||
|
||||
yield
|
||||
|
||||
# Shutdown
|
||||
logger.info("🛑 Shutting down ReelForge API...")
|
||||
await task_manager.stop()
|
||||
await shutdown_reelforge()
|
||||
logger.info("✅ ReelForge API shutdown complete")
|
||||
|
||||
|
||||
# Create FastAPI app
|
||||
app = FastAPI(
|
||||
title="ReelForge API",
|
||||
description="""
|
||||
## ReelForge - AI Video Generation Platform API
|
||||
|
||||
### Features
|
||||
- 🤖 **LLM**: Large language model integration
|
||||
- 🔊 **TTS**: Text-to-speech synthesis
|
||||
- 🎨 **Image**: AI image generation
|
||||
- 📝 **Content**: Automated content generation
|
||||
- 🎬 **Video**: End-to-end video generation
|
||||
|
||||
### Video Generation Modes
|
||||
- **Sync**: `/api/video/generate/sync` - For small videos (< 30s)
|
||||
- **Async**: `/api/video/generate/async` - For large videos with task tracking
|
||||
|
||||
### Getting Started
|
||||
1. Check health: `GET /health`
|
||||
2. Generate narrations: `POST /api/content/narration`
|
||||
3. Generate video: `POST /api/video/generate/sync` or `/async`
|
||||
4. Track task progress: `GET /api/tasks/{task_id}`
|
||||
""",
|
||||
version="0.1.0",
|
||||
docs_url=api_config.docs_url,
|
||||
redoc_url=api_config.redoc_url,
|
||||
openapi_url=api_config.openapi_url,
|
||||
lifespan=lifespan,
|
||||
)
|
||||
|
||||
# Add CORS middleware
|
||||
if api_config.cors_enabled:
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=api_config.cors_origins,
|
||||
allow_credentials=True,
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
logger.info(f"CORS enabled for origins: {api_config.cors_origins}")
|
||||
|
||||
# Include routers
|
||||
# Health check (no prefix)
|
||||
app.include_router(health_router)
|
||||
|
||||
# API routers (with /api prefix)
|
||||
app.include_router(llm_router, prefix=api_config.api_prefix)
|
||||
app.include_router(tts_router, prefix=api_config.api_prefix)
|
||||
app.include_router(image_router, prefix=api_config.api_prefix)
|
||||
app.include_router(content_router, prefix=api_config.api_prefix)
|
||||
app.include_router(video_router, prefix=api_config.api_prefix)
|
||||
app.include_router(tasks_router, prefix=api_config.api_prefix)
|
||||
app.include_router(files_router, prefix=api_config.api_prefix)
|
||||
|
||||
|
||||
@app.get("/")
|
||||
async def root():
|
||||
"""Root endpoint with API information"""
|
||||
return {
|
||||
"service": "ReelForge API",
|
||||
"version": "0.1.0",
|
||||
"docs": api_config.docs_url,
|
||||
"health": "/health",
|
||||
"api": {
|
||||
"llm": f"{api_config.api_prefix}/llm",
|
||||
"tts": f"{api_config.api_prefix}/tts",
|
||||
"image": f"{api_config.api_prefix}/image",
|
||||
"content": f"{api_config.api_prefix}/content",
|
||||
"video": f"{api_config.api_prefix}/video",
|
||||
"tasks": f"{api_config.api_prefix}/tasks",
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import uvicorn
|
||||
|
||||
uvicorn.run(
|
||||
"api.app:app",
|
||||
host=api_config.host,
|
||||
port=api_config.port,
|
||||
reload=api_config.reload,
|
||||
)
|
||||
|
||||
38
api/config.py
Normal file
38
api/config.py
Normal file
@@ -0,0 +1,38 @@
|
||||
"""
|
||||
API Configuration
|
||||
"""
|
||||
|
||||
from typing import Optional
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class APIConfig(BaseModel):
|
||||
"""API configuration"""
|
||||
|
||||
# Server settings
|
||||
host: str = "0.0.0.0"
|
||||
port: int = 8000
|
||||
reload: bool = False
|
||||
|
||||
# CORS settings
|
||||
cors_enabled: bool = True
|
||||
cors_origins: list[str] = ["*"]
|
||||
|
||||
# Task settings
|
||||
max_concurrent_tasks: int = 5
|
||||
task_cleanup_interval: int = 3600 # Clean completed tasks every hour
|
||||
task_retention_time: int = 86400 # Keep task results for 24 hours
|
||||
|
||||
# File upload settings
|
||||
max_upload_size: int = 100 * 1024 * 1024 # 100MB
|
||||
|
||||
# API settings
|
||||
api_prefix: str = "/api"
|
||||
docs_url: Optional[str] = "/docs"
|
||||
redoc_url: Optional[str] = "/redoc"
|
||||
openapi_url: Optional[str] = "/openapi.json"
|
||||
|
||||
|
||||
# Global config instance
|
||||
api_config = APIConfig()
|
||||
|
||||
45
api/dependencies.py
Normal file
45
api/dependencies.py
Normal file
@@ -0,0 +1,45 @@
|
||||
"""
|
||||
FastAPI Dependencies
|
||||
|
||||
Provides dependency injection for ReelForgeCore and other services.
|
||||
"""
|
||||
|
||||
from typing import Annotated
|
||||
from fastapi import Depends
|
||||
from loguru import logger
|
||||
|
||||
from reelforge.service import ReelForgeCore
|
||||
|
||||
|
||||
# Global ReelForge instance
|
||||
_reelforge_instance: ReelForgeCore = None
|
||||
|
||||
|
||||
async def get_reelforge() -> ReelForgeCore:
|
||||
"""
|
||||
Get ReelForge core instance (dependency injection)
|
||||
|
||||
Returns:
|
||||
ReelForgeCore instance
|
||||
"""
|
||||
global _reelforge_instance
|
||||
|
||||
if _reelforge_instance is None:
|
||||
_reelforge_instance = ReelForgeCore()
|
||||
await _reelforge_instance.initialize()
|
||||
logger.info("✅ ReelForge initialized for API")
|
||||
|
||||
return _reelforge_instance
|
||||
|
||||
|
||||
async def shutdown_reelforge():
|
||||
"""Shutdown ReelForge instance"""
|
||||
global _reelforge_instance
|
||||
if _reelforge_instance:
|
||||
logger.info("Shutting down ReelForge...")
|
||||
_reelforge_instance = None
|
||||
|
||||
|
||||
# Type alias for dependency injection
|
||||
ReelForgeDep = Annotated[ReelForgeCore, Depends(get_reelforge)]
|
||||
|
||||
24
api/routers/__init__.py
Normal file
24
api/routers/__init__.py
Normal file
@@ -0,0 +1,24 @@
|
||||
"""
|
||||
API Routers
|
||||
"""
|
||||
|
||||
from api.routers.health import router as health_router
|
||||
from api.routers.llm import router as llm_router
|
||||
from api.routers.tts import router as tts_router
|
||||
from api.routers.image import router as image_router
|
||||
from api.routers.content import router as content_router
|
||||
from api.routers.video import router as video_router
|
||||
from api.routers.tasks import router as tasks_router
|
||||
from api.routers.files import router as files_router
|
||||
|
||||
__all__ = [
|
||||
"health_router",
|
||||
"llm_router",
|
||||
"tts_router",
|
||||
"image_router",
|
||||
"content_router",
|
||||
"video_router",
|
||||
"tasks_router",
|
||||
"files_router",
|
||||
]
|
||||
|
||||
125
api/routers/content.py
Normal file
125
api/routers/content.py
Normal file
@@ -0,0 +1,125 @@
|
||||
"""
|
||||
Content generation endpoints
|
||||
|
||||
Endpoints for generating narrations, image prompts, and titles.
|
||||
"""
|
||||
|
||||
from fastapi import APIRouter, HTTPException
|
||||
from loguru import logger
|
||||
|
||||
from api.dependencies import ReelForgeDep
|
||||
from api.schemas.content import (
|
||||
NarrationGenerateRequest,
|
||||
NarrationGenerateResponse,
|
||||
ImagePromptGenerateRequest,
|
||||
ImagePromptGenerateResponse,
|
||||
TitleGenerateRequest,
|
||||
TitleGenerateResponse,
|
||||
)
|
||||
|
||||
router = APIRouter(prefix="/content", tags=["Content Generation"])
|
||||
|
||||
|
||||
@router.post("/narration", response_model=NarrationGenerateResponse)
|
||||
async def generate_narration(
|
||||
request: NarrationGenerateRequest,
|
||||
reelforge: ReelForgeDep
|
||||
):
|
||||
"""
|
||||
Generate narrations from text
|
||||
|
||||
Uses LLM to break down text into multiple narration segments.
|
||||
|
||||
- **text**: Source text
|
||||
- **n_scenes**: Number of narrations to generate
|
||||
- **min_words**: Minimum words per narration
|
||||
- **max_words**: Maximum words per narration
|
||||
|
||||
Returns list of narration strings.
|
||||
"""
|
||||
try:
|
||||
logger.info(f"Generating {request.n_scenes} narrations from text")
|
||||
|
||||
# Call narration generator service
|
||||
narrations = await reelforge.narration_generator(
|
||||
text=request.text,
|
||||
n_scenes=request.n_scenes,
|
||||
min_words=request.min_words,
|
||||
max_words=request.max_words
|
||||
)
|
||||
|
||||
return NarrationGenerateResponse(
|
||||
narrations=narrations
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Narration generation error: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.post("/image-prompt", response_model=ImagePromptGenerateResponse)
|
||||
async def generate_image_prompt(
|
||||
request: ImagePromptGenerateRequest,
|
||||
reelforge: ReelForgeDep
|
||||
):
|
||||
"""
|
||||
Generate image prompts from narrations
|
||||
|
||||
Uses LLM to create detailed image generation prompts.
|
||||
|
||||
- **narrations**: List of narration texts
|
||||
- **min_words**: Minimum words per prompt
|
||||
- **max_words**: Maximum words per prompt
|
||||
|
||||
Returns list of image prompts.
|
||||
"""
|
||||
try:
|
||||
logger.info(f"Generating image prompts for {len(request.narrations)} narrations")
|
||||
|
||||
# Call image prompt generator service
|
||||
image_prompts = await reelforge.image_prompt_generator(
|
||||
narrations=request.narrations,
|
||||
min_words=request.min_words,
|
||||
max_words=request.max_words
|
||||
)
|
||||
|
||||
return ImagePromptGenerateResponse(
|
||||
image_prompts=image_prompts
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Image prompt generation error: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.post("/title", response_model=TitleGenerateResponse)
|
||||
async def generate_title(
|
||||
request: TitleGenerateRequest,
|
||||
reelforge: ReelForgeDep
|
||||
):
|
||||
"""
|
||||
Generate video title from text
|
||||
|
||||
Uses LLM to create an engaging title.
|
||||
|
||||
- **text**: Source text
|
||||
- **style**: Optional title style hint
|
||||
|
||||
Returns generated title.
|
||||
"""
|
||||
try:
|
||||
logger.info("Generating title from text")
|
||||
|
||||
# Call title generator service
|
||||
title = await reelforge.title_generator(
|
||||
text=request.text
|
||||
)
|
||||
|
||||
return TitleGenerateResponse(
|
||||
title=title
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Title generation error: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
72
api/routers/files.py
Normal file
72
api/routers/files.py
Normal file
@@ -0,0 +1,72 @@
|
||||
"""
|
||||
File service endpoints
|
||||
|
||||
Provides access to generated files (videos, images, audio).
|
||||
"""
|
||||
|
||||
from pathlib import Path
|
||||
from fastapi import APIRouter, HTTPException
|
||||
from fastapi.responses import FileResponse
|
||||
from loguru import logger
|
||||
|
||||
router = APIRouter(prefix="/files", tags=["Files"])
|
||||
|
||||
|
||||
@router.get("/{file_path:path}")
|
||||
async def get_file(file_path: str):
|
||||
"""
|
||||
Get file by path
|
||||
|
||||
Serves files from the output directory only.
|
||||
|
||||
- **file_path**: File name or path (e.g., "abc123.mp4" or "subfolder/abc123.mp4")
|
||||
|
||||
Returns file for download or preview.
|
||||
"""
|
||||
try:
|
||||
# Automatically prepend "output/" to the path
|
||||
full_path = f"output/{file_path}"
|
||||
abs_path = Path.cwd() / full_path
|
||||
|
||||
if not abs_path.exists():
|
||||
raise HTTPException(status_code=404, detail=f"File not found: {file_path}")
|
||||
|
||||
if not abs_path.is_file():
|
||||
raise HTTPException(status_code=400, detail=f"Path is not a file: {file_path}")
|
||||
|
||||
# Security: only allow access to output directory
|
||||
try:
|
||||
rel_path = abs_path.relative_to(Path.cwd())
|
||||
if not str(rel_path).startswith("output"):
|
||||
raise HTTPException(status_code=403, detail="Access denied: only output directory is accessible")
|
||||
except ValueError:
|
||||
raise HTTPException(status_code=403, detail="Access denied")
|
||||
|
||||
# Determine media type
|
||||
suffix = abs_path.suffix.lower()
|
||||
media_types = {
|
||||
'.mp4': 'video/mp4',
|
||||
'.mp3': 'audio/mpeg',
|
||||
'.wav': 'audio/wav',
|
||||
'.png': 'image/png',
|
||||
'.jpg': 'image/jpeg',
|
||||
'.jpeg': 'image/jpeg',
|
||||
'.gif': 'image/gif',
|
||||
}
|
||||
media_type = media_types.get(suffix, 'application/octet-stream')
|
||||
|
||||
# Use inline disposition for browser preview
|
||||
return FileResponse(
|
||||
path=str(abs_path),
|
||||
media_type=media_type,
|
||||
headers={
|
||||
"Content-Disposition": f'inline; filename="{abs_path.name}"'
|
||||
}
|
||||
)
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"File access error: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
42
api/routers/health.py
Normal file
42
api/routers/health.py
Normal file
@@ -0,0 +1,42 @@
|
||||
"""
|
||||
Health check and system info endpoints
|
||||
"""
|
||||
|
||||
from fastapi import APIRouter
|
||||
from pydantic import BaseModel
|
||||
|
||||
router = APIRouter(tags=["Health"])
|
||||
|
||||
|
||||
class HealthResponse(BaseModel):
|
||||
"""Health check response"""
|
||||
status: str = "healthy"
|
||||
version: str = "0.1.0"
|
||||
service: str = "ReelForge API"
|
||||
|
||||
|
||||
class CapabilitiesResponse(BaseModel):
|
||||
"""Capabilities response"""
|
||||
success: bool = True
|
||||
capabilities: dict
|
||||
|
||||
|
||||
@router.get("/health", response_model=HealthResponse)
|
||||
async def health_check():
|
||||
"""
|
||||
Health check endpoint
|
||||
|
||||
Returns service status and version information.
|
||||
"""
|
||||
return HealthResponse()
|
||||
|
||||
|
||||
@router.get("/version", response_model=HealthResponse)
|
||||
async def get_version():
|
||||
"""
|
||||
Get API version
|
||||
|
||||
Returns version information.
|
||||
"""
|
||||
return HealthResponse()
|
||||
|
||||
49
api/routers/image.py
Normal file
49
api/routers/image.py
Normal file
@@ -0,0 +1,49 @@
|
||||
"""
|
||||
Image generation endpoints
|
||||
"""
|
||||
|
||||
from fastapi import APIRouter, HTTPException
|
||||
from loguru import logger
|
||||
|
||||
from api.dependencies import ReelForgeDep
|
||||
from api.schemas.image import ImageGenerateRequest, ImageGenerateResponse
|
||||
|
||||
router = APIRouter(prefix="/image", tags=["Image"])
|
||||
|
||||
|
||||
@router.post("/generate", response_model=ImageGenerateResponse)
|
||||
async def image_generate(
|
||||
request: ImageGenerateRequest,
|
||||
reelforge: ReelForgeDep
|
||||
):
|
||||
"""
|
||||
Image generation endpoint
|
||||
|
||||
Generate image from text prompt using ComfyKit.
|
||||
|
||||
- **prompt**: Image description/prompt
|
||||
- **width**: Image width (512-2048)
|
||||
- **height**: Image height (512-2048)
|
||||
- **workflow**: Optional custom workflow filename
|
||||
|
||||
Returns path to generated image.
|
||||
"""
|
||||
try:
|
||||
logger.info(f"Image generation request: {request.prompt[:50]}...")
|
||||
|
||||
# Call image service
|
||||
image_path = await reelforge.image(
|
||||
prompt=request.prompt,
|
||||
width=request.width,
|
||||
height=request.height,
|
||||
workflow=request.workflow
|
||||
)
|
||||
|
||||
return ImageGenerateResponse(
|
||||
image_path=image_path
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Image generation error: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
48
api/routers/llm.py
Normal file
48
api/routers/llm.py
Normal file
@@ -0,0 +1,48 @@
|
||||
"""
|
||||
LLM (Large Language Model) endpoints
|
||||
"""
|
||||
|
||||
from fastapi import APIRouter, HTTPException
|
||||
from loguru import logger
|
||||
|
||||
from api.dependencies import ReelForgeDep
|
||||
from api.schemas.llm import LLMChatRequest, LLMChatResponse
|
||||
|
||||
router = APIRouter(prefix="/llm", tags=["LLM"])
|
||||
|
||||
|
||||
@router.post("/chat", response_model=LLMChatResponse)
|
||||
async def llm_chat(
|
||||
request: LLMChatRequest,
|
||||
reelforge: ReelForgeDep
|
||||
):
|
||||
"""
|
||||
LLM chat endpoint
|
||||
|
||||
Generate text response using configured LLM.
|
||||
|
||||
- **prompt**: User prompt/question
|
||||
- **temperature**: Creativity level (0.0-2.0, lower = more deterministic)
|
||||
- **max_tokens**: Maximum response length
|
||||
|
||||
Returns generated text response.
|
||||
"""
|
||||
try:
|
||||
logger.info(f"LLM chat request: {request.prompt[:50]}...")
|
||||
|
||||
# Call LLM service
|
||||
response = await reelforge.llm(
|
||||
prompt=request.prompt,
|
||||
temperature=request.temperature,
|
||||
max_tokens=request.max_tokens
|
||||
)
|
||||
|
||||
return LLMChatResponse(
|
||||
content=response,
|
||||
tokens_used=None # Can add token counting if needed
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"LLM chat error: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
93
api/routers/tasks.py
Normal file
93
api/routers/tasks.py
Normal file
@@ -0,0 +1,93 @@
|
||||
"""
|
||||
Task management endpoints
|
||||
|
||||
Endpoints for managing async tasks (checking status, canceling, etc.)
|
||||
"""
|
||||
|
||||
from typing import List, Optional
|
||||
from fastapi import APIRouter, HTTPException, Query
|
||||
from loguru import logger
|
||||
|
||||
from api.tasks import task_manager, Task, TaskStatus
|
||||
|
||||
router = APIRouter(prefix="/tasks", tags=["Tasks"])
|
||||
|
||||
|
||||
@router.get("", response_model=List[Task])
|
||||
async def list_tasks(
|
||||
status: Optional[TaskStatus] = Query(None, description="Filter by status"),
|
||||
limit: int = Query(100, ge=1, le=1000, description="Maximum number of tasks")
|
||||
):
|
||||
"""
|
||||
List tasks
|
||||
|
||||
Retrieve list of tasks with optional filtering.
|
||||
|
||||
- **status**: Optional filter by status (pending/running/completed/failed/cancelled)
|
||||
- **limit**: Maximum number of tasks to return (default 100)
|
||||
|
||||
Returns list of tasks sorted by creation time (newest first).
|
||||
"""
|
||||
try:
|
||||
tasks = task_manager.list_tasks(status=status, limit=limit)
|
||||
return tasks
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"List tasks error: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.get("/{task_id}", response_model=Task)
|
||||
async def get_task(task_id: str):
|
||||
"""
|
||||
Get task details
|
||||
|
||||
Retrieve detailed information about a specific task.
|
||||
|
||||
- **task_id**: Task ID
|
||||
|
||||
Returns task details including status, progress, and result (if completed).
|
||||
"""
|
||||
try:
|
||||
task = task_manager.get_task(task_id)
|
||||
|
||||
if not task:
|
||||
raise HTTPException(status_code=404, detail=f"Task {task_id} not found")
|
||||
|
||||
return task
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Get task error: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.delete("/{task_id}")
|
||||
async def cancel_task(task_id: str):
|
||||
"""
|
||||
Cancel task
|
||||
|
||||
Cancel a running or pending task.
|
||||
|
||||
- **task_id**: Task ID
|
||||
|
||||
Returns success status.
|
||||
"""
|
||||
try:
|
||||
success = task_manager.cancel_task(task_id)
|
||||
|
||||
if not success:
|
||||
raise HTTPException(status_code=404, detail=f"Task {task_id} not found")
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"message": f"Task {task_id} cancelled successfully"
|
||||
}
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Cancel task error: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
50
api/routers/tts.py
Normal file
50
api/routers/tts.py
Normal file
@@ -0,0 +1,50 @@
|
||||
"""
|
||||
TTS (Text-to-Speech) endpoints
|
||||
"""
|
||||
|
||||
from fastapi import APIRouter, HTTPException
|
||||
from loguru import logger
|
||||
|
||||
from api.dependencies import ReelForgeDep
|
||||
from api.schemas.tts import TTSSynthesizeRequest, TTSSynthesizeResponse
|
||||
from reelforge.utils.tts_util import get_audio_duration
|
||||
|
||||
router = APIRouter(prefix="/tts", tags=["TTS"])
|
||||
|
||||
|
||||
@router.post("/synthesize", response_model=TTSSynthesizeResponse)
|
||||
async def tts_synthesize(
|
||||
request: TTSSynthesizeRequest,
|
||||
reelforge: ReelForgeDep
|
||||
):
|
||||
"""
|
||||
Text-to-Speech synthesis endpoint
|
||||
|
||||
Convert text to speech audio.
|
||||
|
||||
- **text**: Text to synthesize
|
||||
- **voice_id**: Voice ID (e.g., 'zh-CN-YunjianNeural', 'en-US-AriaNeural')
|
||||
|
||||
Returns path to generated audio file and duration.
|
||||
"""
|
||||
try:
|
||||
logger.info(f"TTS synthesis request: {request.text[:50]}...")
|
||||
|
||||
# Call TTS service
|
||||
audio_path = await reelforge.tts(
|
||||
text=request.text,
|
||||
voice_id=request.voice_id
|
||||
)
|
||||
|
||||
# Get audio duration
|
||||
duration = get_audio_duration(audio_path)
|
||||
|
||||
return TTSSynthesizeResponse(
|
||||
audio_path=audio_path,
|
||||
duration=duration
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"TTS synthesis error: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
180
api/routers/video.py
Normal file
180
api/routers/video.py
Normal file
@@ -0,0 +1,180 @@
|
||||
"""
|
||||
Video generation endpoints
|
||||
|
||||
Supports both synchronous and asynchronous video generation.
|
||||
"""
|
||||
|
||||
import os
|
||||
from fastapi import APIRouter, HTTPException, Request
|
||||
from loguru import logger
|
||||
|
||||
from api.dependencies import ReelForgeDep
|
||||
from api.schemas.video import (
|
||||
VideoGenerateRequest,
|
||||
VideoGenerateResponse,
|
||||
VideoGenerateAsyncResponse,
|
||||
)
|
||||
from api.tasks import task_manager, TaskType
|
||||
|
||||
router = APIRouter(prefix="/video", tags=["Video Generation"])
|
||||
|
||||
|
||||
def path_to_url(request: Request, file_path: str) -> str:
|
||||
"""Convert file path to accessible URL"""
|
||||
# file_path is like "output/abc123.mp4"
|
||||
# Remove "output/" prefix for cleaner URL
|
||||
if file_path.startswith("output/"):
|
||||
file_path = file_path[7:] # Remove "output/"
|
||||
base_url = str(request.base_url).rstrip('/')
|
||||
return f"{base_url}/api/files/{file_path}"
|
||||
|
||||
|
||||
@router.post("/generate/sync", response_model=VideoGenerateResponse)
|
||||
async def generate_video_sync(
|
||||
request_body: VideoGenerateRequest,
|
||||
reelforge: ReelForgeDep,
|
||||
request: Request
|
||||
):
|
||||
"""
|
||||
Generate video synchronously
|
||||
|
||||
This endpoint blocks until video generation is complete.
|
||||
Suitable for small videos (< 30 seconds).
|
||||
|
||||
**Note**: May timeout for large videos. Use `/generate/async` instead.
|
||||
|
||||
Request body includes all video generation parameters.
|
||||
See VideoGenerateRequest schema for details.
|
||||
|
||||
Returns path to generated video, duration, and file size.
|
||||
"""
|
||||
try:
|
||||
logger.info(f"Sync video generation: {request_body.text[:50]}...")
|
||||
|
||||
# Call video generator service
|
||||
result = await reelforge.generate_video(
|
||||
text=request_body.text,
|
||||
mode=request_body.mode,
|
||||
title=request_body.title,
|
||||
n_scenes=request_body.n_scenes,
|
||||
voice_id=request_body.voice_id,
|
||||
use_uuid_filename=True, # API mode: use UUID filename
|
||||
min_narration_words=request_body.min_narration_words,
|
||||
max_narration_words=request_body.max_narration_words,
|
||||
min_image_prompt_words=request_body.min_image_prompt_words,
|
||||
max_image_prompt_words=request_body.max_image_prompt_words,
|
||||
image_width=request_body.image_width,
|
||||
image_height=request_body.image_height,
|
||||
image_workflow=request_body.image_workflow,
|
||||
video_width=request_body.video_width,
|
||||
video_height=request_body.video_height,
|
||||
video_fps=request_body.video_fps,
|
||||
frame_template=request_body.frame_template,
|
||||
prompt_prefix=request_body.prompt_prefix,
|
||||
bgm_path=request_body.bgm_path,
|
||||
bgm_volume=request_body.bgm_volume,
|
||||
)
|
||||
|
||||
# Get file size
|
||||
file_size = os.path.getsize(result.video_path) if os.path.exists(result.video_path) else 0
|
||||
|
||||
# Convert path to URL
|
||||
video_url = path_to_url(request, result.video_path)
|
||||
|
||||
return VideoGenerateResponse(
|
||||
video_url=video_url,
|
||||
duration=result.duration,
|
||||
file_size=file_size
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Sync video generation error: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.post("/generate/async", response_model=VideoGenerateAsyncResponse)
|
||||
async def generate_video_async(
|
||||
request_body: VideoGenerateRequest,
|
||||
reelforge: ReelForgeDep,
|
||||
request: Request
|
||||
):
|
||||
"""
|
||||
Generate video asynchronously
|
||||
|
||||
Creates a background task for video generation.
|
||||
Returns immediately with a task_id for tracking progress.
|
||||
|
||||
**Workflow:**
|
||||
1. Submit video generation request
|
||||
2. Receive task_id in response
|
||||
3. Poll `/api/tasks/{task_id}` to check status
|
||||
4. When status is "completed", retrieve video from result
|
||||
|
||||
Request body includes all video generation parameters.
|
||||
See VideoGenerateRequest schema for details.
|
||||
|
||||
Returns task_id for tracking progress.
|
||||
"""
|
||||
try:
|
||||
logger.info(f"Async video generation: {request_body.text[:50]}...")
|
||||
|
||||
# Create task
|
||||
task = task_manager.create_task(
|
||||
task_type=TaskType.VIDEO_GENERATION,
|
||||
request_params=request_body.model_dump()
|
||||
)
|
||||
|
||||
# Define async execution function
|
||||
async def execute_video_generation():
|
||||
"""Execute video generation in background"""
|
||||
result = await reelforge.generate_video(
|
||||
text=request_body.text,
|
||||
mode=request_body.mode,
|
||||
title=request_body.title,
|
||||
n_scenes=request_body.n_scenes,
|
||||
voice_id=request_body.voice_id,
|
||||
use_uuid_filename=True, # API mode: use UUID filename
|
||||
min_narration_words=request_body.min_narration_words,
|
||||
max_narration_words=request_body.max_narration_words,
|
||||
min_image_prompt_words=request_body.min_image_prompt_words,
|
||||
max_image_prompt_words=request_body.max_image_prompt_words,
|
||||
image_width=request_body.image_width,
|
||||
image_height=request_body.image_height,
|
||||
image_workflow=request_body.image_workflow,
|
||||
video_width=request_body.video_width,
|
||||
video_height=request_body.video_height,
|
||||
video_fps=request_body.video_fps,
|
||||
frame_template=request_body.frame_template,
|
||||
prompt_prefix=request_body.prompt_prefix,
|
||||
bgm_path=request_body.bgm_path,
|
||||
bgm_volume=request_body.bgm_volume,
|
||||
# Progress callback can be added here if needed
|
||||
# progress_callback=lambda event: task_manager.update_progress(...)
|
||||
)
|
||||
|
||||
# Get file size
|
||||
file_size = os.path.getsize(result.video_path) if os.path.exists(result.video_path) else 0
|
||||
|
||||
# Convert path to URL
|
||||
video_url = path_to_url(request, result.video_path)
|
||||
|
||||
return {
|
||||
"video_url": video_url,
|
||||
"duration": result.duration,
|
||||
"file_size": file_size
|
||||
}
|
||||
|
||||
# Start execution
|
||||
await task_manager.execute_task(
|
||||
task_id=task.task_id,
|
||||
coro_func=execute_video_generation
|
||||
)
|
||||
|
||||
return VideoGenerateAsyncResponse(
|
||||
task_id=task.task_id
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Async video generation error: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
48
api/schemas/__init__.py
Normal file
48
api/schemas/__init__.py
Normal file
@@ -0,0 +1,48 @@
|
||||
"""
|
||||
API Schemas (Pydantic models)
|
||||
"""
|
||||
|
||||
from api.schemas.base import BaseResponse, ErrorResponse
|
||||
from api.schemas.llm import LLMChatRequest, LLMChatResponse
|
||||
from api.schemas.tts import TTSSynthesizeRequest, TTSSynthesizeResponse
|
||||
from api.schemas.image import ImageGenerateRequest, ImageGenerateResponse
|
||||
from api.schemas.content import (
|
||||
NarrationGenerateRequest,
|
||||
NarrationGenerateResponse,
|
||||
ImagePromptGenerateRequest,
|
||||
ImagePromptGenerateResponse,
|
||||
TitleGenerateRequest,
|
||||
TitleGenerateResponse,
|
||||
)
|
||||
from api.schemas.video import (
|
||||
VideoGenerateRequest,
|
||||
VideoGenerateResponse,
|
||||
VideoGenerateAsyncResponse,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
# Base
|
||||
"BaseResponse",
|
||||
"ErrorResponse",
|
||||
# LLM
|
||||
"LLMChatRequest",
|
||||
"LLMChatResponse",
|
||||
# TTS
|
||||
"TTSSynthesizeRequest",
|
||||
"TTSSynthesizeResponse",
|
||||
# Image
|
||||
"ImageGenerateRequest",
|
||||
"ImageGenerateResponse",
|
||||
# Content
|
||||
"NarrationGenerateRequest",
|
||||
"NarrationGenerateResponse",
|
||||
"ImagePromptGenerateRequest",
|
||||
"ImagePromptGenerateResponse",
|
||||
"TitleGenerateRequest",
|
||||
"TitleGenerateResponse",
|
||||
# Video
|
||||
"VideoGenerateRequest",
|
||||
"VideoGenerateResponse",
|
||||
"VideoGenerateAsyncResponse",
|
||||
]
|
||||
|
||||
21
api/schemas/base.py
Normal file
21
api/schemas/base.py
Normal file
@@ -0,0 +1,21 @@
|
||||
"""
|
||||
Base schemas
|
||||
"""
|
||||
|
||||
from typing import Any, Optional
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class BaseResponse(BaseModel):
|
||||
"""Base API response"""
|
||||
success: bool = True
|
||||
message: str = "Success"
|
||||
data: Optional[Any] = None
|
||||
|
||||
|
||||
class ErrorResponse(BaseModel):
|
||||
"""Error response"""
|
||||
success: bool = False
|
||||
message: str
|
||||
error: Optional[str] = None
|
||||
|
||||
91
api/schemas/content.py
Normal file
91
api/schemas/content.py
Normal file
@@ -0,0 +1,91 @@
|
||||
"""
|
||||
Content generation API schemas
|
||||
"""
|
||||
|
||||
from typing import List, Optional
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Narration Generation
|
||||
# ============================================================================
|
||||
|
||||
class NarrationGenerateRequest(BaseModel):
|
||||
"""Narration generation request"""
|
||||
text: str = Field(..., description="Source text to generate narrations from")
|
||||
n_scenes: int = Field(5, ge=1, le=20, description="Number of scenes")
|
||||
min_words: int = Field(5, ge=1, le=100, description="Minimum words per narration")
|
||||
max_words: int = Field(20, ge=1, le=200, description="Maximum words per narration")
|
||||
|
||||
class Config:
|
||||
json_schema_extra = {
|
||||
"example": {
|
||||
"text": "Atomic Habits is about making small changes that lead to remarkable results.",
|
||||
"n_scenes": 5,
|
||||
"min_words": 5,
|
||||
"max_words": 20
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
class NarrationGenerateResponse(BaseModel):
|
||||
"""Narration generation response"""
|
||||
success: bool = True
|
||||
message: str = "Success"
|
||||
narrations: List[str] = Field(..., description="Generated narrations")
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Image Prompt Generation
|
||||
# ============================================================================
|
||||
|
||||
class ImagePromptGenerateRequest(BaseModel):
|
||||
"""Image prompt generation request"""
|
||||
narrations: List[str] = Field(..., description="List of narrations")
|
||||
min_words: int = Field(30, ge=10, le=100, description="Minimum words per prompt")
|
||||
max_words: int = Field(60, ge=10, le=200, description="Maximum words per prompt")
|
||||
|
||||
class Config:
|
||||
json_schema_extra = {
|
||||
"example": {
|
||||
"narrations": [
|
||||
"Small habits compound over time",
|
||||
"Focus on systems, not goals"
|
||||
],
|
||||
"min_words": 30,
|
||||
"max_words": 60
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
class ImagePromptGenerateResponse(BaseModel):
|
||||
"""Image prompt generation response"""
|
||||
success: bool = True
|
||||
message: str = "Success"
|
||||
image_prompts: List[str] = Field(..., description="Generated image prompts")
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Title Generation
|
||||
# ============================================================================
|
||||
|
||||
class TitleGenerateRequest(BaseModel):
|
||||
"""Title generation request"""
|
||||
text: str = Field(..., description="Source text")
|
||||
style: Optional[str] = Field(None, description="Title style (e.g., 'engaging', 'formal')")
|
||||
|
||||
class Config:
|
||||
json_schema_extra = {
|
||||
"example": {
|
||||
"text": "Atomic Habits is about making small changes that lead to remarkable results.",
|
||||
"style": "engaging"
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
class TitleGenerateResponse(BaseModel):
|
||||
"""Title generation response"""
|
||||
success: bool = True
|
||||
message: str = "Success"
|
||||
title: str = Field(..., description="Generated title")
|
||||
|
||||
31
api/schemas/image.py
Normal file
31
api/schemas/image.py
Normal file
@@ -0,0 +1,31 @@
|
||||
"""
|
||||
Image generation API schemas
|
||||
"""
|
||||
|
||||
from typing import Optional
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class ImageGenerateRequest(BaseModel):
|
||||
"""Image generation request"""
|
||||
prompt: str = Field(..., description="Image generation prompt")
|
||||
width: int = Field(1024, ge=512, le=2048, description="Image width")
|
||||
height: int = Field(1024, ge=512, le=2048, description="Image height")
|
||||
workflow: Optional[str] = Field(None, description="Custom workflow filename")
|
||||
|
||||
class Config:
|
||||
json_schema_extra = {
|
||||
"example": {
|
||||
"prompt": "A serene mountain landscape at sunset, photorealistic style",
|
||||
"width": 1024,
|
||||
"height": 1024
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
class ImageGenerateResponse(BaseModel):
|
||||
"""Image generation response"""
|
||||
success: bool = True
|
||||
message: str = "Success"
|
||||
image_path: str = Field(..., description="Path to generated image")
|
||||
|
||||
31
api/schemas/llm.py
Normal file
31
api/schemas/llm.py
Normal file
@@ -0,0 +1,31 @@
|
||||
"""
|
||||
LLM API schemas
|
||||
"""
|
||||
|
||||
from typing import Optional
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class LLMChatRequest(BaseModel):
|
||||
"""LLM chat request"""
|
||||
prompt: str = Field(..., description="User prompt")
|
||||
temperature: float = Field(0.7, ge=0.0, le=2.0, description="Temperature (0.0-2.0)")
|
||||
max_tokens: int = Field(2000, ge=1, le=32000, description="Maximum tokens")
|
||||
|
||||
class Config:
|
||||
json_schema_extra = {
|
||||
"example": {
|
||||
"prompt": "Explain the concept of atomic habits in 3 sentences",
|
||||
"temperature": 0.7,
|
||||
"max_tokens": 2000
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
class LLMChatResponse(BaseModel):
|
||||
"""LLM chat response"""
|
||||
success: bool = True
|
||||
message: str = "Success"
|
||||
content: str = Field(..., description="Generated response")
|
||||
tokens_used: Optional[int] = Field(None, description="Tokens used (if available)")
|
||||
|
||||
28
api/schemas/tts.py
Normal file
28
api/schemas/tts.py
Normal file
@@ -0,0 +1,28 @@
|
||||
"""
|
||||
TTS API schemas
|
||||
"""
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class TTSSynthesizeRequest(BaseModel):
|
||||
"""TTS synthesis request"""
|
||||
text: str = Field(..., description="Text to synthesize")
|
||||
voice_id: str = Field("zh-CN-YunjianNeural", description="Voice ID")
|
||||
|
||||
class Config:
|
||||
json_schema_extra = {
|
||||
"example": {
|
||||
"text": "Hello, welcome to ReelForge!",
|
||||
"voice_id": "zh-CN-YunjianNeural"
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
class TTSSynthesizeResponse(BaseModel):
|
||||
"""TTS synthesis response"""
|
||||
success: bool = True
|
||||
message: str = "Success"
|
||||
audio_path: str = Field(..., description="Path to generated audio file")
|
||||
duration: float = Field(..., description="Audio duration in seconds")
|
||||
|
||||
80
api/schemas/video.py
Normal file
80
api/schemas/video.py
Normal file
@@ -0,0 +1,80 @@
|
||||
"""
|
||||
Video generation API schemas
|
||||
"""
|
||||
|
||||
from typing import Optional, Literal
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class VideoGenerateRequest(BaseModel):
|
||||
"""Video generation request"""
|
||||
|
||||
# === Input ===
|
||||
text: str = Field(..., description="Source text for video generation")
|
||||
|
||||
# === Processing Mode ===
|
||||
mode: Literal["generate", "fixed"] = Field(
|
||||
"generate",
|
||||
description="Processing mode: 'generate' (AI generates narrations) or 'fixed' (use text as-is)"
|
||||
)
|
||||
|
||||
# === Optional Title ===
|
||||
title: Optional[str] = Field(None, description="Video title (auto-generated if not provided)")
|
||||
|
||||
# === Basic Config ===
|
||||
n_scenes: int = Field(5, ge=1, le=20, description="Number of scenes (generate mode only)")
|
||||
voice_id: str = Field("zh-CN-YunjianNeural", description="TTS voice ID")
|
||||
|
||||
# === LLM Parameters ===
|
||||
min_narration_words: int = Field(5, ge=1, le=100, description="Min narration words")
|
||||
max_narration_words: int = Field(20, ge=1, le=200, description="Max narration words")
|
||||
min_image_prompt_words: int = Field(30, ge=10, le=100, description="Min image prompt words")
|
||||
max_image_prompt_words: int = Field(60, ge=10, le=200, description="Max image prompt words")
|
||||
|
||||
# === Image Parameters ===
|
||||
image_width: int = Field(1024, ge=512, le=2048, description="Image width")
|
||||
image_height: int = Field(1024, ge=512, le=2048, description="Image height")
|
||||
image_workflow: Optional[str] = Field(None, description="Custom image workflow")
|
||||
|
||||
# === Video Parameters ===
|
||||
video_width: int = Field(1080, ge=512, le=3840, description="Video width")
|
||||
video_height: int = Field(1920, ge=512, le=3840, description="Video height")
|
||||
video_fps: int = Field(30, ge=15, le=60, description="Video FPS")
|
||||
|
||||
# === Frame Template ===
|
||||
frame_template: Optional[str] = Field(None, description="HTML template name (e.g., 'default.html')")
|
||||
|
||||
# === Image Style ===
|
||||
prompt_prefix: Optional[str] = Field(None, description="Image style prefix")
|
||||
|
||||
# === BGM ===
|
||||
bgm_path: Optional[str] = Field(None, description="Background music path")
|
||||
bgm_volume: float = Field(0.3, ge=0.0, le=1.0, description="BGM volume (0.0-1.0)")
|
||||
|
||||
class Config:
|
||||
json_schema_extra = {
|
||||
"example": {
|
||||
"text": "Atomic Habits teaches us that small changes compound over time to produce remarkable results.",
|
||||
"mode": "generate",
|
||||
"n_scenes": 5,
|
||||
"voice_id": "zh-CN-YunjianNeural",
|
||||
"title": "The Power of Atomic Habits"
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
class VideoGenerateResponse(BaseModel):
|
||||
"""Video generation response (synchronous)"""
|
||||
success: bool = True
|
||||
message: str = "Success"
|
||||
video_url: str = Field(..., description="URL to access generated video")
|
||||
duration: float = Field(..., description="Video duration in seconds")
|
||||
file_size: int = Field(..., description="File size in bytes")
|
||||
|
||||
|
||||
class VideoGenerateAsyncResponse(BaseModel):
|
||||
"""Video generation async response"""
|
||||
success: bool = True
|
||||
message: str = "Task created successfully"
|
||||
task_id: str = Field(..., description="Task ID for tracking progress")
|
||||
|
||||
9
api/tasks/__init__.py
Normal file
9
api/tasks/__init__.py
Normal file
@@ -0,0 +1,9 @@
|
||||
"""
|
||||
Task management for async operations
|
||||
"""
|
||||
|
||||
from api.tasks.models import Task, TaskStatus, TaskType
|
||||
from api.tasks.manager import task_manager
|
||||
|
||||
__all__ = ["Task", "TaskStatus", "TaskType", "task_manager"]
|
||||
|
||||
254
api/tasks/manager.py
Normal file
254
api/tasks/manager.py
Normal file
@@ -0,0 +1,254 @@
|
||||
"""
|
||||
Task Manager
|
||||
|
||||
In-memory task management for video generation jobs.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import uuid
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Dict, List, Optional, Callable
|
||||
from loguru import logger
|
||||
|
||||
from api.tasks.models import Task, TaskStatus, TaskType, TaskProgress
|
||||
from api.config import api_config
|
||||
|
||||
|
||||
class TaskManager:
|
||||
"""
|
||||
Task manager for handling async video generation tasks
|
||||
|
||||
Features:
|
||||
- In-memory storage (can be replaced with Redis later)
|
||||
- Task lifecycle management
|
||||
- Progress tracking
|
||||
- Auto cleanup of old tasks
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self._tasks: Dict[str, Task] = {}
|
||||
self._task_futures: Dict[str, asyncio.Task] = {}
|
||||
self._cleanup_task: Optional[asyncio.Task] = None
|
||||
self._running = False
|
||||
|
||||
async def start(self):
|
||||
"""Start task manager and cleanup scheduler"""
|
||||
if self._running:
|
||||
logger.warning("Task manager already running")
|
||||
return
|
||||
|
||||
self._running = True
|
||||
self._cleanup_task = asyncio.create_task(self._cleanup_loop())
|
||||
logger.info("✅ Task manager started")
|
||||
|
||||
async def stop(self):
|
||||
"""Stop task manager and cancel all tasks"""
|
||||
self._running = False
|
||||
|
||||
# Cancel cleanup task
|
||||
if self._cleanup_task:
|
||||
self._cleanup_task.cancel()
|
||||
try:
|
||||
await self._cleanup_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
# Cancel all running tasks
|
||||
for task_id, future in self._task_futures.items():
|
||||
if not future.done():
|
||||
future.cancel()
|
||||
logger.info(f"Cancelled task: {task_id}")
|
||||
|
||||
self._tasks.clear()
|
||||
self._task_futures.clear()
|
||||
logger.info("✅ Task manager stopped")
|
||||
|
||||
def create_task(
|
||||
self,
|
||||
task_type: TaskType,
|
||||
request_params: Optional[dict] = None
|
||||
) -> Task:
|
||||
"""
|
||||
Create a new task
|
||||
|
||||
Args:
|
||||
task_type: Type of task
|
||||
request_params: Original request parameters
|
||||
|
||||
Returns:
|
||||
Created task
|
||||
"""
|
||||
task_id = str(uuid.uuid4())
|
||||
task = Task(
|
||||
task_id=task_id,
|
||||
task_type=task_type,
|
||||
status=TaskStatus.PENDING,
|
||||
request_params=request_params,
|
||||
)
|
||||
|
||||
self._tasks[task_id] = task
|
||||
logger.info(f"Created task {task_id} ({task_type})")
|
||||
return task
|
||||
|
||||
async def execute_task(
|
||||
self,
|
||||
task_id: str,
|
||||
coro_func: Callable,
|
||||
*args,
|
||||
**kwargs
|
||||
):
|
||||
"""
|
||||
Execute task asynchronously
|
||||
|
||||
Args:
|
||||
task_id: Task ID
|
||||
coro_func: Async function to execute
|
||||
*args: Positional arguments
|
||||
**kwargs: Keyword arguments
|
||||
"""
|
||||
task = self._tasks.get(task_id)
|
||||
if not task:
|
||||
logger.error(f"Task {task_id} not found")
|
||||
return
|
||||
|
||||
# Create async task
|
||||
async def _execute():
|
||||
try:
|
||||
task.status = TaskStatus.RUNNING
|
||||
task.started_at = datetime.now()
|
||||
logger.info(f"Task {task_id} started")
|
||||
|
||||
# Execute the actual work
|
||||
result = await coro_func(*args, **kwargs)
|
||||
|
||||
# Update task with result
|
||||
task.status = TaskStatus.COMPLETED
|
||||
task.result = result
|
||||
task.completed_at = datetime.now()
|
||||
logger.info(f"Task {task_id} completed")
|
||||
|
||||
except Exception as e:
|
||||
task.status = TaskStatus.FAILED
|
||||
task.error = str(e)
|
||||
task.completed_at = datetime.now()
|
||||
logger.error(f"Task {task_id} failed: {e}")
|
||||
|
||||
# Start execution
|
||||
future = asyncio.create_task(_execute())
|
||||
self._task_futures[task_id] = future
|
||||
|
||||
def get_task(self, task_id: str) -> Optional[Task]:
|
||||
"""Get task by ID"""
|
||||
return self._tasks.get(task_id)
|
||||
|
||||
def list_tasks(
|
||||
self,
|
||||
status: Optional[TaskStatus] = None,
|
||||
limit: int = 100
|
||||
) -> List[Task]:
|
||||
"""
|
||||
List tasks with optional filtering
|
||||
|
||||
Args:
|
||||
status: Filter by status
|
||||
limit: Maximum number of tasks to return
|
||||
|
||||
Returns:
|
||||
List of tasks
|
||||
"""
|
||||
tasks = list(self._tasks.values())
|
||||
|
||||
if status:
|
||||
tasks = [t for t in tasks if t.status == status]
|
||||
|
||||
# Sort by created_at descending
|
||||
tasks.sort(key=lambda t: t.created_at, reverse=True)
|
||||
|
||||
return tasks[:limit]
|
||||
|
||||
def update_progress(
|
||||
self,
|
||||
task_id: str,
|
||||
current: int,
|
||||
total: int,
|
||||
message: str = ""
|
||||
):
|
||||
"""
|
||||
Update task progress
|
||||
|
||||
Args:
|
||||
task_id: Task ID
|
||||
current: Current progress
|
||||
total: Total steps
|
||||
message: Progress message
|
||||
"""
|
||||
task = self._tasks.get(task_id)
|
||||
if not task:
|
||||
return
|
||||
|
||||
percentage = (current / total * 100) if total > 0 else 0
|
||||
task.progress = TaskProgress(
|
||||
current=current,
|
||||
total=total,
|
||||
percentage=percentage,
|
||||
message=message
|
||||
)
|
||||
|
||||
def cancel_task(self, task_id: str) -> bool:
|
||||
"""
|
||||
Cancel a running task
|
||||
|
||||
Args:
|
||||
task_id: Task ID
|
||||
|
||||
Returns:
|
||||
True if cancelled, False otherwise
|
||||
"""
|
||||
task = self._tasks.get(task_id)
|
||||
if not task:
|
||||
return False
|
||||
|
||||
# Cancel future if running
|
||||
future = self._task_futures.get(task_id)
|
||||
if future and not future.done():
|
||||
future.cancel()
|
||||
|
||||
# Update task status
|
||||
task.status = TaskStatus.CANCELLED
|
||||
task.completed_at = datetime.now()
|
||||
logger.info(f"Cancelled task {task_id}")
|
||||
return True
|
||||
|
||||
async def _cleanup_loop(self):
|
||||
"""Periodically clean up old completed tasks"""
|
||||
while self._running:
|
||||
try:
|
||||
await asyncio.sleep(api_config.task_cleanup_interval)
|
||||
self._cleanup_old_tasks()
|
||||
except asyncio.CancelledError:
|
||||
break
|
||||
except Exception as e:
|
||||
logger.error(f"Error in cleanup loop: {e}")
|
||||
|
||||
def _cleanup_old_tasks(self):
|
||||
"""Remove old completed/failed tasks"""
|
||||
cutoff_time = datetime.now() - timedelta(seconds=api_config.task_retention_time)
|
||||
|
||||
tasks_to_remove = []
|
||||
for task_id, task in self._tasks.items():
|
||||
if task.status in [TaskStatus.COMPLETED, TaskStatus.FAILED, TaskStatus.CANCELLED]:
|
||||
if task.completed_at and task.completed_at < cutoff_time:
|
||||
tasks_to_remove.append(task_id)
|
||||
|
||||
for task_id in tasks_to_remove:
|
||||
del self._tasks[task_id]
|
||||
if task_id in self._task_futures:
|
||||
del self._task_futures[task_id]
|
||||
|
||||
if tasks_to_remove:
|
||||
logger.info(f"Cleaned up {len(tasks_to_remove)} old tasks")
|
||||
|
||||
|
||||
# Global task manager instance
|
||||
task_manager = TaskManager()
|
||||
|
||||
58
api/tasks/models.py
Normal file
58
api/tasks/models.py
Normal file
@@ -0,0 +1,58 @@
|
||||
"""
|
||||
Task data models
|
||||
"""
|
||||
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
from typing import Any, Optional
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class TaskStatus(str, Enum):
|
||||
"""Task status"""
|
||||
PENDING = "pending"
|
||||
RUNNING = "running"
|
||||
COMPLETED = "completed"
|
||||
FAILED = "failed"
|
||||
CANCELLED = "cancelled"
|
||||
|
||||
|
||||
class TaskType(str, Enum):
|
||||
"""Task type"""
|
||||
VIDEO_GENERATION = "video_generation"
|
||||
|
||||
|
||||
class TaskProgress(BaseModel):
|
||||
"""Task progress information"""
|
||||
current: int = 0
|
||||
total: int = 0
|
||||
percentage: float = 0.0
|
||||
message: str = ""
|
||||
|
||||
|
||||
class Task(BaseModel):
|
||||
"""Task model"""
|
||||
task_id: str
|
||||
task_type: TaskType
|
||||
status: TaskStatus = TaskStatus.PENDING
|
||||
|
||||
# Progress tracking
|
||||
progress: Optional[TaskProgress] = None
|
||||
|
||||
# Result
|
||||
result: Optional[Any] = None
|
||||
error: Optional[str] = None
|
||||
|
||||
# Metadata
|
||||
created_at: datetime = Field(default_factory=datetime.now)
|
||||
started_at: Optional[datetime] = None
|
||||
completed_at: Optional[datetime] = None
|
||||
|
||||
# Request parameters (for reference)
|
||||
request_params: Optional[dict] = None
|
||||
|
||||
class Config:
|
||||
json_encoders = {
|
||||
datetime: lambda v: v.isoformat()
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user