|
|
@@ -9,7 +9,6 @@ import aiohttp
|
|
|
from loguru import logger
|
|
|
from task_db import TaskDB
|
|
|
import signal
|
|
|
-import sys
|
|
|
|
|
|
|
|
|
class TaskScheduler:
|
|
|
@@ -41,40 +40,38 @@ class TaskScheduler:
|
|
|
self.max_concurrent_polls = max_concurrent_polls
|
|
|
self.db = TaskDB()
|
|
|
self.running = True
|
|
|
- self.active_polls = 0
|
|
|
+ self.semaphore = asyncio.Semaphore(max_concurrent_polls)
|
|
|
|
|
|
async def trigger_worker_poll(self, session: aiohttp.ClientSession):
|
|
|
"""
|
|
|
触发一个 worker 拉取任务
|
|
|
"""
|
|
|
- self.active_polls += 1
|
|
|
- try:
|
|
|
- async with session.post(
|
|
|
- self.litserve_url,
|
|
|
- json={'action': 'poll'},
|
|
|
- timeout=aiohttp.ClientTimeout(total=600) # 10分钟超时
|
|
|
- ) as resp:
|
|
|
- if resp.status == 200:
|
|
|
- result = await resp.json()
|
|
|
-
|
|
|
- if result.get('status') == 'completed':
|
|
|
- logger.info(f"✅ Task completed: {result.get('task_id')} by {result.get('worker_id')}")
|
|
|
- elif result.get('status') == 'failed':
|
|
|
- logger.error(f"❌ Task failed: {result.get('task_id')} - {result.get('error')}")
|
|
|
- elif result.get('status') == 'idle':
|
|
|
- # Worker 空闲,没有任务
|
|
|
- pass
|
|
|
-
|
|
|
- return result
|
|
|
- else:
|
|
|
- logger.error(f"Worker poll failed with status {resp.status}")
|
|
|
-
|
|
|
- except asyncio.TimeoutError:
|
|
|
- logger.warning("Worker poll timeout")
|
|
|
- except Exception as e:
|
|
|
- logger.error(f"Worker poll error: {e}")
|
|
|
- finally:
|
|
|
- self.active_polls -= 1
|
|
|
+ async with self.semaphore:
|
|
|
+ try:
|
|
|
+ async with session.post(
|
|
|
+ self.litserve_url,
|
|
|
+ json={'action': 'poll'},
|
|
|
+ timeout=aiohttp.ClientTimeout(total=600) # 10分钟超时
|
|
|
+ ) as resp:
|
|
|
+ if resp.status == 200:
|
|
|
+ result = await resp.json()
|
|
|
+
|
|
|
+ if result.get('status') == 'completed':
|
|
|
+ logger.info(f"✅ Task completed: {result.get('task_id')} by {result.get('worker_id')}")
|
|
|
+ elif result.get('status') == 'failed':
|
|
|
+ logger.error(f"❌ Task failed: {result.get('task_id')} - {result.get('error')}")
|
|
|
+ elif result.get('status') == 'idle':
|
|
|
+ # Worker 空闲,没有任务
|
|
|
+ pass
|
|
|
+
|
|
|
+ return result
|
|
|
+ else:
|
|
|
+ logger.error(f"Worker poll failed with status {resp.status}")
|
|
|
+
|
|
|
+ except asyncio.TimeoutError:
|
|
|
+ logger.warning("Worker poll timeout")
|
|
|
+ except Exception as e:
|
|
|
+ logger.error(f"Worker poll error: {e}")
|
|
|
|
|
|
async def schedule_loop(self):
|
|
|
"""
|
|
|
@@ -97,14 +94,15 @@ class TaskScheduler:
|
|
|
logger.info(f"📋 Queue status: {pending_count} pending, {processing_count} processing")
|
|
|
|
|
|
# 计算需要触发的 worker 数量
|
|
|
- # 考虑:待处理任务数、当前处理中的任务数、活跃的轮询数
|
|
|
+ # 考虑:待处理任务数
|
|
|
needed_workers = min(
|
|
|
pending_count, # 待处理任务数
|
|
|
- self.max_concurrent_polls - self.active_polls # 剩余并发数
|
|
|
+ self.max_concurrent_polls # 最大并发数
|
|
|
)
|
|
|
|
|
|
if needed_workers > 0:
|
|
|
# 并发触发多个 worker
|
|
|
+ # semaphore 会自动控制实际并发数
|
|
|
tasks = [
|
|
|
self.trigger_worker_poll(session)
|
|
|
for _ in range(needed_workers)
|