|
|
@@ -103,12 +103,13 @@ class TaskDB:
|
|
|
''', (task_id, file_name, file_path, backend, json.dumps(options or {}), priority))
|
|
|
return task_id
|
|
|
|
|
|
- def get_next_task(self, worker_id: str) -> Optional[Dict]:
|
|
|
+ def get_next_task(self, worker_id: str, max_retries: int = 3) -> Optional[Dict]:
|
|
|
"""
|
|
|
获取下一个待处理任务(原子操作,防止并发冲突)
|
|
|
|
|
|
Args:
|
|
|
worker_id: Worker ID
|
|
|
+ max_retries: 当任务被其他 worker 抢走时的最大重试次数(默认3次)
|
|
|
|
|
|
Returns:
|
|
|
task: 任务字典,如果没有任务返回 None
|
|
|
@@ -117,39 +118,97 @@ class TaskDB:
|
|
|
1. 使用 BEGIN IMMEDIATE 立即获取写锁
|
|
|
2. UPDATE 时检查 status = 'pending' 防止重复拉取
|
|
|
3. 检查 rowcount 确保更新成功
|
|
|
+ 4. 如果任务被抢走,立即重试而不是返回 None(避免不必要的等待)
|
|
|
"""
|
|
|
- with self.get_cursor() as cursor:
|
|
|
- # 使用事务确保原子性
|
|
|
- cursor.execute('BEGIN IMMEDIATE')
|
|
|
-
|
|
|
- # 按优先级和创建时间获取任务
|
|
|
- cursor.execute('''
|
|
|
- SELECT * FROM tasks
|
|
|
- WHERE status = 'pending'
|
|
|
- ORDER BY priority DESC, created_at ASC
|
|
|
- LIMIT 1
|
|
|
- ''')
|
|
|
-
|
|
|
- task = cursor.fetchone()
|
|
|
- if task:
|
|
|
- # 立即标记为 processing,并确保状态仍是 pending
|
|
|
+ for attempt in range(max_retries):
|
|
|
+ with self.get_cursor() as cursor:
|
|
|
+ # 使用事务确保原子性
|
|
|
+ cursor.execute('BEGIN IMMEDIATE')
|
|
|
+
|
|
|
+ # 按优先级和创建时间获取任务
|
|
|
cursor.execute('''
|
|
|
- UPDATE tasks
|
|
|
- SET status = 'processing',
|
|
|
- started_at = CURRENT_TIMESTAMP,
|
|
|
- worker_id = ?
|
|
|
- WHERE task_id = ? AND status = 'pending'
|
|
|
- ''', (worker_id, task['task_id']))
|
|
|
+ SELECT * FROM tasks
|
|
|
+ WHERE status = 'pending'
|
|
|
+ ORDER BY priority DESC, created_at ASC
|
|
|
+ LIMIT 1
|
|
|
+ ''')
|
|
|
|
|
|
- # 检查是否更新成功(防止被其他 worker 抢走)
|
|
|
- if cursor.rowcount == 0:
|
|
|
- # 任务被其他进程抢走了,返回 None
|
|
|
- # 调用方会在下一次循环中重新获取
|
|
|
+ task = cursor.fetchone()
|
|
|
+ if task:
|
|
|
+ # 立即标记为 processing,并确保状态仍是 pending
|
|
|
+ cursor.execute('''
|
|
|
+ UPDATE tasks
|
|
|
+ SET status = 'processing',
|
|
|
+ started_at = CURRENT_TIMESTAMP,
|
|
|
+ worker_id = ?
|
|
|
+ WHERE task_id = ? AND status = 'pending'
|
|
|
+ ''', (worker_id, task['task_id']))
|
|
|
+
|
|
|
+ # 检查是否更新成功(防止被其他 worker 抢走)
|
|
|
+ if cursor.rowcount == 0:
|
|
|
+ # 任务被其他进程抢走了,立即重试
|
|
|
+ # 因为队列中可能还有其他待处理任务
|
|
|
+ continue
|
|
|
+
|
|
|
+ return dict(task)
|
|
|
+ else:
|
|
|
+ # 队列中没有待处理任务,返回 None
|
|
|
return None
|
|
|
-
|
|
|
- return dict(task)
|
|
|
|
|
|
- return None
|
|
|
+ # 重试次数用尽,仍未获取到任务(高并发场景)
|
|
|
+ return None
|
|
|
+
|
|
|
+ def _build_update_clauses(self, status: str, result_path: str = None,
|
|
|
+ error_message: str = None, worker_id: str = None,
|
|
|
+ task_id: str = None):
|
|
|
+ """
|
|
|
+ 构建 UPDATE 和 WHERE 子句的辅助方法
|
|
|
+
|
|
|
+ Args:
|
|
|
+ status: 新状态
|
|
|
+ result_path: 结果路径(可选)
|
|
|
+ error_message: 错误信息(可选)
|
|
|
+ worker_id: Worker ID(可选)
|
|
|
+ task_id: 任务ID(可选)
|
|
|
+
|
|
|
+ Returns:
|
|
|
+ tuple: (update_clauses, update_params, where_clauses, where_params)
|
|
|
+ """
|
|
|
+ update_clauses = ['status = ?']
|
|
|
+ update_params = [status]
|
|
|
+ where_clauses = []
|
|
|
+ where_params = []
|
|
|
+
|
|
|
+ # 添加 task_id 条件(如果提供)
|
|
|
+ if task_id:
|
|
|
+ where_clauses.append('task_id = ?')
|
|
|
+ where_params.append(task_id)
|
|
|
+
|
|
|
+ # 处理 completed 状态
|
|
|
+ if status == 'completed':
|
|
|
+ update_clauses.append('completed_at = CURRENT_TIMESTAMP')
|
|
|
+ if result_path:
|
|
|
+ update_clauses.append('result_path = ?')
|
|
|
+ update_params.append(result_path)
|
|
|
+ # 只更新正在处理的任务
|
|
|
+ where_clauses.append("status = 'processing'")
|
|
|
+ if worker_id:
|
|
|
+ where_clauses.append('worker_id = ?')
|
|
|
+ where_params.append(worker_id)
|
|
|
+
|
|
|
+ # 处理 failed 状态
|
|
|
+ elif status == 'failed':
|
|
|
+ update_clauses.append('completed_at = CURRENT_TIMESTAMP')
|
|
|
+ if error_message:
|
|
|
+ update_clauses.append('error_message = ?')
|
|
|
+ update_params.append(error_message)
|
|
|
+ # 只更新正在处理的任务
|
|
|
+ where_clauses.append("status = 'processing'")
|
|
|
+ if worker_id:
|
|
|
+ where_clauses.append('worker_id = ?')
|
|
|
+ where_params.append(worker_id)
|
|
|
+
|
|
|
+ return update_clauses, update_params, where_clauses, where_params
|
|
|
|
|
|
def update_task_status(self, task_id: str, status: str,
|
|
|
result_path: str = None, error_message: str = None,
|
|
|
@@ -173,35 +232,9 @@ class TaskDB:
|
|
|
3. 返回 False 表示任务被其他进程修改了
|
|
|
"""
|
|
|
with self.get_cursor() as cursor:
|
|
|
- # 分离 UPDATE 和 WHERE 的参数,确保顺序正确
|
|
|
- update_clauses = ['status = ?']
|
|
|
- update_params = [status]
|
|
|
- where_clauses = ['task_id = ?']
|
|
|
- where_params = [task_id]
|
|
|
-
|
|
|
- # 处理 completed 状态
|
|
|
- if status == 'completed':
|
|
|
- update_clauses.append('completed_at = CURRENT_TIMESTAMP')
|
|
|
- if result_path:
|
|
|
- update_clauses.append('result_path = ?')
|
|
|
- update_params.append(result_path)
|
|
|
- # 只更新正在处理的任务
|
|
|
- where_clauses.append("status = 'processing'")
|
|
|
- if worker_id:
|
|
|
- where_clauses.append('worker_id = ?')
|
|
|
- where_params.append(worker_id)
|
|
|
-
|
|
|
- # 处理 failed 状态
|
|
|
- elif status == 'failed':
|
|
|
- update_clauses.append('completed_at = CURRENT_TIMESTAMP')
|
|
|
- if error_message:
|
|
|
- update_clauses.append('error_message = ?')
|
|
|
- update_params.append(error_message)
|
|
|
- # 只更新正在处理的任务
|
|
|
- where_clauses.append("status = 'processing'")
|
|
|
- if worker_id:
|
|
|
- where_clauses.append('worker_id = ?')
|
|
|
- where_params.append(worker_id)
|
|
|
+ # 使用辅助方法构建 UPDATE 和 WHERE 子句
|
|
|
+ update_clauses, update_params, where_clauses, where_params = \
|
|
|
+ self._build_update_clauses(status, result_path, error_message, worker_id, task_id)
|
|
|
|
|
|
# 合并参数:先 UPDATE 部分,再 WHERE 部分
|
|
|
all_params = update_params + where_params
|