task_scheduler.py 7.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195
  1. """
  2. MinerU Tianshu - Task Scheduler
  3. 天枢任务调度器
  4. 定期检查任务队列,触发 LitServe Workers 拉取和处理任务
  5. """
  6. import asyncio
  7. import aiohttp
  8. from loguru import logger
  9. from task_db import TaskDB
  10. import signal
  11. class TaskScheduler:
  12. """
  13. 任务调度器
  14. 职责:
  15. 1. 监控 SQLite 任务队列
  16. 2. 当有待处理任务时,触发 LitServe Workers
  17. 3. 管理调度策略(轮询间隔、并发控制等)
  18. """
  19. def __init__(
  20. self,
  21. litserve_url='http://localhost:9000/predict',
  22. poll_interval=2,
  23. max_concurrent_polls=10
  24. ):
  25. """
  26. 初始化调度器
  27. Args:
  28. litserve_url: LitServe Worker 的 URL
  29. poll_interval: 轮询间隔(秒)
  30. max_concurrent_polls: 最大并发轮询数
  31. """
  32. self.litserve_url = litserve_url
  33. self.poll_interval = poll_interval
  34. self.max_concurrent_polls = max_concurrent_polls
  35. self.db = TaskDB()
  36. self.running = True
  37. self.semaphore = asyncio.Semaphore(max_concurrent_polls)
  38. async def trigger_worker_poll(self, session: aiohttp.ClientSession):
  39. """
  40. 触发一个 worker 拉取任务
  41. """
  42. async with self.semaphore:
  43. try:
  44. async with session.post(
  45. self.litserve_url,
  46. json={'action': 'poll'},
  47. timeout=aiohttp.ClientTimeout(total=600) # 10分钟超时
  48. ) as resp:
  49. if resp.status == 200:
  50. result = await resp.json()
  51. if result.get('status') == 'completed':
  52. logger.info(f"✅ Task completed: {result.get('task_id')} by {result.get('worker_id')}")
  53. elif result.get('status') == 'failed':
  54. logger.error(f"❌ Task failed: {result.get('task_id')} - {result.get('error')}")
  55. elif result.get('status') == 'idle':
  56. # Worker 空闲,没有任务
  57. pass
  58. return result
  59. else:
  60. logger.error(f"Worker poll failed with status {resp.status}")
  61. except asyncio.TimeoutError:
  62. logger.warning("Worker poll timeout")
  63. except Exception as e:
  64. logger.error(f"Worker poll error: {e}")
  65. async def schedule_loop(self):
  66. """
  67. 主调度循环
  68. """
  69. logger.info("🔄 Task scheduler started")
  70. logger.info(f" LitServe URL: {self.litserve_url}")
  71. logger.info(f" Poll Interval: {self.poll_interval}s")
  72. logger.info(f" Max Concurrent Polls: {self.max_concurrent_polls}")
  73. async with aiohttp.ClientSession() as session:
  74. while self.running:
  75. try:
  76. # 获取队列统计
  77. stats = self.db.get_queue_stats()
  78. pending_count = stats.get('pending', 0)
  79. processing_count = stats.get('processing', 0)
  80. if pending_count > 0:
  81. logger.info(f"📋 Queue status: {pending_count} pending, {processing_count} processing")
  82. # 计算需要触发的 worker 数量
  83. # 考虑:待处理任务数
  84. needed_workers = min(
  85. pending_count, # 待处理任务数
  86. self.max_concurrent_polls # 最大并发数
  87. )
  88. if needed_workers > 0:
  89. # 并发触发多个 worker
  90. # semaphore 会自动控制实际并发数
  91. tasks = [
  92. self.trigger_worker_poll(session)
  93. for _ in range(needed_workers)
  94. ]
  95. await asyncio.gather(*tasks, return_exceptions=True)
  96. # 等待下一次轮询
  97. await asyncio.sleep(self.poll_interval)
  98. except Exception as e:
  99. logger.error(f"Scheduler loop error: {e}")
  100. await asyncio.sleep(self.poll_interval)
  101. logger.info("⏹️ Task scheduler stopped")
  102. def start(self):
  103. """启动调度器"""
  104. logger.info("🚀 Starting MinerU Tianshu Task Scheduler...")
  105. # 设置信号处理
  106. def signal_handler(sig, frame):
  107. logger.info("\n🛑 Received stop signal, shutting down...")
  108. self.running = False
  109. signal.signal(signal.SIGINT, signal_handler)
  110. signal.signal(signal.SIGTERM, signal_handler)
  111. # 运行调度循环
  112. asyncio.run(self.schedule_loop())
  113. def stop(self):
  114. """停止调度器"""
  115. self.running = False
  116. async def health_check(litserve_url: str) -> bool:
  117. """
  118. 健康检查:验证 LitServe Worker 是否可用
  119. """
  120. try:
  121. async with aiohttp.ClientSession() as session:
  122. async with session.get(
  123. litserve_url.replace('/predict', '/health'),
  124. timeout=aiohttp.ClientTimeout(total=5)
  125. ) as resp:
  126. return resp.status == 200
  127. except:
  128. return False
  129. if __name__ == '__main__':
  130. import argparse
  131. parser = argparse.ArgumentParser(description='MinerU Tianshu Task Scheduler')
  132. parser.add_argument('--litserve-url', type=str, default='http://localhost:9000/predict',
  133. help='LitServe worker URL')
  134. parser.add_argument('--poll-interval', type=int, default=2,
  135. help='Poll interval in seconds')
  136. parser.add_argument('--max-concurrent', type=int, default=10,
  137. help='Maximum concurrent worker polls')
  138. parser.add_argument('--wait-for-workers', action='store_true',
  139. help='Wait for workers to be ready before starting')
  140. args = parser.parse_args()
  141. # 等待 workers 就绪(可选)
  142. if args.wait_for_workers:
  143. logger.info("⏳ Waiting for LitServe workers to be ready...")
  144. import time
  145. max_retries = 30
  146. for i in range(max_retries):
  147. if asyncio.run(health_check(args.litserve_url)):
  148. logger.info("✅ LitServe workers are ready!")
  149. break
  150. time.sleep(2)
  151. if i == max_retries - 1:
  152. logger.error("❌ LitServe workers not responding, starting anyway...")
  153. # 创建并启动调度器
  154. scheduler = TaskScheduler(
  155. litserve_url=args.litserve_url,
  156. poll_interval=args.poll_interval,
  157. max_concurrent_polls=args.max_concurrent
  158. )
  159. try:
  160. scheduler.start()
  161. except KeyboardInterrupt:
  162. logger.info("👋 Scheduler interrupted by user")