litserve_worker.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363
  1. """
  2. MinerU Tianshu - LitServe Worker
  3. 天枢 LitServe Worker
  4. 使用 LitServe 实现 GPU 资源的自动负载均衡
  5. 从 SQLite 队列拉取任务并处理
  6. """
  7. import os
  8. import json
  9. import sys
  10. from pathlib import Path
  11. import litserve as ls
  12. from loguru import logger
  13. # 添加父目录到路径以导入 MinerU
  14. sys.path.insert(0, str(Path(__file__).parent.parent.parent))
  15. from task_db import TaskDB
  16. from mineru.cli.common import do_parse, read_fn
  17. from mineru.utils.config_reader import get_device
  18. from mineru.utils.model_utils import get_vram
  19. # 尝试导入 markitdown
  20. try:
  21. from markitdown import MarkItDown
  22. MARKITDOWN_AVAILABLE = True
  23. except ImportError:
  24. MARKITDOWN_AVAILABLE = False
  25. logger.warning("⚠️ markitdown not available, Office format parsing will be disabled")
  26. class MinerUWorkerAPI(ls.LitAPI):
  27. """
  28. LitServe API Worker
  29. 从 SQLite 队列拉取任务,利用 LitServe 的自动 GPU 负载均衡
  30. 支持两种解析方式:
  31. - PDF/图片 -> MinerU 解析(GPU 加速)
  32. - 其他所有格式 -> MarkItDown 解析(快速处理)
  33. """
  34. # 支持的文件格式定义
  35. # MinerU 专用格式:PDF 和图片
  36. PDF_IMAGE_FORMATS = {'.pdf', '.png', '.jpg', '.jpeg', '.bmp', '.tiff', '.tif', '.webp'}
  37. # 其他所有格式都使用 MarkItDown 解析
  38. def __init__(self, output_dir='/tmp/mineru_tianshu_output', worker_id_prefix='tianshu'):
  39. super().__init__()
  40. self.output_dir = Path(output_dir)
  41. self.output_dir.mkdir(parents=True, exist_ok=True)
  42. self.worker_id_prefix = worker_id_prefix
  43. self.db = TaskDB()
  44. self.worker_id = None
  45. self.markitdown = None
  46. def setup(self, device):
  47. """
  48. 初始化环境(每个 worker 进程调用一次)
  49. Args:
  50. device: LitServe 分配的设备 (cuda:0, cuda:1, etc.)
  51. """
  52. # 生成唯一的 worker_id
  53. import socket
  54. hostname = socket.gethostname()
  55. pid = os.getpid()
  56. self.worker_id = f"{self.worker_id_prefix}-{hostname}-{device}-{pid}"
  57. logger.info(f"⚙️ Worker {self.worker_id} setting up on device: {device}")
  58. # 配置 MinerU 环境
  59. if os.getenv('MINERU_DEVICE_MODE', None) is None:
  60. os.environ['MINERU_DEVICE_MODE'] = device if device != 'auto' else get_device()
  61. device_mode = os.environ['MINERU_DEVICE_MODE']
  62. # 配置显存
  63. if os.getenv('MINERU_VIRTUAL_VRAM_SIZE', None) is None:
  64. if device_mode.startswith("cuda") or device_mode.startswith("npu"):
  65. try:
  66. vram = round(get_vram(device_mode))
  67. os.environ['MINERU_VIRTUAL_VRAM_SIZE'] = str(vram)
  68. except:
  69. os.environ['MINERU_VIRTUAL_VRAM_SIZE'] = '8' # 默认值
  70. else:
  71. os.environ['MINERU_VIRTUAL_VRAM_SIZE'] = '1'
  72. # 初始化 MarkItDown(如果可用)
  73. if MARKITDOWN_AVAILABLE:
  74. self.markitdown = MarkItDown()
  75. logger.info(f"✅ MarkItDown initialized for Office format parsing")
  76. logger.info(f"✅ Worker {self.worker_id} ready")
  77. logger.info(f" Device: {device_mode}")
  78. logger.info(f" VRAM: {os.environ['MINERU_VIRTUAL_VRAM_SIZE']}GB")
  79. def decode_request(self, request):
  80. """
  81. 解码请求
  82. 接收一个 'poll' 信号来触发从数据库拉取任务
  83. """
  84. return request.get('action', 'poll')
  85. def _get_file_type(self, file_path: str) -> str:
  86. """
  87. 判断文件类型
  88. Args:
  89. file_path: 文件路径
  90. Returns:
  91. 'pdf_image': PDF 或图片格式,使用 MinerU 解析
  92. 'markitdown': 其他所有格式,使用 markitdown 解析
  93. """
  94. suffix = Path(file_path).suffix.lower()
  95. if suffix in self.PDF_IMAGE_FORMATS:
  96. return 'pdf_image'
  97. else:
  98. # 所有非 PDF/图片格式都使用 markitdown
  99. return 'markitdown'
  100. def _parse_with_mineru(self, file_path: Path, file_name: str, task_id: str,
  101. backend: str, options: dict, output_path: Path):
  102. """
  103. 使用 MinerU 解析 PDF 和图片格式
  104. Args:
  105. file_path: 文件路径
  106. file_name: 文件名
  107. task_id: 任务ID
  108. backend: 后端类型
  109. options: 解析选项
  110. output_path: 输出路径
  111. """
  112. logger.info(f"📄 Using MinerU to parse: {file_name}")
  113. # 读取文件
  114. pdf_bytes = read_fn(file_path)
  115. # 执行解析
  116. do_parse(
  117. output_dir=str(output_path),
  118. pdf_file_names=[Path(file_name).stem],
  119. pdf_bytes_list=[pdf_bytes],
  120. p_lang_list=[options.get('lang', 'ch')],
  121. backend=backend,
  122. parse_method=options.get('method', 'auto'),
  123. formula_enable=options.get('formula_enable', True),
  124. table_enable=options.get('table_enable', True),
  125. )
  126. def _parse_with_markitdown(self, file_path: Path, file_name: str,
  127. output_path: Path):
  128. """
  129. 使用 markitdown 解析文档(支持 Office、HTML、文本等多种格式)
  130. Args:
  131. file_path: 文件路径
  132. file_name: 文件名
  133. output_path: 输出路径
  134. """
  135. if not MARKITDOWN_AVAILABLE or self.markitdown is None:
  136. raise RuntimeError("markitdown is not available. Please install it: pip install markitdown")
  137. logger.info(f"📊 Using MarkItDown to parse: {file_name}")
  138. # 使用 markitdown 转换文档
  139. result = self.markitdown.convert(str(file_path))
  140. # 保存为 markdown 文件
  141. output_file = output_path / f"{Path(file_name).stem}.md"
  142. output_file.write_text(result.text_content, encoding='utf-8')
  143. logger.info(f"📝 Markdown saved to: {output_file}")
  144. def predict(self, action):
  145. """
  146. 从数据库拉取任务并处理
  147. 这里是实际的任务处理逻辑,LitServe 会自动管理 GPU 负载均衡
  148. 支持根据文件类型选择不同的解析器:
  149. - PDF/图片 -> MinerU(GPU 加速)
  150. - 其他所有格式 -> MarkItDown(快速处理)
  151. """
  152. if action != 'poll':
  153. return {
  154. 'status': 'error',
  155. 'message': 'Invalid action. Use {"action": "poll"} to trigger task processing.'
  156. }
  157. # 从数据库获取任务
  158. task = self.db.get_next_task(self.worker_id)
  159. if not task:
  160. # 没有任务时返回空闲状态
  161. return {
  162. 'status': 'idle',
  163. 'message': 'No pending tasks in queue',
  164. 'worker_id': self.worker_id
  165. }
  166. # 提取任务信息
  167. task_id = task['task_id']
  168. file_path = task['file_path']
  169. file_name = task['file_name']
  170. backend = task['backend']
  171. options = json.loads(task['options'])
  172. logger.info(f"🔄 Worker {self.worker_id} processing task {task_id}: {file_name}")
  173. try:
  174. # 准备输出目录
  175. output_path = self.output_dir / task_id
  176. output_path.mkdir(parents=True, exist_ok=True)
  177. # 判断文件类型并选择解析方式
  178. file_type = self._get_file_type(file_path)
  179. if file_type == 'pdf_image':
  180. # 使用 MinerU 解析 PDF 和图片
  181. self._parse_with_mineru(
  182. file_path=Path(file_path),
  183. file_name=file_name,
  184. task_id=task_id,
  185. backend=backend,
  186. options=options,
  187. output_path=output_path
  188. )
  189. parse_method = 'MinerU'
  190. else: # file_type == 'markitdown'
  191. # 使用 markitdown 解析所有其他格式
  192. self._parse_with_markitdown(
  193. file_path=Path(file_path),
  194. file_name=file_name,
  195. output_path=output_path
  196. )
  197. parse_method = 'MarkItDown'
  198. # 更新状态为成功
  199. self.db.update_task_status(task_id, 'completed', str(output_path))
  200. logger.info(f"✅ Task {task_id} completed by {self.worker_id}")
  201. logger.info(f" Parser: {parse_method}")
  202. logger.info(f" Output: {output_path}")
  203. return {
  204. 'status': 'completed',
  205. 'task_id': task_id,
  206. 'file_name': file_name,
  207. 'parse_method': parse_method,
  208. 'file_type': file_type,
  209. 'output_path': str(output_path),
  210. 'worker_id': self.worker_id
  211. }
  212. except Exception as e:
  213. logger.error(f"❌ Task {task_id} failed: {e}")
  214. self.db.update_task_status(task_id, 'failed', error_message=str(e))
  215. return {
  216. 'status': 'failed',
  217. 'task_id': task_id,
  218. 'error': str(e),
  219. 'worker_id': self.worker_id
  220. }
  221. finally:
  222. # 清理临时文件
  223. try:
  224. if Path(file_path).exists():
  225. Path(file_path).unlink()
  226. except Exception as e:
  227. logger.warning(f"Failed to clean up temp file {file_path}: {e}")
  228. def encode_response(self, response):
  229. """编码响应"""
  230. return response
  231. def start_litserve_workers(
  232. output_dir='/tmp/mineru_tianshu_output',
  233. accelerator='auto',
  234. devices='auto',
  235. workers_per_device=1,
  236. port=9000
  237. ):
  238. """
  239. 启动 LitServe Worker Pool
  240. Args:
  241. output_dir: 输出目录
  242. accelerator: 加速器类型 (auto/cuda/cpu/mps)
  243. devices: 使用的设备 (auto/[0,1,2])
  244. workers_per_device: 每个 GPU 的 worker 数量
  245. port: 服务端口
  246. """
  247. logger.info("=" * 60)
  248. logger.info("🚀 Starting MinerU Tianshu LitServe Worker Pool")
  249. logger.info("=" * 60)
  250. logger.info(f"📂 Output Directory: {output_dir}")
  251. logger.info(f"🎮 Accelerator: {accelerator}")
  252. logger.info(f"💾 Devices: {devices}")
  253. logger.info(f"👷 Workers per Device: {workers_per_device}")
  254. logger.info(f"🔌 Port: {port}")
  255. logger.info("=" * 60)
  256. # 创建 LitServe 服务器
  257. api = MinerUWorkerAPI(output_dir=output_dir)
  258. server = ls.LitServer(
  259. api,
  260. accelerator=accelerator,
  261. devices=devices,
  262. workers_per_device=workers_per_device,
  263. timeout=False, # 不设置超时
  264. )
  265. logger.info(f"✅ LitServe worker pool initialized")
  266. logger.info(f"📡 Listening on: http://0.0.0.0:{port}/predict")
  267. logger.info(f"🔄 Workers will poll SQLite queue for tasks")
  268. logger.info("=" * 60)
  269. # 启动服务器
  270. server.run(port=port, generate_client_file=False)
  271. if __name__ == '__main__':
  272. import argparse
  273. parser = argparse.ArgumentParser(description='MinerU Tianshu LitServe Worker Pool')
  274. parser.add_argument('--output-dir', type=str, default='/tmp/mineru_tianshu_output',
  275. help='Output directory for processed files')
  276. parser.add_argument('--accelerator', type=str, default='auto',
  277. choices=['auto', 'cuda', 'cpu', 'mps'],
  278. help='Accelerator type')
  279. parser.add_argument('--devices', type=str, default='auto',
  280. help='Devices to use (auto or comma-separated list like 0,1,2)')
  281. parser.add_argument('--workers-per-device', type=int, default=1,
  282. help='Number of workers per device')
  283. parser.add_argument('--port', type=int, default=9000,
  284. help='Server port')
  285. args = parser.parse_args()
  286. # 处理 devices 参数
  287. devices = args.devices
  288. if devices != 'auto':
  289. try:
  290. devices = [int(d) for d in devices.split(',')]
  291. except:
  292. logger.warning(f"Invalid devices format: {devices}, using 'auto'")
  293. devices = 'auto'
  294. start_litserve_workers(
  295. output_dir=args.output_dir,
  296. accelerator=args.accelerator,
  297. devices=devices,
  298. workers_per_device=args.workers_per_device,
  299. port=args.port
  300. )