task_db.py 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436
  1. """
  2. MinerU Tianshu - SQLite Task Database Manager
  3. 天枢任务数据库管理器
  4. 负责任务的持久化存储、状态管理和原子性操作
  5. """
  6. import sqlite3
  7. import json
  8. import uuid
  9. from contextlib import contextmanager
  10. from typing import Optional, List, Dict
  11. from pathlib import Path
  12. class TaskDB:
  13. """任务数据库管理类"""
  14. def __init__(self, db_path='mineru_tianshu.db'):
  15. self.db_path = db_path
  16. self._init_db()
  17. def _get_conn(self):
  18. """获取数据库连接(每次创建新连接,避免 pickle 问题)
  19. 并发安全说明:
  20. - 使用 check_same_thread=False 是安全的,因为:
  21. 1. 每次调用都创建新连接,不跨线程共享
  22. 2. 连接使用完立即关闭(在 get_cursor 上下文管理器中)
  23. 3. 不使用连接池,避免线程间共享同一连接
  24. - timeout=30.0 防止死锁,如果锁等待超过30秒会抛出异常
  25. """
  26. conn = sqlite3.connect(
  27. self.db_path,
  28. check_same_thread=False,
  29. timeout=30.0
  30. )
  31. conn.row_factory = sqlite3.Row
  32. return conn
  33. @contextmanager
  34. def get_cursor(self):
  35. """上下文管理器,自动提交和错误处理"""
  36. conn = self._get_conn()
  37. cursor = conn.cursor()
  38. try:
  39. yield cursor
  40. conn.commit()
  41. except Exception as e:
  42. conn.rollback()
  43. raise e
  44. finally:
  45. conn.close() # 关闭连接
  46. def _init_db(self):
  47. """初始化数据库表"""
  48. with self.get_cursor() as cursor:
  49. cursor.execute('''
  50. CREATE TABLE IF NOT EXISTS tasks (
  51. task_id TEXT PRIMARY KEY,
  52. file_name TEXT NOT NULL,
  53. file_path TEXT,
  54. status TEXT DEFAULT 'pending',
  55. priority INTEGER DEFAULT 0,
  56. backend TEXT DEFAULT 'pipeline',
  57. options TEXT,
  58. result_path TEXT,
  59. error_message TEXT,
  60. created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
  61. started_at TIMESTAMP,
  62. completed_at TIMESTAMP,
  63. worker_id TEXT,
  64. retry_count INTEGER DEFAULT 0
  65. )
  66. ''')
  67. # 创建索引加速查询
  68. cursor.execute('CREATE INDEX IF NOT EXISTS idx_status ON tasks(status)')
  69. cursor.execute('CREATE INDEX IF NOT EXISTS idx_priority ON tasks(priority DESC)')
  70. cursor.execute('CREATE INDEX IF NOT EXISTS idx_created_at ON tasks(created_at)')
  71. cursor.execute('CREATE INDEX IF NOT EXISTS idx_worker_id ON tasks(worker_id)')
  72. def create_task(self, file_name: str, file_path: str,
  73. backend: str = 'pipeline', options: dict = None,
  74. priority: int = 0) -> str:
  75. """
  76. 创建新任务
  77. Args:
  78. file_name: 文件名
  79. file_path: 文件路径
  80. backend: 处理后端 (pipeline/vlm-transformers/vlm-vllm-engine)
  81. options: 处理选项 (dict)
  82. priority: 优先级,数字越大越优先
  83. Returns:
  84. task_id: 任务ID
  85. """
  86. task_id = str(uuid.uuid4())
  87. with self.get_cursor() as cursor:
  88. cursor.execute('''
  89. INSERT INTO tasks (task_id, file_name, file_path, backend, options, priority)
  90. VALUES (?, ?, ?, ?, ?, ?)
  91. ''', (task_id, file_name, file_path, backend, json.dumps(options or {}), priority))
  92. return task_id
  93. def get_next_task(self, worker_id: str, max_retries: int = 3) -> Optional[Dict]:
  94. """
  95. 获取下一个待处理任务(原子操作,防止并发冲突)
  96. Args:
  97. worker_id: Worker ID
  98. max_retries: 当任务被其他 worker 抢走时的最大重试次数(默认3次)
  99. Returns:
  100. task: 任务字典,如果没有任务返回 None
  101. 并发安全说明:
  102. 1. 使用 BEGIN IMMEDIATE 立即获取写锁
  103. 2. UPDATE 时检查 status = 'pending' 防止重复拉取
  104. 3. 检查 rowcount 确保更新成功
  105. 4. 如果任务被抢走,立即重试而不是返回 None(避免不必要的等待)
  106. """
  107. for attempt in range(max_retries):
  108. with self.get_cursor() as cursor:
  109. # 使用事务确保原子性
  110. cursor.execute('BEGIN IMMEDIATE')
  111. # 按优先级和创建时间获取任务
  112. cursor.execute('''
  113. SELECT * FROM tasks
  114. WHERE status = 'pending'
  115. ORDER BY priority DESC, created_at ASC
  116. LIMIT 1
  117. ''')
  118. task = cursor.fetchone()
  119. if task:
  120. # 立即标记为 processing,并确保状态仍是 pending
  121. cursor.execute('''
  122. UPDATE tasks
  123. SET status = 'processing',
  124. started_at = CURRENT_TIMESTAMP,
  125. worker_id = ?
  126. WHERE task_id = ? AND status = 'pending'
  127. ''', (worker_id, task['task_id']))
  128. # 检查是否更新成功(防止被其他 worker 抢走)
  129. if cursor.rowcount == 0:
  130. # 任务被其他进程抢走了,立即重试
  131. # 因为队列中可能还有其他待处理任务
  132. continue
  133. return dict(task)
  134. else:
  135. # 队列中没有待处理任务,返回 None
  136. return None
  137. # 重试次数用尽,仍未获取到任务(高并发场景)
  138. return None
  139. def _build_update_clauses(self, status: str, result_path: str = None,
  140. error_message: str = None, worker_id: str = None,
  141. task_id: str = None):
  142. """
  143. 构建 UPDATE 和 WHERE 子句的辅助方法
  144. Args:
  145. status: 新状态
  146. result_path: 结果路径(可选)
  147. error_message: 错误信息(可选)
  148. worker_id: Worker ID(可选)
  149. task_id: 任务ID(可选)
  150. Returns:
  151. tuple: (update_clauses, update_params, where_clauses, where_params)
  152. """
  153. update_clauses = ['status = ?']
  154. update_params = [status]
  155. where_clauses = []
  156. where_params = []
  157. # 添加 task_id 条件(如果提供)
  158. if task_id:
  159. where_clauses.append('task_id = ?')
  160. where_params.append(task_id)
  161. # 处理 completed 状态
  162. if status == 'completed':
  163. update_clauses.append('completed_at = CURRENT_TIMESTAMP')
  164. if result_path:
  165. update_clauses.append('result_path = ?')
  166. update_params.append(result_path)
  167. # 只更新正在处理的任务
  168. where_clauses.append("status = 'processing'")
  169. if worker_id:
  170. where_clauses.append('worker_id = ?')
  171. where_params.append(worker_id)
  172. # 处理 failed 状态
  173. elif status == 'failed':
  174. update_clauses.append('completed_at = CURRENT_TIMESTAMP')
  175. if error_message:
  176. update_clauses.append('error_message = ?')
  177. update_params.append(error_message)
  178. # 只更新正在处理的任务
  179. where_clauses.append("status = 'processing'")
  180. if worker_id:
  181. where_clauses.append('worker_id = ?')
  182. where_params.append(worker_id)
  183. return update_clauses, update_params, where_clauses, where_params
  184. def update_task_status(self, task_id: str, status: str,
  185. result_path: str = None, error_message: str = None,
  186. worker_id: str = None):
  187. """
  188. 更新任务状态
  189. Args:
  190. task_id: 任务ID
  191. status: 新状态 (pending/processing/completed/failed/cancelled)
  192. result_path: 结果路径(可选)
  193. error_message: 错误信息(可选)
  194. worker_id: Worker ID(可选,用于并发检查)
  195. Returns:
  196. bool: 更新是否成功
  197. 并发安全说明:
  198. 1. 更新为 completed/failed 时会检查状态是 processing
  199. 2. 如果提供 worker_id,会检查任务是否属于该 worker
  200. 3. 返回 False 表示任务被其他进程修改了
  201. """
  202. with self.get_cursor() as cursor:
  203. # 使用辅助方法构建 UPDATE 和 WHERE 子句
  204. update_clauses, update_params, where_clauses, where_params = \
  205. self._build_update_clauses(status, result_path, error_message, worker_id, task_id)
  206. # 合并参数:先 UPDATE 部分,再 WHERE 部分
  207. all_params = update_params + where_params
  208. sql = f'''
  209. UPDATE tasks
  210. SET {', '.join(update_clauses)}
  211. WHERE {' AND '.join(where_clauses)}
  212. '''
  213. cursor.execute(sql, all_params)
  214. # 检查更新是否成功
  215. success = cursor.rowcount > 0
  216. # 调试日志(仅在失败时)
  217. if not success and status in ['completed', 'failed']:
  218. from loguru import logger
  219. logger.debug(
  220. f"Status update failed: task_id={task_id}, status={status}, "
  221. f"worker_id={worker_id}, SQL: {sql}, params: {all_params}"
  222. )
  223. return success
  224. def get_task(self, task_id: str) -> Optional[Dict]:
  225. """
  226. 查询任务详情
  227. Args:
  228. task_id: 任务ID
  229. Returns:
  230. task: 任务字典,如果不存在返回 None
  231. """
  232. with self.get_cursor() as cursor:
  233. cursor.execute('SELECT * FROM tasks WHERE task_id = ?', (task_id,))
  234. task = cursor.fetchone()
  235. return dict(task) if task else None
  236. def get_queue_stats(self) -> Dict[str, int]:
  237. """
  238. 获取队列统计信息
  239. Returns:
  240. stats: 各状态的任务数量
  241. """
  242. with self.get_cursor() as cursor:
  243. cursor.execute('''
  244. SELECT status, COUNT(*) as count
  245. FROM tasks
  246. GROUP BY status
  247. ''')
  248. stats = {row['status']: row['count'] for row in cursor.fetchall()}
  249. return stats
  250. def get_tasks_by_status(self, status: str, limit: int = 100) -> List[Dict]:
  251. """
  252. 根据状态获取任务列表
  253. Args:
  254. status: 任务状态
  255. limit: 返回数量限制
  256. Returns:
  257. tasks: 任务列表
  258. """
  259. with self.get_cursor() as cursor:
  260. cursor.execute('''
  261. SELECT * FROM tasks
  262. WHERE status = ?
  263. ORDER BY created_at DESC
  264. LIMIT ?
  265. ''', (status, limit))
  266. return [dict(row) for row in cursor.fetchall()]
  267. def cleanup_old_task_files(self, days: int = 7):
  268. """
  269. 清理旧任务的结果文件(保留数据库记录)
  270. Args:
  271. days: 清理多少天前的任务文件
  272. Returns:
  273. int: 删除的文件目录数
  274. 注意:
  275. - 只删除结果文件,保留数据库记录
  276. - 数据库中的 result_path 字段会被清空
  277. - 用户仍可查询任务状态和历史记录
  278. """
  279. from pathlib import Path
  280. import shutil
  281. with self.get_cursor() as cursor:
  282. # 查询要清理文件的任务
  283. cursor.execute('''
  284. SELECT task_id, result_path FROM tasks
  285. WHERE completed_at < datetime('now', '-' || ? || ' days')
  286. AND status IN ('completed', 'failed')
  287. AND result_path IS NOT NULL
  288. ''', (days,))
  289. old_tasks = cursor.fetchall()
  290. file_count = 0
  291. # 删除结果文件
  292. for task in old_tasks:
  293. if task['result_path']:
  294. result_path = Path(task['result_path'])
  295. if result_path.exists() and result_path.is_dir():
  296. try:
  297. shutil.rmtree(result_path)
  298. file_count += 1
  299. # 清空数据库中的 result_path,表示文件已被清理
  300. cursor.execute('''
  301. UPDATE tasks
  302. SET result_path = NULL
  303. WHERE task_id = ?
  304. ''', (task['task_id'],))
  305. except Exception as e:
  306. from loguru import logger
  307. logger.warning(f"Failed to delete result files for task {task['task_id']}: {e}")
  308. return file_count
  309. def cleanup_old_task_records(self, days: int = 30):
  310. """
  311. 清理极旧的任务记录(可选功能)
  312. Args:
  313. days: 删除多少天前的任务记录
  314. Returns:
  315. int: 删除的记录数
  316. 注意:
  317. - 这个方法会永久删除数据库记录
  318. - 建议设置较长的保留期(如30-90天)
  319. - 一般情况下不需要调用此方法
  320. """
  321. with self.get_cursor() as cursor:
  322. cursor.execute('''
  323. DELETE FROM tasks
  324. WHERE completed_at < datetime('now', '-' || ? || ' days')
  325. AND status IN ('completed', 'failed')
  326. ''', (days,))
  327. deleted_count = cursor.rowcount
  328. return deleted_count
  329. def reset_stale_tasks(self, timeout_minutes: int = 60):
  330. """
  331. 重置超时的 processing 任务为 pending
  332. Args:
  333. timeout_minutes: 超时时间(分钟)
  334. """
  335. with self.get_cursor() as cursor:
  336. cursor.execute('''
  337. UPDATE tasks
  338. SET status = 'pending',
  339. worker_id = NULL,
  340. retry_count = retry_count + 1
  341. WHERE status = 'processing'
  342. AND started_at < datetime('now', '-' || ? || ' minutes')
  343. ''', (timeout_minutes,))
  344. reset_count = cursor.rowcount
  345. return reset_count
  346. if __name__ == '__main__':
  347. # 测试代码
  348. db = TaskDB('test_tianshu.db')
  349. # 创建测试任务
  350. task_id = db.create_task(
  351. file_name='test.pdf',
  352. file_path='/tmp/test.pdf',
  353. backend='pipeline',
  354. options={'lang': 'ch', 'formula_enable': True},
  355. priority=1
  356. )
  357. print(f"Created task: {task_id}")
  358. # 查询任务
  359. task = db.get_task(task_id)
  360. print(f"Task details: {task}")
  361. # 获取统计
  362. stats = db.get_queue_stats()
  363. print(f"Queue stats: {stats}")
  364. # 清理测试数据库
  365. Path('test_tianshu.db').unlink(missing_ok=True)
  366. print("Test completed!")