pipeline_analyze.py 6.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206
  1. import os
  2. import time
  3. from typing import List, Tuple
  4. from PIL import Image
  5. from loguru import logger
  6. from .model_init import MineruPipelineModel
  7. from mineru.utils.config_reader import get_device
  8. from ...utils.enum_class import ImageType
  9. from ...utils.pdf_classify import classify
  10. from ...utils.pdf_image_tools import load_images_from_pdf
  11. from ...utils.model_utils import get_vram, clean_memory
  12. os.environ['PYTORCH_ENABLE_MPS_FALLBACK'] = '1' # 让mps可以fallback
  13. os.environ['NO_ALBUMENTATIONS_UPDATE'] = '1' # 禁止albumentations检查更新
  14. class ModelSingleton:
  15. _instance = None
  16. _models = {}
  17. def __new__(cls, *args, **kwargs):
  18. if cls._instance is None:
  19. cls._instance = super().__new__(cls)
  20. return cls._instance
  21. def get_model(
  22. self,
  23. lang=None,
  24. formula_enable=None,
  25. table_enable=None,
  26. ):
  27. key = (lang, formula_enable, table_enable)
  28. if key not in self._models:
  29. self._models[key] = custom_model_init(
  30. lang=lang,
  31. formula_enable=formula_enable,
  32. table_enable=table_enable,
  33. )
  34. return self._models[key]
  35. def custom_model_init(
  36. lang=None,
  37. formula_enable=True,
  38. table_enable=True,
  39. ):
  40. model_init_start = time.time()
  41. # 从配置文件读取model-dir和device
  42. device = get_device()
  43. formula_config = {"enable": formula_enable}
  44. table_config = {"enable": table_enable}
  45. model_input = {
  46. 'device': device,
  47. 'table_config': table_config,
  48. 'formula_config': formula_config,
  49. 'lang': lang,
  50. }
  51. custom_model = MineruPipelineModel(**model_input)
  52. model_init_cost = time.time() - model_init_start
  53. logger.info(f'model init cost: {model_init_cost}')
  54. return custom_model
  55. def doc_analyze(
  56. pdf_bytes_list,
  57. lang_list,
  58. parse_method: str = 'auto',
  59. formula_enable=True,
  60. table_enable=True,
  61. ):
  62. """
  63. 适当调大MIN_BATCH_INFERENCE_SIZE可以提高性能,更大的 MIN_BATCH_INFERENCE_SIZE会消耗更多内存,
  64. 可通过环境变量MINERU_MIN_BATCH_INFERENCE_SIZE设置,默认值为384。
  65. """
  66. min_batch_inference_size = int(os.environ.get('MINERU_MIN_BATCH_INFERENCE_SIZE', 384))
  67. # 收集所有页面信息
  68. all_pages_info = [] # 存储(dataset_index, page_index, img, ocr, lang, width, height)
  69. all_image_lists = []
  70. all_pdf_docs = []
  71. ocr_enabled_list = []
  72. for pdf_idx, pdf_bytes in enumerate(pdf_bytes_list):
  73. # 确定OCR设置
  74. _ocr_enable = False
  75. if parse_method == 'auto':
  76. if classify(pdf_bytes) == 'ocr':
  77. _ocr_enable = True
  78. elif parse_method == 'ocr':
  79. _ocr_enable = True
  80. ocr_enabled_list.append(_ocr_enable)
  81. _lang = lang_list[pdf_idx]
  82. # 收集每个数据集中的页面
  83. images_list, pdf_doc = load_images_from_pdf(pdf_bytes, image_type=ImageType.PIL)
  84. all_image_lists.append(images_list)
  85. all_pdf_docs.append(pdf_doc)
  86. for page_idx in range(len(images_list)):
  87. img_dict = images_list[page_idx]
  88. all_pages_info.append((
  89. pdf_idx, page_idx,
  90. img_dict['img_pil'], _ocr_enable, _lang,
  91. ))
  92. # 准备批处理
  93. images_with_extra_info = [(info[2], info[3], info[4]) for info in all_pages_info]
  94. batch_size = min_batch_inference_size
  95. batch_images = [
  96. images_with_extra_info[i:i + batch_size]
  97. for i in range(0, len(images_with_extra_info), batch_size)
  98. ]
  99. # 执行批处理
  100. results = []
  101. processed_images_count = 0
  102. for index, batch_image in enumerate(batch_images):
  103. processed_images_count += len(batch_image)
  104. logger.info(
  105. f'Batch {index + 1}/{len(batch_images)}: '
  106. f'{processed_images_count} pages/{len(images_with_extra_info)} pages'
  107. )
  108. batch_results = batch_image_analyze(batch_image, formula_enable, table_enable)
  109. results.extend(batch_results)
  110. # 构建返回结果
  111. infer_results = []
  112. for _ in range(len(pdf_bytes_list)):
  113. infer_results.append([])
  114. for i, page_info in enumerate(all_pages_info):
  115. pdf_idx, page_idx, pil_img, _, _ = page_info
  116. result = results[i]
  117. page_info_dict = {'page_no': page_idx, 'width': pil_img.width, 'height': pil_img.height}
  118. page_dict = {'layout_dets': result, 'page_info': page_info_dict}
  119. infer_results[pdf_idx].append(page_dict)
  120. return infer_results, all_image_lists, all_pdf_docs, lang_list, ocr_enabled_list
  121. def batch_image_analyze(
  122. images_with_extra_info: List[Tuple[Image.Image, bool, str]],
  123. formula_enable=True,
  124. table_enable=True):
  125. from .batch_analyze import BatchAnalyze
  126. model_manager = ModelSingleton()
  127. batch_ratio = 1
  128. device = get_device()
  129. if str(device).startswith('npu'):
  130. try:
  131. import torch_npu
  132. if torch_npu.npu.is_available():
  133. torch_npu.npu.set_compile_mode(jit_compile=False)
  134. except Exception as e:
  135. raise RuntimeError(
  136. "NPU is selected as device, but torch_npu is not available. "
  137. "Please ensure that the torch_npu package is installed correctly."
  138. ) from e
  139. if str(device).startswith('npu') or str(device).startswith('cuda'):
  140. vram = get_vram(device)
  141. if vram is not None:
  142. gpu_memory = int(os.getenv('MINERU_VIRTUAL_VRAM_SIZE', round(vram)))
  143. if gpu_memory >= 16:
  144. batch_ratio = 16
  145. elif gpu_memory >= 12:
  146. batch_ratio = 8
  147. elif gpu_memory >= 8:
  148. batch_ratio = 4
  149. elif gpu_memory >= 6:
  150. batch_ratio = 2
  151. else:
  152. batch_ratio = 1
  153. logger.info(f'gpu_memory: {gpu_memory} GB, batch_ratio: {batch_ratio}')
  154. else:
  155. # Default batch_ratio when VRAM can't be determined
  156. batch_ratio = 1
  157. logger.info(f'Could not determine GPU memory, using default batch_ratio: {batch_ratio}')
  158. # 检测torch的版本号
  159. import torch
  160. from packaging import version
  161. if version.parse(torch.__version__) >= version.parse("2.8.0") or str(device).startswith('mps'):
  162. enable_ocr_det_batch = False
  163. else:
  164. enable_ocr_det_batch = True
  165. batch_model = BatchAnalyze(model_manager, batch_ratio, formula_enable, table_enable, enable_ocr_det_batch)
  166. results = batch_model(images_with_extra_info)
  167. clean_memory(get_device())
  168. return results