Magic_yuan преди 1 месец
родител
ревизия
e7d8bf097a
променени са 4 файла, в които са добавени 137 реда и са изтрити 62 реда
  1. 3 1
      projects/mineru_tianshu/README.md
  2. 42 2
      projects/mineru_tianshu/litserve_worker.py
  3. 91 58
      projects/mineru_tianshu/task_db.py
  4. 1 1
      projects/mineru_tianshu/task_scheduler.py

+ 3 - 1
projects/mineru_tianshu/README.md

@@ -532,11 +532,13 @@ python start_all.py --cleanup-old-files-days 0
 
 | 指标 | v1.x | v2.0 | 提升 |
 |-----|------|------|-----|
-| 任务响应延迟 | 5-10秒 (调度器触发) | 0.5秒 (Worker主动拉取) | **10-20倍** |
+| 任务响应延迟<sup>※</sup> | 5-10秒 (调度器轮询) | 0.5秒 (Worker主动拉取) | **10-20倍** |
 | 并发安全性 | 基础锁机制 | 原子操作 + 状态检查 | **可靠性提升** |
 | 多GPU效率 | 有时会出现显存冲突 | 完全隔离,无冲突 | **稳定性提升** |
 | 系统开销 | 调度器持续运行 | 可选监控(5分钟) | **资源节省** |
 
+※ 任务响应延迟指任务添加到被 Worker 开始处理的时间间隔。v1.x 主要受调度器轮询间隔影响,非测量端到端处理时间。实际端到端响应时间还包括任务类型和系统负载所有因子。
+
 ## 📝 核心依赖
 
 ```txt

+ 42 - 2
projects/mineru_tianshu/litserve_worker.py

@@ -10,6 +10,8 @@ import json
 import sys
 import time
 import threading
+import signal
+import atexit
 from pathlib import Path
 import litserve as ls
 from loguru import logger
@@ -87,7 +89,7 @@ class MinerUWorkerAPI(ls.LitAPI):
             os.environ['CUDA_VISIBLE_DEVICES'] = device_id
             # 设置为 cuda:0,因为对进程来说只能看到一张卡(逻辑ID变为0)
             os.environ['MINERU_DEVICE_MODE'] = 'cuda:0'
-            device_mode = 'cuda:0'
+            device_mode = os.environ['MINERU_DEVICE_MODE']
             logger.info(f"🔒 CUDA_VISIBLE_DEVICES={device_id} (Physical GPU {device_id} → Logical GPU 0)")
         else:
             # 配置 MinerU 环境
@@ -126,6 +128,26 @@ class MinerUWorkerAPI(ls.LitAPI):
             self.worker_thread.start()
             logger.info(f"🔄 Worker loop started (poll_interval={self.poll_interval}s)")
     
+    def teardown(self):
+        """
+        优雅关闭 Worker
+        
+        设置 running 标志为 False,等待 worker 线程完成当前任务后退出。
+        这避免了守护线程可能导致的任务处理不完整或数据库操作不一致问题。
+        """
+        if self.enable_worker_loop and self.worker_thread and self.worker_thread.is_alive():
+            logger.info(f"🛑 Shutting down worker {self.worker_id}...")
+            self.running = False
+            
+            # 等待线程完成当前任务(最多等待 poll_interval * 2 秒)
+            timeout = self.poll_interval * 2
+            self.worker_thread.join(timeout=timeout)
+            
+            if self.worker_thread.is_alive():
+                logger.warning(f"⚠️  Worker thread did not stop within {timeout}s, forcing exit")
+            else:
+                logger.info(f"✅ Worker {self.worker_id} shut down gracefully")
+    
     def _worker_loop(self):
         """
         Worker 主循环:持续拉取并处理任务
@@ -309,7 +331,7 @@ class MinerUWorkerAPI(ls.LitAPI):
             try:
                 clean_memory()
             except Exception as e:
-                logger.debug(f"Memory cleanup: {e}")
+                logger.debug(f"Memory cleanup failed for task {task_id}: {e}")
     
     def _parse_with_markitdown(self, file_path: Path, file_name: str, 
                                output_path: Path):
@@ -450,6 +472,24 @@ def start_litserve_workers(
         timeout=False,  # 不设置超时
     )
     
+    # 注册优雅关闭处理器
+    def graceful_shutdown(signum=None, frame=None):
+        """处理关闭信号,优雅地停止 worker"""
+        logger.info("🛑 Received shutdown signal, gracefully stopping workers...")
+        # 注意:LitServe 会为每个设备创建多个 worker 实例
+        # 这里的 api 只是模板,实际的 worker 实例由 LitServe 管理
+        # teardown 会在每个 worker 进程中被调用
+        if hasattr(api, 'teardown'):
+            api.teardown()
+        sys.exit(0)
+    
+    # 注册信号处理器(Ctrl+C 等)
+    signal.signal(signal.SIGINT, graceful_shutdown)
+    signal.signal(signal.SIGTERM, graceful_shutdown)
+    
+    # 注册 atexit 处理器(正常退出时调用)
+    atexit.register(lambda: api.teardown() if hasattr(api, 'teardown') else None)
+    
     logger.info(f"✅ LitServe worker pool initialized")
     logger.info(f"📡 Listening on: http://0.0.0.0:{port}/predict")
     if enable_worker_loop:

+ 91 - 58
projects/mineru_tianshu/task_db.py

@@ -103,12 +103,13 @@ class TaskDB:
             ''', (task_id, file_name, file_path, backend, json.dumps(options or {}), priority))
         return task_id
     
-    def get_next_task(self, worker_id: str) -> Optional[Dict]:
+    def get_next_task(self, worker_id: str, max_retries: int = 3) -> Optional[Dict]:
         """
         获取下一个待处理任务(原子操作,防止并发冲突)
         
         Args:
             worker_id: Worker ID
+            max_retries: 当任务被其他 worker 抢走时的最大重试次数(默认3次)
             
         Returns:
             task: 任务字典,如果没有任务返回 None
@@ -117,39 +118,97 @@ class TaskDB:
             1. 使用 BEGIN IMMEDIATE 立即获取写锁
             2. UPDATE 时检查 status = 'pending' 防止重复拉取
             3. 检查 rowcount 确保更新成功
+            4. 如果任务被抢走,立即重试而不是返回 None(避免不必要的等待)
         """
-        with self.get_cursor() as cursor:
-            # 使用事务确保原子性
-            cursor.execute('BEGIN IMMEDIATE')
-            
-            # 按优先级和创建时间获取任务
-            cursor.execute('''
-                SELECT * FROM tasks 
-                WHERE status = 'pending' 
-                ORDER BY priority DESC, created_at ASC 
-                LIMIT 1
-            ''')
-            
-            task = cursor.fetchone()
-            if task:
-                # 立即标记为 processing,并确保状态仍是 pending
+        for attempt in range(max_retries):
+            with self.get_cursor() as cursor:
+                # 使用事务确保原子性
+                cursor.execute('BEGIN IMMEDIATE')
+                
+                # 按优先级和创建时间获取任务
                 cursor.execute('''
-                    UPDATE tasks 
-                    SET status = 'processing', 
-                        started_at = CURRENT_TIMESTAMP, 
-                        worker_id = ?
-                    WHERE task_id = ? AND status = 'pending'
-                ''', (worker_id, task['task_id']))
+                    SELECT * FROM tasks 
+                    WHERE status = 'pending' 
+                    ORDER BY priority DESC, created_at ASC 
+                    LIMIT 1
+                ''')
                 
-                # 检查是否更新成功(防止被其他 worker 抢走)
-                if cursor.rowcount == 0:
-                    # 任务被其他进程抢走了,返回 None
-                    # 调用方会在下一次循环中重新获取
+                task = cursor.fetchone()
+                if task:
+                    # 立即标记为 processing,并确保状态仍是 pending
+                    cursor.execute('''
+                        UPDATE tasks 
+                        SET status = 'processing', 
+                            started_at = CURRENT_TIMESTAMP, 
+                            worker_id = ?
+                        WHERE task_id = ? AND status = 'pending'
+                    ''', (worker_id, task['task_id']))
+                    
+                    # 检查是否更新成功(防止被其他 worker 抢走)
+                    if cursor.rowcount == 0:
+                        # 任务被其他进程抢走了,立即重试
+                        # 因为队列中可能还有其他待处理任务
+                        continue
+                    
+                    return dict(task)
+                else:
+                    # 队列中没有待处理任务,返回 None
                     return None
-                
-                return dict(task)
             
-            return None
+        # 重试次数用尽,仍未获取到任务(高并发场景)
+        return None
+    
+    def _build_update_clauses(self, status: str, result_path: str = None, 
+                             error_message: str = None, worker_id: str = None, 
+                             task_id: str = None):
+        """
+        构建 UPDATE 和 WHERE 子句的辅助方法
+        
+        Args:
+            status: 新状态
+            result_path: 结果路径(可选)
+            error_message: 错误信息(可选)
+            worker_id: Worker ID(可选)
+            task_id: 任务ID(可选)
+            
+        Returns:
+            tuple: (update_clauses, update_params, where_clauses, where_params)
+        """
+        update_clauses = ['status = ?']
+        update_params = [status]
+        where_clauses = []
+        where_params = []
+        
+        # 添加 task_id 条件(如果提供)
+        if task_id:
+            where_clauses.append('task_id = ?')
+            where_params.append(task_id)
+        
+        # 处理 completed 状态
+        if status == 'completed':
+            update_clauses.append('completed_at = CURRENT_TIMESTAMP')
+            if result_path:
+                update_clauses.append('result_path = ?')
+                update_params.append(result_path)
+            # 只更新正在处理的任务
+            where_clauses.append("status = 'processing'")
+            if worker_id:
+                where_clauses.append('worker_id = ?')
+                where_params.append(worker_id)
+        
+        # 处理 failed 状态
+        elif status == 'failed':
+            update_clauses.append('completed_at = CURRENT_TIMESTAMP')
+            if error_message:
+                update_clauses.append('error_message = ?')
+                update_params.append(error_message)
+            # 只更新正在处理的任务
+            where_clauses.append("status = 'processing'")
+            if worker_id:
+                where_clauses.append('worker_id = ?')
+                where_params.append(worker_id)
+        
+        return update_clauses, update_params, where_clauses, where_params
     
     def update_task_status(self, task_id: str, status: str, 
                           result_path: str = None, error_message: str = None,
@@ -173,35 +232,9 @@ class TaskDB:
             3. 返回 False 表示任务被其他进程修改了
         """
         with self.get_cursor() as cursor:
-            # 分离 UPDATE 和 WHERE 的参数,确保顺序正确
-            update_clauses = ['status = ?']
-            update_params = [status]
-            where_clauses = ['task_id = ?']
-            where_params = [task_id]
-            
-            # 处理 completed 状态
-            if status == 'completed':
-                update_clauses.append('completed_at = CURRENT_TIMESTAMP')
-                if result_path:
-                    update_clauses.append('result_path = ?')
-                    update_params.append(result_path)
-                # 只更新正在处理的任务
-                where_clauses.append("status = 'processing'")
-                if worker_id:
-                    where_clauses.append('worker_id = ?')
-                    where_params.append(worker_id)
-            
-            # 处理 failed 状态
-            elif status == 'failed':
-                update_clauses.append('completed_at = CURRENT_TIMESTAMP')
-                if error_message:
-                    update_clauses.append('error_message = ?')
-                    update_params.append(error_message)
-                # 只更新正在处理的任务
-                where_clauses.append("status = 'processing'")
-                if worker_id:
-                    where_clauses.append('worker_id = ?')
-                    where_params.append(worker_id)
+            # 使用辅助方法构建 UPDATE 和 WHERE 子句
+            update_clauses, update_params, where_clauses, where_params = \
+                self._build_update_clauses(status, result_path, error_message, worker_id, task_id)
             
             # 合并参数:先 UPDATE 部分,再 WHERE 部分
             all_params = update_params + where_params

+ 1 - 1
projects/mineru_tianshu/task_scheduler.py

@@ -151,7 +151,7 @@ class TaskScheduler:
                     
                     # 4. 定期清理旧任务文件和记录
                     cleanup_counter += 1
-                    # 每24小时清理一次(假设 monitor_interval = 300s
+                    # 每24小时清理一次(基于当前监控间隔计算
                     cleanup_interval_cycles = (24 * 3600) / self.monitor_interval
                     if cleanup_counter >= cleanup_interval_cycles:
                         cleanup_counter = 0