||
- """
- MinerU Tianshu - SQLite Task Database Manager
- 天枢任务数据库管理器
- 负责任务的持久化存储、状态管理和原子性操作
- """
- import sqlite3
- import json
- import uuid
- from contextlib import contextmanager
- from typing import Optional, List, Dict
- from pathlib import Path
- class TaskDB:
- """任务数据库管理类"""
-
- def __init__(self, db_path='mineru_tianshu.db'):
- self.db_path = db_path
- self._init_db()
-
- def _get_conn(self):
- """获取数据库连接(每次创建新连接,避免 pickle 问题)"""
- conn = sqlite3.connect(
- self.db_path,
- check_same_thread=False,
- timeout=30.0
- )
- conn.row_factory = sqlite3.Row
- return conn
-
- @contextmanager
- def get_cursor(self):
- """上下文管理器,自动提交和错误处理"""
- conn = self._get_conn()
- cursor = conn.cursor()
- try:
- yield cursor
- conn.commit()
- except Exception as e:
- conn.rollback()
- raise e
- finally:
- conn.close() # 关闭连接
-
- def _init_db(self):
- """初始化数据库表"""
- with self.get_cursor() as cursor:
- cursor.execute('''
- CREATE TABLE IF NOT EXISTS tasks (
- task_id TEXT PRIMARY KEY,
- file_name TEXT NOT NULL,
- file_path TEXT,
- status TEXT DEFAULT 'pending',
- priority INTEGER DEFAULT 0,
- backend TEXT DEFAULT 'pipeline',
- options TEXT,
- result_path TEXT,
- error_message TEXT,
- created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
- started_at TIMESTAMP,
- completed_at TIMESTAMP,
- worker_id TEXT,
- retry_count INTEGER DEFAULT 0
- )
- ''')
-
- # 创建索引加速查询
- cursor.execute('CREATE INDEX IF NOT EXISTS idx_status ON tasks(status)')
- cursor.execute('CREATE INDEX IF NOT EXISTS idx_priority ON tasks(priority DESC)')
- cursor.execute('CREATE INDEX IF NOT EXISTS idx_created_at ON tasks(created_at)')
- cursor.execute('CREATE INDEX IF NOT EXISTS idx_worker_id ON tasks(worker_id)')
-
- def create_task(self, file_name: str, file_path: str,
- backend: str = 'pipeline', options: dict = None,
- priority: int = 0) -> str:
- """
- 创建新任务
-
- Args:
- file_name: 文件名
- file_path: 文件路径
- backend: 处理后端 (pipeline/vlm-transformers/vlm-vllm-engine)
- options: 处理选项 (dict)
- priority: 优先级,数字越大越优先
-
- Returns:
- task_id: 任务ID
- """
- task_id = str(uuid.uuid4())
- with self.get_cursor() as cursor:
- cursor.execute('''
- INSERT INTO tasks (task_id, file_name, file_path, backend, options, priority)
- VALUES (?, ?, ?, ?, ?, ?)
- ''', (task_id, file_name, file_path, backend, json.dumps(options or {}), priority))
- return task_id
-
- def get_next_task(self, worker_id: str) -> Optional[Dict]:
- """
- 获取下一个待处理任务(原子操作,防止并发冲突)
-
- Args:
- worker_id: Worker ID
-
- Returns:
- task: 任务字典,如果没有任务返回 None
- """
- with self.get_cursor() as cursor:
- # 使用事务确保原子性
- cursor.execute('BEGIN IMMEDIATE')
-
- # 按优先级和创建时间获取任务
- cursor.execute('''
- SELECT * FROM tasks
- WHERE status = 'pending'
- ORDER BY priority DESC, created_at ASC
- LIMIT 1
- ''')
-
- task = cursor.fetchone()
- if task:
- # 立即标记为 processing
- cursor.execute('''
- UPDATE tasks
- SET status = 'processing',
- started_at = CURRENT_TIMESTAMP,
- worker_id = ?
- WHERE task_id = ?
- ''', (worker_id, task['task_id']))
-
- return dict(task)
-
- return None
-
- def update_task_status(self, task_id: str, status: str,
- result_path: str = None, error_message: str = None):
- """
- 更新任务状态
-
- Args:
- task_id: 任务ID
- status: 新状态 (pending/processing/completed/failed/cancelled)
- result_path: 结果路径(可选)
- error_message: 错误信息(可选)
- """
- with self.get_cursor() as cursor:
- updates = ['status = ?']
- params = [status]
-
- if status == 'completed':
- updates.append('completed_at = CURRENT_TIMESTAMP')
- if result_path:
- updates.append('result_path = ?')
- params.append(result_path)
-
- if status == 'failed' and error_message:
- updates.append('error_message = ?')
- params.append(error_message)
- updates.append('completed_at = CURRENT_TIMESTAMP')
-
- params.append(task_id)
- cursor.execute(f'''
- UPDATE tasks SET {', '.join(updates)}
- WHERE task_id = ?
- ''', params)
-
- def get_task(self, task_id: str) -> Optional[Dict]:
- """
- 查询任务详情
-
- Args:
- task_id: 任务ID
-
- Returns:
- task: 任务字典,如果不存在返回 None
- """
- with self.get_cursor() as cursor:
- cursor.execute('SELECT * FROM tasks WHERE task_id = ?', (task_id,))
- task = cursor.fetchone()
- return dict(task) if task else None
-
- def get_queue_stats(self) -> Dict[str, int]:
- """
- 获取队列统计信息
-
- Returns:
- stats: 各状态的任务数量
- """
- with self.get_cursor() as cursor:
- cursor.execute('''
- SELECT status, COUNT(*) as count
- FROM tasks
- GROUP BY status
- ''')
- stats = {row['status']: row['count'] for row in cursor.fetchall()}
- return stats
-
- def get_tasks_by_status(self, status: str, limit: int = 100) -> List[Dict]:
- """
- 根据状态获取任务列表
-
- Args:
- status: 任务状态
- limit: 返回数量限制
-
- Returns:
- tasks: 任务列表
- """
- with self.get_cursor() as cursor:
- cursor.execute('''
- SELECT * FROM tasks
- WHERE status = ?
- ORDER BY created_at DESC
- LIMIT ?
- ''', (status, limit))
- return [dict(row) for row in cursor.fetchall()]
-
- def cleanup_old_tasks(self, days: int = 7):
- """
- 清理旧任务记录
-
- Args:
- days: 保留最近N天的任务
- """
- with self.get_cursor() as cursor:
- cursor.execute('''
- DELETE FROM tasks
- WHERE completed_at < datetime('now', '-' || ? || ' days')
- AND status IN ('completed', 'failed')
- ''', (days,))
- deleted_count = cursor.rowcount
- return deleted_count
-
- def reset_stale_tasks(self, timeout_minutes: int = 60):
- """
- 重置超时的 processing 任务为 pending
-
- Args:
- timeout_minutes: 超时时间(分钟)
- """
- with self.get_cursor() as cursor:
- cursor.execute('''
- UPDATE tasks
- SET status = 'pending',
- worker_id = NULL,
- retry_count = retry_count + 1
- WHERE status = 'processing'
- AND started_at < datetime('now', '-' || ? || ' minutes')
- ''', (timeout_minutes,))
- reset_count = cursor.rowcount
- return reset_count
- if __name__ == '__main__':
- # 测试代码
- db = TaskDB('test_tianshu.db')
-
- # 创建测试任务
- task_id = db.create_task(
- file_name='test.pdf',
- file_path='/tmp/test.pdf',
- backend='pipeline',
- options={'lang': 'ch', 'formula_enable': True},
- priority=1
- )
- print(f"Created task: {task_id}")
-
- # 查询任务
- task = db.get_task(task_id)
- print(f"Task details: {task}")
-
- # 获取统计
- stats = db.get_queue_stats()
- print(f"Queue stats: {stats}")
-
- # 清理测试数据库
- Path('test_tianshu.db').unlink(missing_ok=True)
- print("Test completed!")
|