task_scheduler.py 7.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197
  1. """
  2. MinerU Tianshu - Task Scheduler
  3. 天枢任务调度器
  4. 定期检查任务队列,触发 LitServe Workers 拉取和处理任务
  5. """
  6. import asyncio
  7. import aiohttp
  8. from loguru import logger
  9. from task_db import TaskDB
  10. import signal
  11. import sys
  12. class TaskScheduler:
  13. """
  14. 任务调度器
  15. 职责:
  16. 1. 监控 SQLite 任务队列
  17. 2. 当有待处理任务时,触发 LitServe Workers
  18. 3. 管理调度策略(轮询间隔、并发控制等)
  19. """
  20. def __init__(
  21. self,
  22. litserve_url='http://localhost:9000/predict',
  23. poll_interval=2,
  24. max_concurrent_polls=10
  25. ):
  26. """
  27. 初始化调度器
  28. Args:
  29. litserve_url: LitServe Worker 的 URL
  30. poll_interval: 轮询间隔(秒)
  31. max_concurrent_polls: 最大并发轮询数
  32. """
  33. self.litserve_url = litserve_url
  34. self.poll_interval = poll_interval
  35. self.max_concurrent_polls = max_concurrent_polls
  36. self.db = TaskDB()
  37. self.running = True
  38. self.active_polls = 0
  39. async def trigger_worker_poll(self, session: aiohttp.ClientSession):
  40. """
  41. 触发一个 worker 拉取任务
  42. """
  43. self.active_polls += 1
  44. try:
  45. async with session.post(
  46. self.litserve_url,
  47. json={'action': 'poll'},
  48. timeout=aiohttp.ClientTimeout(total=600) # 10分钟超时
  49. ) as resp:
  50. if resp.status == 200:
  51. result = await resp.json()
  52. if result.get('status') == 'completed':
  53. logger.info(f"✅ Task completed: {result.get('task_id')} by {result.get('worker_id')}")
  54. elif result.get('status') == 'failed':
  55. logger.error(f"❌ Task failed: {result.get('task_id')} - {result.get('error')}")
  56. elif result.get('status') == 'idle':
  57. # Worker 空闲,没有任务
  58. pass
  59. return result
  60. else:
  61. logger.error(f"Worker poll failed with status {resp.status}")
  62. except asyncio.TimeoutError:
  63. logger.warning("Worker poll timeout")
  64. except Exception as e:
  65. logger.error(f"Worker poll error: {e}")
  66. finally:
  67. self.active_polls -= 1
  68. async def schedule_loop(self):
  69. """
  70. 主调度循环
  71. """
  72. logger.info("🔄 Task scheduler started")
  73. logger.info(f" LitServe URL: {self.litserve_url}")
  74. logger.info(f" Poll Interval: {self.poll_interval}s")
  75. logger.info(f" Max Concurrent Polls: {self.max_concurrent_polls}")
  76. async with aiohttp.ClientSession() as session:
  77. while self.running:
  78. try:
  79. # 获取队列统计
  80. stats = self.db.get_queue_stats()
  81. pending_count = stats.get('pending', 0)
  82. processing_count = stats.get('processing', 0)
  83. if pending_count > 0:
  84. logger.info(f"📋 Queue status: {pending_count} pending, {processing_count} processing")
  85. # 计算需要触发的 worker 数量
  86. # 考虑:待处理任务数、当前处理中的任务数、活跃的轮询数
  87. needed_workers = min(
  88. pending_count, # 待处理任务数
  89. self.max_concurrent_polls - self.active_polls # 剩余并发数
  90. )
  91. if needed_workers > 0:
  92. # 并发触发多个 worker
  93. tasks = [
  94. self.trigger_worker_poll(session)
  95. for _ in range(needed_workers)
  96. ]
  97. await asyncio.gather(*tasks, return_exceptions=True)
  98. # 等待下一次轮询
  99. await asyncio.sleep(self.poll_interval)
  100. except Exception as e:
  101. logger.error(f"Scheduler loop error: {e}")
  102. await asyncio.sleep(self.poll_interval)
  103. logger.info("⏹️ Task scheduler stopped")
  104. def start(self):
  105. """启动调度器"""
  106. logger.info("🚀 Starting MinerU Tianshu Task Scheduler...")
  107. # 设置信号处理
  108. def signal_handler(sig, frame):
  109. logger.info("\n🛑 Received stop signal, shutting down...")
  110. self.running = False
  111. signal.signal(signal.SIGINT, signal_handler)
  112. signal.signal(signal.SIGTERM, signal_handler)
  113. # 运行调度循环
  114. asyncio.run(self.schedule_loop())
  115. def stop(self):
  116. """停止调度器"""
  117. self.running = False
  118. async def health_check(litserve_url: str) -> bool:
  119. """
  120. 健康检查:验证 LitServe Worker 是否可用
  121. """
  122. try:
  123. async with aiohttp.ClientSession() as session:
  124. async with session.get(
  125. litserve_url.replace('/predict', '/health'),
  126. timeout=aiohttp.ClientTimeout(total=5)
  127. ) as resp:
  128. return resp.status == 200
  129. except:
  130. return False
  131. if __name__ == '__main__':
  132. import argparse
  133. parser = argparse.ArgumentParser(description='MinerU Tianshu Task Scheduler')
  134. parser.add_argument('--litserve-url', type=str, default='http://localhost:9000/predict',
  135. help='LitServe worker URL')
  136. parser.add_argument('--poll-interval', type=int, default=2,
  137. help='Poll interval in seconds')
  138. parser.add_argument('--max-concurrent', type=int, default=10,
  139. help='Maximum concurrent worker polls')
  140. parser.add_argument('--wait-for-workers', action='store_true',
  141. help='Wait for workers to be ready before starting')
  142. args = parser.parse_args()
  143. # 等待 workers 就绪(可选)
  144. if args.wait_for_workers:
  145. logger.info("⏳ Waiting for LitServe workers to be ready...")
  146. import time
  147. max_retries = 30
  148. for i in range(max_retries):
  149. if asyncio.run(health_check(args.litserve_url)):
  150. logger.info("✅ LitServe workers are ready!")
  151. break
  152. time.sleep(2)
  153. if i == max_retries - 1:
  154. logger.error("❌ LitServe workers not responding, starting anyway...")
  155. # 创建并启动调度器
  156. scheduler = TaskScheduler(
  157. litserve_url=args.litserve_url,
  158. poll_interval=args.poll_interval,
  159. max_concurrent_polls=args.max_concurrent
  160. )
  161. try:
  162. scheduler.start()
  163. except KeyboardInterrupt:
  164. logger.info("👋 Scheduler interrupted by user")