pipeline_analyze.py 6.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217
  1. import os
  2. import time
  3. import numpy as np
  4. import torch
  5. from pypdfium2 import PdfDocument
  6. from mineru.backend.pipeline.model_init import MineruPipelineModel
  7. from .model_json_to_middle_json import result_to_middle_json
  8. from ...utils.pdf_classify import classify
  9. from ...utils.pdf_image_tools import pdf_page_to_image
  10. from loguru import logger
  11. from ...utils.model_utils import get_vram, clean_memory
  12. from magic_pdf.libs.config_reader import (get_device, get_formula_config,
  13. get_layout_config,
  14. get_local_models_dir,
  15. get_table_recog_config)
  16. os.environ['PYTORCH_ENABLE_MPS_FALLBACK'] = '1' # 让mps可以fallback
  17. os.environ['NO_ALBUMENTATIONS_UPDATE'] = '1' # 禁止albumentations检查更新
  18. class ModelSingleton:
  19. _instance = None
  20. _models = {}
  21. def __new__(cls, *args, **kwargs):
  22. if cls._instance is None:
  23. cls._instance = super().__new__(cls)
  24. return cls._instance
  25. def get_model(
  26. self,
  27. lang=None,
  28. formula_enable=None,
  29. table_enable=None,
  30. ):
  31. key = (lang, formula_enable, table_enable)
  32. if key not in self._models:
  33. self._models[key] = custom_model_init(
  34. lang=lang,
  35. formula_enable=formula_enable,
  36. table_enable=table_enable,
  37. )
  38. return self._models[key]
  39. def custom_model_init(
  40. lang=None,
  41. formula_enable=None,
  42. table_enable=None,
  43. ):
  44. model_init_start = time.time()
  45. # 从配置文件读取model-dir和device
  46. local_models_dir = get_local_models_dir()
  47. device = get_device()
  48. formula_config = get_formula_config()
  49. if formula_enable is not None:
  50. formula_config['enable'] = formula_enable
  51. table_config = get_table_recog_config()
  52. if table_enable is not None:
  53. table_config['enable'] = table_enable
  54. model_input = {
  55. 'models_dir': local_models_dir,
  56. 'device': device,
  57. 'table_config': table_config,
  58. 'formula_config': formula_config,
  59. 'lang': lang,
  60. }
  61. custom_model = MineruPipelineModel(**model_input)
  62. model_init_cost = time.time() - model_init_start
  63. logger.info(f'model init cost: {model_init_cost}')
  64. return custom_model
  65. def doc_analyze(
  66. pdf_bytes_list,
  67. lang_list,
  68. parse_method: str = 'auto',
  69. formula_enable=None,
  70. table_enable=None,
  71. ):
  72. """
  73. 统一处理文档分析函数,根据输入参数类型决定处理单个数据集还是多个数据集
  74. Args:
  75. dataset_or_datasets: 单个Dataset对象或Dataset对象列表
  76. parse_method: 解析方法,'auto'/'ocr'/'txt'
  77. formula_enable: 是否启用公式识别
  78. table_enable: 是否启用表格识别
  79. Returns:
  80. 单个dataset时返回单个model_json,多个dataset时返回model_json列表
  81. """
  82. MIN_BATCH_INFERENCE_SIZE = int(os.environ.get('MINERU_MIN_BATCH_INFERENCE_SIZE', 100))
  83. # 收集所有页面信息
  84. all_pages_info = [] # 存储(dataset_index, page_index, img, ocr, lang, width, height)
  85. for pdf_idx, pdf_bytes in enumerate(pdf_bytes_list):
  86. # 确定OCR设置
  87. _ocr = False
  88. if parse_method == 'auto':
  89. if classify(pdf_bytes) == 'ocr':
  90. _ocr = True
  91. elif parse_method == 'ocr':
  92. _ocr = True
  93. _lang = lang_list[pdf_idx]
  94. # 收集每个数据集中的页面
  95. pdf_doc = PdfDocument(pdf_bytes)
  96. for page_idx in range(len(pdf_doc)):
  97. page_data = pdf_doc[page_idx]
  98. img_dict = pdf_page_to_image(page_data)
  99. all_pages_info.append((
  100. pdf_idx, page_idx,
  101. img_dict['img_pil'], _ocr, _lang,
  102. img_dict['scale']
  103. ))
  104. # 准备批处理
  105. images_with_extra_info = [(info[2], info[3], info[4]) for info in all_pages_info]
  106. batch_size = MIN_BATCH_INFERENCE_SIZE
  107. batch_images = [
  108. images_with_extra_info[i:i + batch_size]
  109. for i in range(0, len(images_with_extra_info), batch_size)
  110. ]
  111. # 执行批处理
  112. results = []
  113. processed_images_count = 0
  114. for index, batch_image in enumerate(batch_images):
  115. processed_images_count += len(batch_image)
  116. logger.info(
  117. f'Batch {index + 1}/{len(batch_images)}: '
  118. f'{processed_images_count} pages/{len(images_with_extra_info)} pages'
  119. )
  120. batch_results = may_batch_image_analyze(batch_image, formula_enable, table_enable)
  121. results.extend(batch_results)
  122. # 构建返回结果
  123. # 多数据集模式:按数据集分组结果
  124. infer_results = [[] for _ in datasets]
  125. for i, page_info in enumerate(all_pages_info):
  126. pdf_idx, page_idx, pil_img, _, _ = page_info
  127. result = results[i]
  128. page_info_dict = {'page_no': page_idx, 'width': pil_img.get_width(), 'height': pil_img.get_height()}
  129. page_dict = {'layout_dets': result, 'page_info': page_info_dict}
  130. infer_results[pdf_idx].append(page_dict)
  131. middle_json_list = []
  132. for model_json in infer_results:
  133. middle_json = result_to_middle_json(model_json)
  134. middle_json_list.append(middle_json)
  135. return middle_json_list, infer_results
  136. def may_batch_image_analyze(
  137. images_with_extra_info: list[(np.ndarray, bool, str)],
  138. formula_enable=None,
  139. table_enable=None):
  140. # os.environ['CUDA_VISIBLE_DEVICES'] = str(idx)
  141. from .batch_analyze import BatchAnalyze
  142. model_manager = ModelSingleton()
  143. batch_ratio = 1
  144. device = get_device()
  145. if str(device).startswith('npu'):
  146. import torch_npu
  147. if torch_npu.npu.is_available():
  148. torch.npu.set_compile_mode(jit_compile=False)
  149. if str(device).startswith('npu') or str(device).startswith('cuda'):
  150. vram = get_vram(device)
  151. if vram is not None:
  152. gpu_memory = int(os.getenv('VIRTUAL_VRAM_SIZE', round(vram)))
  153. if gpu_memory >= 16:
  154. batch_ratio = 16
  155. elif gpu_memory >= 12:
  156. batch_ratio = 8
  157. elif gpu_memory >= 8:
  158. batch_ratio = 4
  159. elif gpu_memory >= 6:
  160. batch_ratio = 2
  161. else:
  162. batch_ratio = 1
  163. logger.info(f'gpu_memory: {gpu_memory} GB, batch_ratio: {batch_ratio}')
  164. else:
  165. # Default batch_ratio when VRAM can't be determined
  166. batch_ratio = 1
  167. logger.info(f'Could not determine GPU memory, using default batch_ratio: {batch_ratio}')
  168. batch_model = BatchAnalyze(model_manager, batch_ratio, formula_enable, table_enable)
  169. results = batch_model(images_with_extra_info)
  170. clean_memory(get_device())
  171. return results