282 lines
9.5 KiB
Python
282 lines
9.5 KiB
Python
# Copyright (C) 2025 AIDC-AI
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
"""
|
|
Task Center Page - View and manage video generation tasks
|
|
|
|
Features:
|
|
- View all pending/running/completed tasks
|
|
- Real-time progress updates (polling)
|
|
- Jump to editor for completed tasks
|
|
"""
|
|
|
|
import streamlit as st
|
|
import requests
|
|
import time
|
|
from datetime import datetime
|
|
|
|
from web.i18n import tr, get_language
|
|
|
|
# Page config
|
|
st.set_page_config(
|
|
page_title="任务中心" if get_language() == "zh_CN" else "Task Center",
|
|
page_icon="📋",
|
|
layout="wide",
|
|
)
|
|
|
|
# API endpoint
|
|
API_BASE = "http://localhost:8000/api"
|
|
|
|
|
|
def get_all_tasks():
|
|
"""Fetch all tasks from API"""
|
|
try:
|
|
response = requests.get(f"{API_BASE}/tasks", timeout=5)
|
|
if response.status_code == 200:
|
|
data = response.json()
|
|
# API returns list directly, or dict with "tasks" key
|
|
if isinstance(data, list):
|
|
return data
|
|
elif isinstance(data, dict):
|
|
return data.get("tasks", [])
|
|
return []
|
|
return []
|
|
except Exception as e:
|
|
st.error(f"无法连接到 API: {e}")
|
|
return []
|
|
|
|
|
|
def get_task_status(task_id: str):
|
|
"""Get status of a specific task"""
|
|
try:
|
|
response = requests.get(f"{API_BASE}/tasks/{task_id}", timeout=5)
|
|
if response.status_code == 200:
|
|
return response.json()
|
|
return None
|
|
except:
|
|
return None
|
|
|
|
|
|
def format_time(iso_string):
|
|
"""Format ISO time string to readable format"""
|
|
if not iso_string:
|
|
return "-"
|
|
try:
|
|
dt = datetime.fromisoformat(iso_string.replace("Z", "+00:00"))
|
|
return dt.strftime("%Y-%m-%d %H:%M:%S")
|
|
except:
|
|
return iso_string
|
|
|
|
|
|
def render_task_card(task):
|
|
"""Render a single task card"""
|
|
task_id = task.get("task_id", "unknown")
|
|
status = task.get("status", "unknown")
|
|
progress = task.get("progress", {})
|
|
result = task.get("result", {})
|
|
params = task.get("request_params", {})
|
|
|
|
# Status badge colors
|
|
status_colors = {
|
|
"pending": "🟡",
|
|
"running": "🔵",
|
|
"completed": "🟢",
|
|
"failed": "🔴",
|
|
"cancelled": "⚪",
|
|
}
|
|
status_emoji = status_colors.get(status, "⚪")
|
|
|
|
with st.container(border=True):
|
|
# Header row
|
|
col1, col2, col3 = st.columns([3, 1, 1.5])
|
|
|
|
with col1:
|
|
title = params.get("title") or f"任务 {task_id[:8]}"
|
|
st.markdown(f"### {status_emoji} {title}")
|
|
st.caption(f"ID: `{task_id}`")
|
|
|
|
with col2:
|
|
st.markdown(f"**状态**\n\n{status}")
|
|
|
|
with col3:
|
|
created_at = format_time(task.get("created_at"))
|
|
st.markdown(f"**创建时间**\n\n{created_at}")
|
|
|
|
# Progress bar for running tasks
|
|
if status == "running" and progress:
|
|
percentage = progress.get("percentage", 0)
|
|
message = progress.get("message", "")
|
|
st.progress(percentage / 100, text=message)
|
|
|
|
# Details expander
|
|
with st.expander("🔍 查看执行细节"):
|
|
# Time statistics
|
|
t_col1, t_col2, t_col3 = st.columns(3)
|
|
with t_col1:
|
|
st.write(f"⏱️ **开始**: {format_time(task.get('started_at'))}")
|
|
with t_col2:
|
|
st.write(f"🏁 **结束**: {format_time(task.get('completed_at'))}")
|
|
with t_col3:
|
|
# Calculate duration if possible
|
|
try:
|
|
start = task.get("started_at")
|
|
end = task.get("completed_at")
|
|
if start and end:
|
|
s_dt = datetime.fromisoformat(start.replace("Z", "+00:00"))
|
|
e_dt = datetime.fromisoformat(end.replace("Z", "+00:00"))
|
|
duration = (e_dt - s_dt).total_seconds()
|
|
st.write(f"⏳ **用时**: {duration:.1f}s")
|
|
else:
|
|
st.write("⏳ **用时**: -")
|
|
except:
|
|
st.write("⏳ **用时**: -")
|
|
|
|
st.divider()
|
|
|
|
# Input parameters
|
|
st.markdown("**📝 输入参数**")
|
|
p_col1, p_col2 = st.columns(2)
|
|
with p_col1:
|
|
st.write(f"- **文本**: {params.get('text', '')[:100]}...")
|
|
st.write(f"- **场景数**: {params.get('n_scenes', 5)}")
|
|
with p_col2:
|
|
st.write(f"- **模板**: {params.get('frame_template', '默认')}")
|
|
st.write(f"- **TTS 模式**: {params.get('tts_inference_mode', 'local')}")
|
|
|
|
# Execution logs
|
|
logs = task.get("logs", [])
|
|
if logs:
|
|
st.divider()
|
|
st.markdown("**📜 执行步骤**")
|
|
for log in reversed(logs):
|
|
t = format_time(log.get("timestamp"))
|
|
m = log.get("message", "")
|
|
p = log.get("percentage", 0)
|
|
st.write(f"`{t}` | **{p}%** | {m}")
|
|
|
|
# Result information if completed
|
|
if status == "completed" and result:
|
|
st.divider()
|
|
st.markdown("**📊 结果详情**")
|
|
r_col1, r_col2, r_col3 = st.columns(3)
|
|
with r_col1:
|
|
st.write(f"📐 **视频长度**: {result.get('duration', 0):.1f}s")
|
|
with r_col2:
|
|
file_size = result.get('file_size', 0) / (1024 * 1024)
|
|
st.write(f"📦 **文件大小**: {file_size:.2f} MB")
|
|
with r_col3:
|
|
st.write(f"🔗 [查看视频]({result.get('video_url', '#')})")
|
|
|
|
# Actions for completed task
|
|
if status == "completed":
|
|
st.markdown("---")
|
|
col_a, col_b = st.columns(2)
|
|
with col_a:
|
|
st.success("✨ 视频生成成功")
|
|
with col_b:
|
|
editor_url = f"http://localhost:3000/editor?storyboard_id={task_id}"
|
|
st.markdown(
|
|
f'''
|
|
<a href="{editor_url}" target="_blank" style="text-decoration: none;">
|
|
<button style="
|
|
width: 100%;
|
|
padding: 0.5rem 1rem;
|
|
background-color: #262730;
|
|
color: white;
|
|
border: 1px solid #262730;
|
|
border-radius: 0.5rem;
|
|
cursor: pointer;
|
|
">
|
|
✏️ 在编辑器中打开
|
|
</button>
|
|
</a>
|
|
''',
|
|
unsafe_allow_html=True
|
|
)
|
|
|
|
# Failed task error
|
|
if status == "failed":
|
|
error = task.get("error", "未知错误")
|
|
st.error(f"❌ 执行失败: {error}")
|
|
|
|
|
|
def main():
|
|
st.title("📋 任务中心")
|
|
st.caption("查看和管理视频生成任务")
|
|
|
|
# Fetch tasks first to check for running tasks
|
|
tasks = get_all_tasks()
|
|
running_count = sum(1 for t in tasks if t.get("status") == "running") if tasks else 0
|
|
|
|
# Auto-refresh toggle - default ON if there are running tasks
|
|
col1, col2, col3 = st.columns([2, 1, 1])
|
|
|
|
with col1:
|
|
auto_refresh = st.checkbox(
|
|
f"🔄 自动刷新 (3秒)" + (f" - {running_count}个任务运行中" if running_count > 0 else ""),
|
|
value=(running_count > 0), # Auto-enable if running tasks exist
|
|
help="有运行中任务时建议开启"
|
|
)
|
|
|
|
with col2:
|
|
if st.button("🔄 刷新", use_container_width=True):
|
|
st.rerun()
|
|
|
|
with col3:
|
|
filter_status = st.selectbox(
|
|
"筛选状态",
|
|
["全部", "running", "completed", "failed", "pending"],
|
|
label_visibility="collapsed"
|
|
)
|
|
|
|
st.markdown("---")
|
|
|
|
if not tasks:
|
|
st.info("📭 暂无任务")
|
|
st.caption("在首页生成视频后,任务会显示在这里")
|
|
else:
|
|
# Filter tasks
|
|
if filter_status != "全部":
|
|
tasks = [t for t in tasks if t.get("status") == filter_status]
|
|
|
|
# Sort by created_at (newest first)
|
|
tasks = sorted(
|
|
tasks,
|
|
key=lambda x: x.get("created_at", ""),
|
|
reverse=True
|
|
)
|
|
|
|
# Stats row
|
|
completed_count = sum(1 for t in tasks if t.get("status") == "completed")
|
|
failed_count = sum(1 for t in tasks if t.get("status") == "failed")
|
|
|
|
col1, col2, col3, col4 = st.columns(4)
|
|
col1.metric("总任务", len(tasks))
|
|
col2.metric("🔵 运行中", running_count)
|
|
col3.metric("🟢 已完成", completed_count)
|
|
col4.metric("🔴 失败", failed_count)
|
|
|
|
st.markdown("---")
|
|
|
|
# Render task cards
|
|
for task in tasks:
|
|
render_task_card(task)
|
|
|
|
# Auto-refresh logic with shorter interval
|
|
if auto_refresh:
|
|
time.sleep(3)
|
|
st.rerun()
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|