task_db.py 9.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282
  1. """
  2. MinerU Tianshu - SQLite Task Database Manager
  3. 天枢任务数据库管理器
  4. 负责任务的持久化存储、状态管理和原子性操作
  5. """
  6. import sqlite3
  7. import json
  8. import uuid
  9. from datetime import datetime
  10. from contextlib import contextmanager
  11. from typing import Optional, List, Dict
  12. from pathlib import Path
  13. class TaskDB:
  14. """任务数据库管理类"""
  15. def __init__(self, db_path='mineru_tianshu.db'):
  16. self.db_path = db_path
  17. self._init_db()
  18. def _get_conn(self):
  19. """获取数据库连接(每次创建新连接,避免 pickle 问题)"""
  20. conn = sqlite3.connect(
  21. self.db_path,
  22. check_same_thread=False,
  23. timeout=30.0
  24. )
  25. conn.row_factory = sqlite3.Row
  26. return conn
  27. @contextmanager
  28. def get_cursor(self):
  29. """上下文管理器,自动提交和错误处理"""
  30. conn = self._get_conn()
  31. cursor = conn.cursor()
  32. try:
  33. yield cursor
  34. conn.commit()
  35. except Exception as e:
  36. conn.rollback()
  37. raise e
  38. finally:
  39. conn.close() # 关闭连接
  40. def _init_db(self):
  41. """初始化数据库表"""
  42. with self.get_cursor() as cursor:
  43. cursor.execute('''
  44. CREATE TABLE IF NOT EXISTS tasks (
  45. task_id TEXT PRIMARY KEY,
  46. file_name TEXT NOT NULL,
  47. file_path TEXT,
  48. status TEXT DEFAULT 'pending',
  49. priority INTEGER DEFAULT 0,
  50. backend TEXT DEFAULT 'pipeline',
  51. options TEXT,
  52. result_path TEXT,
  53. error_message TEXT,
  54. created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
  55. started_at TIMESTAMP,
  56. completed_at TIMESTAMP,
  57. worker_id TEXT,
  58. retry_count INTEGER DEFAULT 0
  59. )
  60. ''')
  61. # 创建索引加速查询
  62. cursor.execute('CREATE INDEX IF NOT EXISTS idx_status ON tasks(status)')
  63. cursor.execute('CREATE INDEX IF NOT EXISTS idx_priority ON tasks(priority DESC)')
  64. cursor.execute('CREATE INDEX IF NOT EXISTS idx_created_at ON tasks(created_at)')
  65. cursor.execute('CREATE INDEX IF NOT EXISTS idx_worker_id ON tasks(worker_id)')
  66. def create_task(self, file_name: str, file_path: str,
  67. backend: str = 'pipeline', options: dict = None,
  68. priority: int = 0) -> str:
  69. """
  70. 创建新任务
  71. Args:
  72. file_name: 文件名
  73. file_path: 文件路径
  74. backend: 处理后端 (pipeline/vlm-transformers/vlm-vllm-engine)
  75. options: 处理选项 (dict)
  76. priority: 优先级,数字越大越优先
  77. Returns:
  78. task_id: 任务ID
  79. """
  80. task_id = str(uuid.uuid4())
  81. with self.get_cursor() as cursor:
  82. cursor.execute('''
  83. INSERT INTO tasks (task_id, file_name, file_path, backend, options, priority)
  84. VALUES (?, ?, ?, ?, ?, ?)
  85. ''', (task_id, file_name, file_path, backend, json.dumps(options or {}), priority))
  86. return task_id
  87. def get_next_task(self, worker_id: str) -> Optional[Dict]:
  88. """
  89. 获取下一个待处理任务(原子操作,防止并发冲突)
  90. Args:
  91. worker_id: Worker ID
  92. Returns:
  93. task: 任务字典,如果没有任务返回 None
  94. """
  95. with self.get_cursor() as cursor:
  96. # 使用事务确保原子性
  97. cursor.execute('BEGIN IMMEDIATE')
  98. # 按优先级和创建时间获取任务
  99. cursor.execute('''
  100. SELECT * FROM tasks
  101. WHERE status = 'pending'
  102. ORDER BY priority DESC, created_at ASC
  103. LIMIT 1
  104. ''')
  105. task = cursor.fetchone()
  106. if task:
  107. # 立即标记为 processing
  108. cursor.execute('''
  109. UPDATE tasks
  110. SET status = 'processing',
  111. started_at = ?,
  112. worker_id = ?
  113. WHERE task_id = ?
  114. ''', (datetime.now().isoformat(), worker_id, task['task_id']))
  115. return dict(task)
  116. return None
  117. def update_task_status(self, task_id: str, status: str,
  118. result_path: str = None, error_message: str = None):
  119. """
  120. 更新任务状态
  121. Args:
  122. task_id: 任务ID
  123. status: 新状态 (pending/processing/completed/failed/cancelled)
  124. result_path: 结果路径(可选)
  125. error_message: 错误信息(可选)
  126. """
  127. with self.get_cursor() as cursor:
  128. updates = ['status = ?']
  129. params = [status]
  130. if status == 'completed':
  131. updates.append('completed_at = ?')
  132. params.append(datetime.now().isoformat())
  133. if result_path:
  134. updates.append('result_path = ?')
  135. params.append(result_path)
  136. if status == 'failed' and error_message:
  137. updates.append('error_message = ?')
  138. params.append(error_message)
  139. updates.append('completed_at = ?')
  140. params.append(datetime.now().isoformat())
  141. params.append(task_id)
  142. cursor.execute(f'''
  143. UPDATE tasks SET {', '.join(updates)}
  144. WHERE task_id = ?
  145. ''', params)
  146. def get_task(self, task_id: str) -> Optional[Dict]:
  147. """
  148. 查询任务详情
  149. Args:
  150. task_id: 任务ID
  151. Returns:
  152. task: 任务字典,如果不存在返回 None
  153. """
  154. with self.get_cursor() as cursor:
  155. cursor.execute('SELECT * FROM tasks WHERE task_id = ?', (task_id,))
  156. task = cursor.fetchone()
  157. return dict(task) if task else None
  158. def get_queue_stats(self) -> Dict[str, int]:
  159. """
  160. 获取队列统计信息
  161. Returns:
  162. stats: 各状态的任务数量
  163. """
  164. with self.get_cursor() as cursor:
  165. cursor.execute('''
  166. SELECT status, COUNT(*) as count
  167. FROM tasks
  168. GROUP BY status
  169. ''')
  170. stats = {row['status']: row['count'] for row in cursor.fetchall()}
  171. return stats
  172. def get_tasks_by_status(self, status: str, limit: int = 100) -> List[Dict]:
  173. """
  174. 根据状态获取任务列表
  175. Args:
  176. status: 任务状态
  177. limit: 返回数量限制
  178. Returns:
  179. tasks: 任务列表
  180. """
  181. with self.get_cursor() as cursor:
  182. cursor.execute('''
  183. SELECT * FROM tasks
  184. WHERE status = ?
  185. ORDER BY created_at DESC
  186. LIMIT ?
  187. ''', (status, limit))
  188. return [dict(row) for row in cursor.fetchall()]
  189. def cleanup_old_tasks(self, days: int = 7):
  190. """
  191. 清理旧任务记录
  192. Args:
  193. days: 保留最近N天的任务
  194. """
  195. with self.get_cursor() as cursor:
  196. cursor.execute('''
  197. DELETE FROM tasks
  198. WHERE completed_at < datetime('now', '-' || ? || ' days')
  199. AND status IN ('completed', 'failed')
  200. ''', (days,))
  201. deleted_count = cursor.rowcount
  202. return deleted_count
  203. def reset_stale_tasks(self, timeout_minutes: int = 60):
  204. """
  205. 重置超时的 processing 任务为 pending
  206. Args:
  207. timeout_minutes: 超时时间(分钟)
  208. """
  209. with self.get_cursor() as cursor:
  210. cursor.execute('''
  211. UPDATE tasks
  212. SET status = 'pending',
  213. worker_id = NULL,
  214. retry_count = retry_count + 1
  215. WHERE status = 'processing'
  216. AND started_at < datetime('now', '-' || ? || ' minutes')
  217. ''', (timeout_minutes,))
  218. reset_count = cursor.rowcount
  219. return reset_count
  220. if __name__ == '__main__':
  221. # 测试代码
  222. db = TaskDB('test_tianshu.db')
  223. # 创建测试任务
  224. task_id = db.create_task(
  225. file_name='test.pdf',
  226. file_path='/tmp/test.pdf',
  227. backend='pipeline',
  228. options={'lang': 'ch', 'formula_enable': True},
  229. priority=1
  230. )
  231. print(f"Created task: {task_id}")
  232. # 查询任务
  233. task = db.get_task(task_id)
  234. print(f"Task details: {task}")
  235. # 获取统计
  236. stats = db.get_queue_stats()
  237. print(f"Queue stats: {stats}")
  238. # 清理测试数据库
  239. Path('test_tianshu.db').unlink(missing_ok=True)
  240. print("Test completed!")