255 lines
7.2 KiB
Python
255 lines
7.2 KiB
Python
"""
|
|
Task Manager
|
|
|
|
In-memory task management for video generation jobs.
|
|
"""
|
|
|
|
import asyncio
|
|
import uuid
|
|
from datetime import datetime, timedelta
|
|
from typing import Dict, List, Optional, Callable
|
|
from loguru import logger
|
|
|
|
from api.tasks.models import Task, TaskStatus, TaskType, TaskProgress
|
|
from api.config import api_config
|
|
|
|
|
|
class TaskManager:
|
|
"""
|
|
Task manager for handling async video generation tasks
|
|
|
|
Features:
|
|
- In-memory storage (can be replaced with Redis later)
|
|
- Task lifecycle management
|
|
- Progress tracking
|
|
- Auto cleanup of old tasks
|
|
"""
|
|
|
|
def __init__(self):
|
|
self._tasks: Dict[str, Task] = {}
|
|
self._task_futures: Dict[str, asyncio.Task] = {}
|
|
self._cleanup_task: Optional[asyncio.Task] = None
|
|
self._running = False
|
|
|
|
async def start(self):
|
|
"""Start task manager and cleanup scheduler"""
|
|
if self._running:
|
|
logger.warning("Task manager already running")
|
|
return
|
|
|
|
self._running = True
|
|
self._cleanup_task = asyncio.create_task(self._cleanup_loop())
|
|
logger.info("✅ Task manager started")
|
|
|
|
async def stop(self):
|
|
"""Stop task manager and cancel all tasks"""
|
|
self._running = False
|
|
|
|
# Cancel cleanup task
|
|
if self._cleanup_task:
|
|
self._cleanup_task.cancel()
|
|
try:
|
|
await self._cleanup_task
|
|
except asyncio.CancelledError:
|
|
pass
|
|
|
|
# Cancel all running tasks
|
|
for task_id, future in self._task_futures.items():
|
|
if not future.done():
|
|
future.cancel()
|
|
logger.info(f"Cancelled task: {task_id}")
|
|
|
|
self._tasks.clear()
|
|
self._task_futures.clear()
|
|
logger.info("✅ Task manager stopped")
|
|
|
|
def create_task(
|
|
self,
|
|
task_type: TaskType,
|
|
request_params: Optional[dict] = None
|
|
) -> Task:
|
|
"""
|
|
Create a new task
|
|
|
|
Args:
|
|
task_type: Type of task
|
|
request_params: Original request parameters
|
|
|
|
Returns:
|
|
Created task
|
|
"""
|
|
task_id = str(uuid.uuid4())
|
|
task = Task(
|
|
task_id=task_id,
|
|
task_type=task_type,
|
|
status=TaskStatus.PENDING,
|
|
request_params=request_params,
|
|
)
|
|
|
|
self._tasks[task_id] = task
|
|
logger.info(f"Created task {task_id} ({task_type})")
|
|
return task
|
|
|
|
async def execute_task(
|
|
self,
|
|
task_id: str,
|
|
coro_func: Callable,
|
|
*args,
|
|
**kwargs
|
|
):
|
|
"""
|
|
Execute task asynchronously
|
|
|
|
Args:
|
|
task_id: Task ID
|
|
coro_func: Async function to execute
|
|
*args: Positional arguments
|
|
**kwargs: Keyword arguments
|
|
"""
|
|
task = self._tasks.get(task_id)
|
|
if not task:
|
|
logger.error(f"Task {task_id} not found")
|
|
return
|
|
|
|
# Create async task
|
|
async def _execute():
|
|
try:
|
|
task.status = TaskStatus.RUNNING
|
|
task.started_at = datetime.now()
|
|
logger.info(f"Task {task_id} started")
|
|
|
|
# Execute the actual work
|
|
result = await coro_func(*args, **kwargs)
|
|
|
|
# Update task with result
|
|
task.status = TaskStatus.COMPLETED
|
|
task.result = result
|
|
task.completed_at = datetime.now()
|
|
logger.info(f"Task {task_id} completed")
|
|
|
|
except Exception as e:
|
|
task.status = TaskStatus.FAILED
|
|
task.error = str(e)
|
|
task.completed_at = datetime.now()
|
|
logger.error(f"Task {task_id} failed: {e}")
|
|
|
|
# Start execution
|
|
future = asyncio.create_task(_execute())
|
|
self._task_futures[task_id] = future
|
|
|
|
def get_task(self, task_id: str) -> Optional[Task]:
|
|
"""Get task by ID"""
|
|
return self._tasks.get(task_id)
|
|
|
|
def list_tasks(
|
|
self,
|
|
status: Optional[TaskStatus] = None,
|
|
limit: int = 100
|
|
) -> List[Task]:
|
|
"""
|
|
List tasks with optional filtering
|
|
|
|
Args:
|
|
status: Filter by status
|
|
limit: Maximum number of tasks to return
|
|
|
|
Returns:
|
|
List of tasks
|
|
"""
|
|
tasks = list(self._tasks.values())
|
|
|
|
if status:
|
|
tasks = [t for t in tasks if t.status == status]
|
|
|
|
# Sort by created_at descending
|
|
tasks.sort(key=lambda t: t.created_at, reverse=True)
|
|
|
|
return tasks[:limit]
|
|
|
|
def update_progress(
|
|
self,
|
|
task_id: str,
|
|
current: int,
|
|
total: int,
|
|
message: str = ""
|
|
):
|
|
"""
|
|
Update task progress
|
|
|
|
Args:
|
|
task_id: Task ID
|
|
current: Current progress
|
|
total: Total steps
|
|
message: Progress message
|
|
"""
|
|
task = self._tasks.get(task_id)
|
|
if not task:
|
|
return
|
|
|
|
percentage = (current / total * 100) if total > 0 else 0
|
|
task.progress = TaskProgress(
|
|
current=current,
|
|
total=total,
|
|
percentage=percentage,
|
|
message=message
|
|
)
|
|
|
|
def cancel_task(self, task_id: str) -> bool:
|
|
"""
|
|
Cancel a running task
|
|
|
|
Args:
|
|
task_id: Task ID
|
|
|
|
Returns:
|
|
True if cancelled, False otherwise
|
|
"""
|
|
task = self._tasks.get(task_id)
|
|
if not task:
|
|
return False
|
|
|
|
# Cancel future if running
|
|
future = self._task_futures.get(task_id)
|
|
if future and not future.done():
|
|
future.cancel()
|
|
|
|
# Update task status
|
|
task.status = TaskStatus.CANCELLED
|
|
task.completed_at = datetime.now()
|
|
logger.info(f"Cancelled task {task_id}")
|
|
return True
|
|
|
|
async def _cleanup_loop(self):
|
|
"""Periodically clean up old completed tasks"""
|
|
while self._running:
|
|
try:
|
|
await asyncio.sleep(api_config.task_cleanup_interval)
|
|
self._cleanup_old_tasks()
|
|
except asyncio.CancelledError:
|
|
break
|
|
except Exception as e:
|
|
logger.error(f"Error in cleanup loop: {e}")
|
|
|
|
def _cleanup_old_tasks(self):
|
|
"""Remove old completed/failed tasks"""
|
|
cutoff_time = datetime.now() - timedelta(seconds=api_config.task_retention_time)
|
|
|
|
tasks_to_remove = []
|
|
for task_id, task in self._tasks.items():
|
|
if task.status in [TaskStatus.COMPLETED, TaskStatus.FAILED, TaskStatus.CANCELLED]:
|
|
if task.completed_at and task.completed_at < cutoff_time:
|
|
tasks_to_remove.append(task_id)
|
|
|
|
for task_id in tasks_to_remove:
|
|
del self._tasks[task_id]
|
|
if task_id in self._task_futures:
|
|
del self._task_futures[task_id]
|
|
|
|
if tasks_to_remove:
|
|
logger.info(f"Cleaned up {len(tasks_to_remove)} old tasks")
|
|
|
|
|
|
# Global task manager instance
|
|
task_manager = TaskManager()
|
|
|