doc_analyze_by_custom_model.py 7.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235
  1. import os
  2. import time
  3. import numpy as np
  4. import torch
  5. from mineru.backend.pipeline.model_init import MineruPipelineModel
  6. os.environ['FLAGS_npu_jit_compile'] = '0' # 关闭paddle的jit编译
  7. os.environ['FLAGS_use_stride_kernel'] = '0'
  8. os.environ['PYTORCH_ENABLE_MPS_FALLBACK'] = '1' # 让mps可以fallback
  9. os.environ['NO_ALBUMENTATIONS_UPDATE'] = '1' # 禁止albumentations检查更新
  10. from loguru import logger
  11. from ...utils.model_utils import get_vram, clean_memory
  12. from magic_pdf.libs.config_reader import (get_device, get_formula_config,
  13. get_layout_config,
  14. get_local_models_dir,
  15. get_table_recog_config)
  16. class ModelSingleton:
  17. _instance = None
  18. _models = {}
  19. def __new__(cls, *args, **kwargs):
  20. if cls._instance is None:
  21. cls._instance = super().__new__(cls)
  22. return cls._instance
  23. def get_model(
  24. self,
  25. lang=None,
  26. formula_enable=None,
  27. table_enable=None,
  28. ):
  29. key = (lang, formula_enable, table_enable)
  30. if key not in self._models:
  31. self._models[key] = custom_model_init(
  32. lang=lang,
  33. formula_enable=formula_enable,
  34. table_enable=table_enable,
  35. )
  36. return self._models[key]
  37. def custom_model_init(
  38. lang=None,
  39. formula_enable=None,
  40. table_enable=None,
  41. ):
  42. model_init_start = time.time()
  43. # 从配置文件读取model-dir和device
  44. local_models_dir = get_local_models_dir()
  45. device = get_device()
  46. formula_config = get_formula_config()
  47. if formula_enable is not None:
  48. formula_config['enable'] = formula_enable
  49. table_config = get_table_recog_config()
  50. if table_enable is not None:
  51. table_config['enable'] = table_enable
  52. model_input = {
  53. 'models_dir': local_models_dir,
  54. 'device': device,
  55. 'table_config': table_config,
  56. 'formula_config': formula_config,
  57. 'lang': lang,
  58. }
  59. custom_model = MineruPipelineModel(**model_input)
  60. model_init_cost = time.time() - model_init_start
  61. logger.info(f'model init cost: {model_init_cost}')
  62. return custom_model
  63. def doc_analyze(
  64. dataset: Dataset,
  65. ocr: bool = False,
  66. start_page_id=0,
  67. end_page_id=None,
  68. lang=None,
  69. formula_enable=None,
  70. table_enable=None,
  71. ):
  72. end_page_id = (
  73. end_page_id
  74. if end_page_id is not None and end_page_id >= 0
  75. else len(dataset) - 1
  76. )
  77. MIN_BATCH_INFERENCE_SIZE = int(os.environ.get('MINERU_MIN_BATCH_INFERENCE_SIZE', 100))
  78. images = []
  79. page_wh_list = []
  80. for index in range(len(dataset)):
  81. if start_page_id <= index <= end_page_id:
  82. page_data = dataset.get_page(index)
  83. img_dict = page_data.get_image()
  84. images.append(img_dict['img'])
  85. page_wh_list.append((img_dict['width'], img_dict['height']))
  86. images_with_extra_info = [(images[index], ocr, dataset._lang) for index in range(len(images))]
  87. if len(images) >= MIN_BATCH_INFERENCE_SIZE:
  88. batch_size = MIN_BATCH_INFERENCE_SIZE
  89. batch_images = [images_with_extra_info[i:i+batch_size] for i in range(0, len(images_with_extra_info), batch_size)]
  90. else:
  91. batch_images = [images_with_extra_info]
  92. results = []
  93. processed_images_count = 0
  94. for index, batch_image in enumerate(batch_images):
  95. processed_images_count += len(batch_image)
  96. logger.info(f'Batch {index + 1}/{len(batch_images)}: {processed_images_count} pages/{len(images_with_extra_info)} pages')
  97. result = may_batch_image_analyze(batch_image, formula_enable, table_enable)
  98. results.extend(result)
  99. model_json = []
  100. for index in range(len(dataset)):
  101. if start_page_id <= index <= end_page_id:
  102. result = results.pop(0)
  103. page_width, page_height = page_wh_list.pop(0)
  104. else:
  105. result = []
  106. page_height = 0
  107. page_width = 0
  108. page_info = {'page_no': index, 'width': page_width, 'height': page_height}
  109. page_dict = {'layout_dets': result, 'page_info': page_info}
  110. model_json.append(page_dict)
  111. return model_json
  112. def batch_doc_analyze(
  113. datasets: list[Dataset],
  114. parse_method: str = 'auto',
  115. lang=None,
  116. formula_enable=None,
  117. table_enable=None,
  118. ):
  119. MIN_BATCH_INFERENCE_SIZE = int(os.environ.get('MINERU_MIN_BATCH_INFERENCE_SIZE', 100))
  120. batch_size = MIN_BATCH_INFERENCE_SIZE
  121. page_wh_list = []
  122. images_with_extra_info = []
  123. for dataset in datasets:
  124. ocr = False
  125. if parse_method == 'auto':
  126. if dataset.classify() == 'txt':
  127. ocr = False
  128. elif dataset.classify() == 'ocr':
  129. ocr = True
  130. elif parse_method == 'ocr':
  131. ocr = True
  132. elif parse_method == 'txt':
  133. ocr = False
  134. _lang = dataset._lang
  135. for index in range(len(dataset)):
  136. page_data = dataset.get_page(index)
  137. img_dict = page_data.get_image()
  138. page_wh_list.append((img_dict['width'], img_dict['height']))
  139. images_with_extra_info.append((img_dict['img'], ocr, _lang))
  140. batch_images = [images_with_extra_info[i:i+batch_size] for i in range(0, len(images_with_extra_info), batch_size)]
  141. results = []
  142. processed_images_count = 0
  143. for index, batch_image in enumerate(batch_images):
  144. processed_images_count += len(batch_image)
  145. logger.info(f'Batch {index + 1}/{len(batch_images)}: {processed_images_count} pages/{len(images_with_extra_info)} pages')
  146. result = may_batch_image_analyze(batch_image, formula_enable, table_enable)
  147. results.extend(result)
  148. infer_results = []
  149. for index in range(len(datasets)):
  150. dataset = datasets[index]
  151. model_json = []
  152. for i in range(len(dataset)):
  153. result = results.pop(0)
  154. page_width, page_height = page_wh_list.pop(0)
  155. page_info = {'page_no': i, 'width': page_width, 'height': page_height}
  156. page_dict = {'layout_dets': result, 'page_info': page_info}
  157. model_json.append(page_dict)
  158. infer_results.append(model_json)
  159. return infer_results
  160. def may_batch_image_analyze(
  161. images_with_extra_info: list[(np.ndarray, bool, str)],
  162. formula_enable=None,
  163. table_enable=None):
  164. # os.environ['CUDA_VISIBLE_DEVICES'] = str(idx)
  165. from .batch_analyze import BatchAnalyze
  166. model_manager = ModelSingleton()
  167. batch_ratio = 1
  168. device = get_device()
  169. if str(device).startswith('npu'):
  170. import torch_npu
  171. if torch_npu.npu.is_available():
  172. torch.npu.set_compile_mode(jit_compile=False)
  173. if str(device).startswith('npu') or str(device).startswith('cuda'):
  174. vram = get_vram(device)
  175. if vram is not None:
  176. gpu_memory = int(os.getenv('VIRTUAL_VRAM_SIZE', round(vram)))
  177. if gpu_memory >= 16:
  178. batch_ratio = 16
  179. elif gpu_memory >= 12:
  180. batch_ratio = 8
  181. elif gpu_memory >= 8:
  182. batch_ratio = 4
  183. elif gpu_memory >= 6:
  184. batch_ratio = 2
  185. else:
  186. batch_ratio = 1
  187. logger.info(f'gpu_memory: {gpu_memory} GB, batch_ratio: {batch_ratio}')
  188. else:
  189. # Default batch_ratio when VRAM can't be determined
  190. batch_ratio = 1
  191. logger.info(f'Could not determine GPU memory, using default batch_ratio: {batch_ratio}')
  192. batch_model = BatchAnalyze(model_manager, batch_ratio, formula_enable, table_enable)
  193. results = batch_model(images_with_extra_info)
  194. clean_memory(get_device())
  195. return results