litserve_worker.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364
  1. """
  2. MinerU Tianshu - LitServe Worker
  3. 天枢 LitServe Worker
  4. 使用 LitServe 实现 GPU 资源的自动负载均衡
  5. 从 SQLite 队列拉取任务并处理
  6. """
  7. import os
  8. import json
  9. import sys
  10. from pathlib import Path
  11. import litserve as ls
  12. from loguru import logger
  13. from typing import Optional
  14. # 添加父目录到路径以导入 MinerU
  15. sys.path.insert(0, str(Path(__file__).parent.parent.parent))
  16. from task_db import TaskDB
  17. from mineru.cli.common import do_parse, read_fn
  18. from mineru.utils.config_reader import get_device
  19. from mineru.utils.model_utils import get_vram
  20. # 尝试导入 markitdown
  21. try:
  22. from markitdown import MarkItDown
  23. MARKITDOWN_AVAILABLE = True
  24. except ImportError:
  25. MARKITDOWN_AVAILABLE = False
  26. logger.warning("⚠️ markitdown not available, Office format parsing will be disabled")
  27. class MinerUWorkerAPI(ls.LitAPI):
  28. """
  29. LitServe API Worker
  30. 从 SQLite 队列拉取任务,利用 LitServe 的自动 GPU 负载均衡
  31. 支持两种解析方式:
  32. - PDF/图片 -> MinerU 解析(GPU 加速)
  33. - 其他所有格式 -> MarkItDown 解析(快速处理)
  34. """
  35. # 支持的文件格式定义
  36. # MinerU 专用格式:PDF 和图片
  37. PDF_IMAGE_FORMATS = {'.pdf', '.png', '.jpg', '.jpeg', '.bmp', '.tiff', '.tif', '.webp'}
  38. # 其他所有格式都使用 MarkItDown 解析
  39. def __init__(self, output_dir='/tmp/mineru_tianshu_output', worker_id_prefix='tianshu'):
  40. super().__init__()
  41. self.output_dir = Path(output_dir)
  42. self.output_dir.mkdir(parents=True, exist_ok=True)
  43. self.worker_id_prefix = worker_id_prefix
  44. self.db = TaskDB()
  45. self.worker_id = None
  46. self.markitdown = None
  47. def setup(self, device):
  48. """
  49. 初始化环境(每个 worker 进程调用一次)
  50. Args:
  51. device: LitServe 分配的设备 (cuda:0, cuda:1, etc.)
  52. """
  53. # 生成唯一的 worker_id
  54. import socket
  55. hostname = socket.gethostname()
  56. pid = os.getpid()
  57. self.worker_id = f"{self.worker_id_prefix}-{hostname}-{device}-{pid}"
  58. logger.info(f"⚙️ Worker {self.worker_id} setting up on device: {device}")
  59. # 配置 MinerU 环境
  60. if os.getenv('MINERU_DEVICE_MODE', None) is None:
  61. os.environ['MINERU_DEVICE_MODE'] = device if device != 'auto' else get_device()
  62. device_mode = os.environ['MINERU_DEVICE_MODE']
  63. # 配置显存
  64. if os.getenv('MINERU_VIRTUAL_VRAM_SIZE', None) is None:
  65. if device_mode.startswith("cuda") or device_mode.startswith("npu"):
  66. try:
  67. vram = round(get_vram(device_mode))
  68. os.environ['MINERU_VIRTUAL_VRAM_SIZE'] = str(vram)
  69. except:
  70. os.environ['MINERU_VIRTUAL_VRAM_SIZE'] = '8' # 默认值
  71. else:
  72. os.environ['MINERU_VIRTUAL_VRAM_SIZE'] = '1'
  73. # 初始化 MarkItDown(如果可用)
  74. if MARKITDOWN_AVAILABLE:
  75. self.markitdown = MarkItDown()
  76. logger.info(f"✅ MarkItDown initialized for Office format parsing")
  77. logger.info(f"✅ Worker {self.worker_id} ready")
  78. logger.info(f" Device: {device_mode}")
  79. logger.info(f" VRAM: {os.environ['MINERU_VIRTUAL_VRAM_SIZE']}GB")
  80. def decode_request(self, request):
  81. """
  82. 解码请求
  83. 接收一个 'poll' 信号来触发从数据库拉取任务
  84. """
  85. return request.get('action', 'poll')
  86. def _get_file_type(self, file_path: str) -> str:
  87. """
  88. 判断文件类型
  89. Args:
  90. file_path: 文件路径
  91. Returns:
  92. 'pdf_image': PDF 或图片格式,使用 MinerU 解析
  93. 'markitdown': 其他所有格式,使用 markitdown 解析
  94. """
  95. suffix = Path(file_path).suffix.lower()
  96. if suffix in self.PDF_IMAGE_FORMATS:
  97. return 'pdf_image'
  98. else:
  99. # 所有非 PDF/图片格式都使用 markitdown
  100. return 'markitdown'
  101. def _parse_with_mineru(self, file_path: Path, file_name: str, task_id: str,
  102. backend: str, options: dict, output_path: Path):
  103. """
  104. 使用 MinerU 解析 PDF 和图片格式
  105. Args:
  106. file_path: 文件路径
  107. file_name: 文件名
  108. task_id: 任务ID
  109. backend: 后端类型
  110. options: 解析选项
  111. output_path: 输出路径
  112. """
  113. logger.info(f"📄 Using MinerU to parse: {file_name}")
  114. # 读取文件
  115. pdf_bytes = read_fn(file_path)
  116. # 执行解析
  117. do_parse(
  118. output_dir=str(output_path),
  119. pdf_file_names=[Path(file_name).stem],
  120. pdf_bytes_list=[pdf_bytes],
  121. p_lang_list=[options.get('lang', 'ch')],
  122. backend=backend,
  123. parse_method=options.get('method', 'auto'),
  124. formula_enable=options.get('formula_enable', True),
  125. table_enable=options.get('table_enable', True),
  126. )
  127. def _parse_with_markitdown(self, file_path: Path, file_name: str,
  128. output_path: Path):
  129. """
  130. 使用 markitdown 解析文档(支持 Office、HTML、文本等多种格式)
  131. Args:
  132. file_path: 文件路径
  133. file_name: 文件名
  134. output_path: 输出路径
  135. """
  136. if not MARKITDOWN_AVAILABLE or self.markitdown is None:
  137. raise RuntimeError("markitdown is not available. Please install it: pip install markitdown")
  138. logger.info(f"📊 Using MarkItDown to parse: {file_name}")
  139. # 使用 markitdown 转换文档
  140. result = self.markitdown.convert(str(file_path))
  141. # 保存为 markdown 文件
  142. output_file = output_path / f"{Path(file_name).stem}.md"
  143. output_file.write_text(result.text_content, encoding='utf-8')
  144. logger.info(f"📝 Markdown saved to: {output_file}")
  145. def predict(self, action):
  146. """
  147. 从数据库拉取任务并处理
  148. 这里是实际的任务处理逻辑,LitServe 会自动管理 GPU 负载均衡
  149. 支持根据文件类型选择不同的解析器:
  150. - PDF/图片 -> MinerU(GPU 加速)
  151. - 其他所有格式 -> MarkItDown(快速处理)
  152. """
  153. if action != 'poll':
  154. return {
  155. 'status': 'error',
  156. 'message': 'Invalid action. Use {"action": "poll"} to trigger task processing.'
  157. }
  158. # 从数据库获取任务
  159. task = self.db.get_next_task(self.worker_id)
  160. if not task:
  161. # 没有任务时返回空闲状态
  162. return {
  163. 'status': 'idle',
  164. 'message': 'No pending tasks in queue',
  165. 'worker_id': self.worker_id
  166. }
  167. # 提取任务信息
  168. task_id = task['task_id']
  169. file_path = task['file_path']
  170. file_name = task['file_name']
  171. backend = task['backend']
  172. options = json.loads(task['options'])
  173. logger.info(f"🔄 Worker {self.worker_id} processing task {task_id}: {file_name}")
  174. try:
  175. # 准备输出目录
  176. output_path = self.output_dir / task_id
  177. output_path.mkdir(parents=True, exist_ok=True)
  178. # 判断文件类型并选择解析方式
  179. file_type = self._get_file_type(file_path)
  180. if file_type == 'pdf_image':
  181. # 使用 MinerU 解析 PDF 和图片
  182. self._parse_with_mineru(
  183. file_path=Path(file_path),
  184. file_name=file_name,
  185. task_id=task_id,
  186. backend=backend,
  187. options=options,
  188. output_path=output_path
  189. )
  190. parse_method = 'MinerU'
  191. else: # file_type == 'markitdown'
  192. # 使用 markitdown 解析所有其他格式
  193. self._parse_with_markitdown(
  194. file_path=Path(file_path),
  195. file_name=file_name,
  196. output_path=output_path
  197. )
  198. parse_method = 'MarkItDown'
  199. # 更新状态为成功
  200. self.db.update_task_status(task_id, 'completed', str(output_path))
  201. logger.info(f"✅ Task {task_id} completed by {self.worker_id}")
  202. logger.info(f" Parser: {parse_method}")
  203. logger.info(f" Output: {output_path}")
  204. return {
  205. 'status': 'completed',
  206. 'task_id': task_id,
  207. 'file_name': file_name,
  208. 'parse_method': parse_method,
  209. 'file_type': file_type,
  210. 'output_path': str(output_path),
  211. 'worker_id': self.worker_id
  212. }
  213. except Exception as e:
  214. logger.error(f"❌ Task {task_id} failed: {e}")
  215. self.db.update_task_status(task_id, 'failed', error_message=str(e))
  216. return {
  217. 'status': 'failed',
  218. 'task_id': task_id,
  219. 'error': str(e),
  220. 'worker_id': self.worker_id
  221. }
  222. finally:
  223. # 清理临时文件
  224. try:
  225. if Path(file_path).exists():
  226. Path(file_path).unlink()
  227. except Exception as e:
  228. logger.warning(f"Failed to clean up temp file {file_path}: {e}")
  229. def encode_response(self, response):
  230. """编码响应"""
  231. return response
  232. def start_litserve_workers(
  233. output_dir='/tmp/mineru_tianshu_output',
  234. accelerator='auto',
  235. devices='auto',
  236. workers_per_device=1,
  237. port=9000
  238. ):
  239. """
  240. 启动 LitServe Worker Pool
  241. Args:
  242. output_dir: 输出目录
  243. accelerator: 加速器类型 (auto/cuda/cpu/mps)
  244. devices: 使用的设备 (auto/[0,1,2])
  245. workers_per_device: 每个 GPU 的 worker 数量
  246. port: 服务端口
  247. """
  248. logger.info("=" * 60)
  249. logger.info("🚀 Starting MinerU Tianshu LitServe Worker Pool")
  250. logger.info("=" * 60)
  251. logger.info(f"📂 Output Directory: {output_dir}")
  252. logger.info(f"🎮 Accelerator: {accelerator}")
  253. logger.info(f"💾 Devices: {devices}")
  254. logger.info(f"👷 Workers per Device: {workers_per_device}")
  255. logger.info(f"🔌 Port: {port}")
  256. logger.info("=" * 60)
  257. # 创建 LitServe 服务器
  258. api = MinerUWorkerAPI(output_dir=output_dir)
  259. server = ls.LitServer(
  260. api,
  261. accelerator=accelerator,
  262. devices=devices,
  263. workers_per_device=workers_per_device,
  264. timeout=False, # 不设置超时
  265. )
  266. logger.info(f"✅ LitServe worker pool initialized")
  267. logger.info(f"📡 Listening on: http://0.0.0.0:{port}/predict")
  268. logger.info(f"🔄 Workers will poll SQLite queue for tasks")
  269. logger.info("=" * 60)
  270. # 启动服务器
  271. server.run(port=port, generate_client_file=False)
  272. if __name__ == '__main__':
  273. import argparse
  274. parser = argparse.ArgumentParser(description='MinerU Tianshu LitServe Worker Pool')
  275. parser.add_argument('--output-dir', type=str, default='/tmp/mineru_tianshu_output',
  276. help='Output directory for processed files')
  277. parser.add_argument('--accelerator', type=str, default='auto',
  278. choices=['auto', 'cuda', 'cpu', 'mps'],
  279. help='Accelerator type')
  280. parser.add_argument('--devices', type=str, default='auto',
  281. help='Devices to use (auto or comma-separated list like 0,1,2)')
  282. parser.add_argument('--workers-per-device', type=int, default=1,
  283. help='Number of workers per device')
  284. parser.add_argument('--port', type=int, default=9000,
  285. help='Server port')
  286. args = parser.parse_args()
  287. # 处理 devices 参数
  288. devices = args.devices
  289. if devices != 'auto':
  290. try:
  291. devices = [int(d) for d in devices.split(',')]
  292. except:
  293. logger.warning(f"Invalid devices format: {devices}, using 'auto'")
  294. devices = 'auto'
  295. start_litserve_workers(
  296. output_dir=args.output_dir,
  297. accelerator=args.accelerator,
  298. devices=devices,
  299. workers_per_device=args.workers_per_device,
  300. port=args.port
  301. )