Files

300 lines
9.5 KiB
Python

# Copyright (C) 2025 AIDC-AI
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Publish Task Manager - Background task queue for video publishing.
Features:
- Async task queue with configurable workers
- Task persistence (in-memory, Redis optional)
- Progress tracking and callbacks
- Retry logic for failed tasks
"""
import asyncio
import uuid
from datetime import datetime
from typing import Optional, Dict, List, Callable
from dataclasses import dataclass, field
from enum import Enum
from loguru import logger
from pixelle_video.services.publishing import (
Publisher,
Platform,
PublishStatus,
VideoMetadata,
PublishResult,
PublishTask,
)
from pixelle_video.services.publishing.export_publisher import ExportPublisher
from pixelle_video.services.publishing.bilibili_publisher import BilibiliPublisher
from pixelle_video.services.publishing.youtube_publisher import YouTubePublisher
class TaskPriority(Enum):
LOW = 0
NORMAL = 1
HIGH = 2
@dataclass
class QueuedTask:
"""Extended task with queue metadata"""
task: PublishTask
priority: TaskPriority = TaskPriority.NORMAL
retries: int = 0
max_retries: int = 3
retry_delay: float = 5.0
created_at: datetime = field(default_factory=datetime.now)
started_at: Optional[datetime] = None
completed_at: Optional[datetime] = None
progress: float = 0.0
progress_message: str = ""
class PublishTaskManager:
"""
Manages background publishing tasks with async queue.
Usage:
manager = PublishTaskManager()
await manager.start()
task_id = await manager.enqueue(
video_path="/path/to/video.mp4",
platform=Platform.BILIBILI,
metadata=VideoMetadata(title="My Video")
)
status = manager.get_task(task_id)
"""
def __init__(
self,
max_workers: int = 3,
max_queue_size: int = 100,
):
self.max_workers = max_workers
self.max_queue_size = max_queue_size
# Task storage
self._tasks: Dict[str, QueuedTask] = {}
self._queue: asyncio.Queue = None
self._workers: List[asyncio.Task] = []
self._running = False
# Publishers
self._publishers: Dict[Platform, Publisher] = {
Platform.EXPORT: ExportPublisher(),
Platform.BILIBILI: BilibiliPublisher(),
Platform.YOUTUBE: YouTubePublisher(),
}
# Callbacks
self._on_complete: Optional[Callable] = None
self._on_progress: Optional[Callable] = None
async def start(self):
"""Start the task manager and workers."""
if self._running:
return
self._queue = asyncio.Queue(maxsize=self.max_queue_size)
self._running = True
# Start worker tasks
for i in range(self.max_workers):
worker = asyncio.create_task(self._worker(i))
self._workers.append(worker)
logger.info(f"✅ Publish task manager started with {self.max_workers} workers")
async def stop(self):
"""Stop all workers and clear queue."""
self._running = False
# Cancel all workers
for worker in self._workers:
worker.cancel()
await asyncio.gather(*self._workers, return_exceptions=True)
self._workers.clear()
logger.info("✅ Publish task manager stopped")
async def enqueue(
self,
video_path: str,
platform: Platform,
metadata: VideoMetadata,
priority: TaskPriority = TaskPriority.NORMAL,
) -> str:
"""
Add a publish task to the queue.
Returns:
Task ID for tracking
"""
task_id = str(uuid.uuid4())[:8]
task = PublishTask(
id=task_id,
video_path=video_path,
platform=platform,
metadata=metadata,
status=PublishStatus.PENDING,
)
queued_task = QueuedTask(task=task, priority=priority)
self._tasks[task_id] = queued_task
await self._queue.put(queued_task)
logger.info(f"📥 Queued task {task_id}: {metadata.title}{platform.value}")
return task_id
def get_task(self, task_id: str) -> Optional[QueuedTask]:
"""Get task by ID."""
return self._tasks.get(task_id)
def get_all_tasks(self) -> List[QueuedTask]:
"""Get all tasks."""
return list(self._tasks.values())
def get_pending_tasks(self) -> List[QueuedTask]:
"""Get pending tasks."""
return [t for t in self._tasks.values() if t.task.status == PublishStatus.PENDING]
def get_active_tasks(self) -> List[QueuedTask]:
"""Get currently processing tasks."""
return [t for t in self._tasks.values() if t.task.status in [
PublishStatus.CONVERTING,
PublishStatus.UPLOADING,
PublishStatus.PROCESSING,
]]
def set_on_complete(self, callback: Callable):
"""Set callback for task completion."""
self._on_complete = callback
def set_on_progress(self, callback: Callable):
"""Set callback for progress updates."""
self._on_progress = callback
async def _worker(self, worker_id: int):
"""Worker coroutine that processes tasks from queue."""
logger.debug(f"Worker {worker_id} started")
while self._running:
try:
# Get task from queue with timeout
try:
queued_task = await asyncio.wait_for(
self._queue.get(),
timeout=1.0
)
except asyncio.TimeoutError:
continue
await self._process_task(queued_task, worker_id)
self._queue.task_done()
except asyncio.CancelledError:
break
except Exception as e:
logger.error(f"Worker {worker_id} error: {e}")
async def _process_task(self, queued_task: QueuedTask, worker_id: int):
"""Process a single publish task."""
task = queued_task.task
task_id = task.id
logger.info(f"🔄 Worker {worker_id} processing task {task_id}")
queued_task.started_at = datetime.now()
task.status = PublishStatus.UPLOADING
# Get publisher
publisher = self._publishers.get(task.platform)
if not publisher:
task.status = PublishStatus.FAILED
task.result = PublishResult(
success=False,
platform=task.platform,
status=PublishStatus.FAILED,
error_message=f"No publisher for platform: {task.platform}",
)
return
# Progress callback
def progress_callback(progress: float, message: str):
queued_task.progress = progress
queued_task.progress_message = message
if self._on_progress:
self._on_progress(task_id, progress, message)
try:
# Execute publish
result = await publisher.publish(
task.video_path,
task.metadata,
progress_callback=progress_callback
)
task.result = result
task.status = result.status
if result.success:
logger.info(f"✅ Task {task_id} completed: {result.video_url or result.export_path}")
else:
logger.warning(f"❌ Task {task_id} failed: {result.error_message}")
# Retry if applicable
if queued_task.retries < queued_task.max_retries:
queued_task.retries += 1
task.status = PublishStatus.PENDING
logger.info(f"🔄 Retrying task {task_id} ({queued_task.retries}/{queued_task.max_retries})")
await asyncio.sleep(queued_task.retry_delay)
await self._queue.put(queued_task)
return
except Exception as e:
logger.error(f"Task {task_id} exception: {e}")
task.status = PublishStatus.FAILED
task.result = PublishResult(
success=False,
platform=task.platform,
status=PublishStatus.FAILED,
error_message=str(e),
)
queued_task.completed_at = datetime.now()
task.updated_at = datetime.now()
# Call completion callback
if self._on_complete:
self._on_complete(task_id, task.result)
# Singleton instance
_publish_manager: Optional[PublishTaskManager] = None
def get_publish_manager() -> PublishTaskManager:
"""Get or create the global publish task manager."""
global _publish_manager
if _publish_manager is None:
_publish_manager = PublishTaskManager()
return _publish_manager