pdf_extract_kit.py 9.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228
  1. import numpy as np
  2. import torch
  3. from loguru import logger
  4. import os
  5. import time
  6. import cv2
  7. import yaml
  8. from PIL import Image
  9. os.environ['NO_ALBUMENTATIONS_UPDATE'] = '1' # 禁止albumentations检查更新
  10. os.environ['YOLO_VERBOSE'] = 'False' # disable yolo logger
  11. try:
  12. import torchtext
  13. if torchtext.__version__ >= "0.18.0":
  14. torchtext.disable_torchtext_deprecation_warning()
  15. except ImportError:
  16. pass
  17. from magic_pdf.libs.Constants import *
  18. from magic_pdf.model.model_list import AtomicModel
  19. from magic_pdf.model.sub_modules.model_init import AtomModelSingleton
  20. from magic_pdf.model.sub_modules.model_utils import get_res_list_from_layout_res, crop_img, clean_vram
  21. from magic_pdf.model.sub_modules.ocr.paddleocr.ocr_utils import get_adjusted_mfdetrec_res, get_ocr_result_list
  22. class CustomPEKModel:
  23. def __init__(self, ocr: bool = False, show_log: bool = False, **kwargs):
  24. """
  25. ======== model init ========
  26. """
  27. # 获取当前文件(即 pdf_extract_kit.py)的绝对路径
  28. current_file_path = os.path.abspath(__file__)
  29. # 获取当前文件所在的目录(model)
  30. current_dir = os.path.dirname(current_file_path)
  31. # 上一级目录(magic_pdf)
  32. root_dir = os.path.dirname(current_dir)
  33. # model_config目录
  34. model_config_dir = os.path.join(root_dir, 'resources', 'model_config')
  35. # 构建 model_configs.yaml 文件的完整路径
  36. config_path = os.path.join(model_config_dir, 'model_configs.yaml')
  37. with open(config_path, "r", encoding='utf-8') as f:
  38. self.configs = yaml.load(f, Loader=yaml.FullLoader)
  39. # 初始化解析配置
  40. # layout config
  41. self.layout_config = kwargs.get("layout_config")
  42. self.layout_model_name = self.layout_config.get("model", MODEL_NAME.DocLayout_YOLO)
  43. # formula config
  44. self.formula_config = kwargs.get("formula_config")
  45. self.mfd_model_name = self.formula_config.get("mfd_model", MODEL_NAME.YOLO_V8_MFD)
  46. self.mfr_model_name = self.formula_config.get("mfr_model", MODEL_NAME.UniMerNet_v2_Small)
  47. self.apply_formula = self.formula_config.get("enable", True)
  48. # table config
  49. self.table_config = kwargs.get("table_config")
  50. self.apply_table = self.table_config.get("enable", False)
  51. self.table_max_time = self.table_config.get("max_time", TABLE_MAX_TIME_VALUE)
  52. self.table_model_name = self.table_config.get("model", MODEL_NAME.RAPID_TABLE)
  53. # ocr config
  54. self.apply_ocr = ocr
  55. self.lang = kwargs.get("lang", None)
  56. logger.info(
  57. "DocAnalysis init, this may take some times, layout_model: {}, apply_formula: {}, apply_ocr: {}, "
  58. "apply_table: {}, table_model: {}, lang: {}".format(
  59. self.layout_model_name, self.apply_formula, self.apply_ocr, self.apply_table, self.table_model_name,
  60. self.lang
  61. )
  62. )
  63. # 初始化解析方案
  64. self.device = kwargs.get("device", "cpu")
  65. logger.info("using device: {}".format(self.device))
  66. models_dir = kwargs.get("models_dir", os.path.join(root_dir, "resources", "models"))
  67. logger.info("using models_dir: {}".format(models_dir))
  68. atom_model_manager = AtomModelSingleton()
  69. # 初始化公式识别
  70. if self.apply_formula:
  71. # 初始化公式检测模型
  72. self.mfd_model = atom_model_manager.get_atom_model(
  73. atom_model_name=AtomicModel.MFD,
  74. mfd_weights=str(os.path.join(models_dir, self.configs["weights"][self.mfd_model_name])),
  75. device=self.device
  76. )
  77. # 初始化公式解析模型
  78. mfr_weight_dir = str(os.path.join(models_dir, self.configs["weights"][self.mfr_model_name]))
  79. mfr_cfg_path = str(os.path.join(model_config_dir, "UniMERNet", "demo.yaml"))
  80. self.mfr_model = atom_model_manager.get_atom_model(
  81. atom_model_name=AtomicModel.MFR,
  82. mfr_weight_dir=mfr_weight_dir,
  83. mfr_cfg_path=mfr_cfg_path,
  84. device=self.device
  85. )
  86. # 初始化layout模型
  87. if self.layout_model_name == MODEL_NAME.LAYOUTLMv3:
  88. self.layout_model = atom_model_manager.get_atom_model(
  89. atom_model_name=AtomicModel.Layout,
  90. layout_model_name=MODEL_NAME.LAYOUTLMv3,
  91. layout_weights=str(os.path.join(models_dir, self.configs['weights'][self.layout_model_name])),
  92. layout_config_file=str(os.path.join(model_config_dir, "layoutlmv3", "layoutlmv3_base_inference.yaml")),
  93. device=self.device
  94. )
  95. elif self.layout_model_name == MODEL_NAME.DocLayout_YOLO:
  96. self.layout_model = atom_model_manager.get_atom_model(
  97. atom_model_name=AtomicModel.Layout,
  98. layout_model_name=MODEL_NAME.DocLayout_YOLO,
  99. doclayout_yolo_weights=str(os.path.join(models_dir, self.configs['weights'][self.layout_model_name])),
  100. device=self.device
  101. )
  102. # 初始化ocr
  103. if self.apply_ocr:
  104. self.ocr_model = atom_model_manager.get_atom_model(
  105. atom_model_name=AtomicModel.OCR,
  106. ocr_show_log=show_log,
  107. det_db_box_thresh=0.3,
  108. lang=self.lang
  109. )
  110. # init table model
  111. if self.apply_table:
  112. table_model_dir = self.configs["weights"][self.table_model_name]
  113. self.table_model = atom_model_manager.get_atom_model(
  114. atom_model_name=AtomicModel.Table,
  115. table_model_name=self.table_model_name,
  116. table_model_path=str(os.path.join(models_dir, table_model_dir)),
  117. table_max_time=self.table_max_time,
  118. device=self.device
  119. )
  120. logger.info('DocAnalysis init done!')
  121. def __call__(self, image):
  122. page_start = time.time()
  123. # layout检测
  124. layout_start = time.time()
  125. layout_res = []
  126. if self.layout_model_name == MODEL_NAME.LAYOUTLMv3:
  127. # layoutlmv3
  128. layout_res = self.layout_model(image, ignore_catids=[])
  129. elif self.layout_model_name == MODEL_NAME.DocLayout_YOLO:
  130. # doclayout_yolo
  131. layout_res = self.layout_model.predict(image)
  132. layout_cost = round(time.time() - layout_start, 2)
  133. logger.info(f"layout detection time: {layout_cost}")
  134. pil_img = Image.fromarray(image)
  135. if self.apply_formula:
  136. # 公式检测
  137. mfd_start = time.time()
  138. mfd_res = self.mfd_model.predict(image)
  139. logger.info(f"mfd time: {round(time.time() - mfd_start, 2)}")
  140. # 公式识别
  141. mfr_start = time.time()
  142. formula_list = self.mfr_model.predict(mfd_res, image)
  143. layout_res.extend(formula_list)
  144. mfr_cost = round(time.time() - mfr_start, 2)
  145. logger.info(f"formula nums: {len(formula_list)}, mfr time: {mfr_cost}")
  146. # 清理显存
  147. clean_vram(self.device, vram_threshold=8)
  148. # 从layout_res中获取ocr区域、表格区域、公式区域
  149. ocr_res_list, table_res_list, single_page_mfdetrec_res = get_res_list_from_layout_res(layout_res)
  150. # ocr识别
  151. if self.apply_ocr:
  152. ocr_start = time.time()
  153. # Process each area that requires OCR processing
  154. for res in ocr_res_list:
  155. new_image, useful_list = crop_img(res, pil_img, crop_paste_x=50, crop_paste_y=50)
  156. adjusted_mfdetrec_res = get_adjusted_mfdetrec_res(single_page_mfdetrec_res, useful_list)
  157. # OCR recognition
  158. new_image = cv2.cvtColor(np.asarray(new_image), cv2.COLOR_RGB2BGR)
  159. ocr_res = self.ocr_model.ocr(new_image, mfd_res=adjusted_mfdetrec_res)[0]
  160. # Integration results
  161. if ocr_res:
  162. ocr_result_list = get_ocr_result_list(ocr_res, useful_list)
  163. layout_res.extend(ocr_result_list)
  164. ocr_cost = round(time.time() - ocr_start, 2)
  165. logger.info(f"ocr time: {ocr_cost}")
  166. # 表格识别 table recognition
  167. if self.apply_table:
  168. table_start = time.time()
  169. for res in table_res_list:
  170. new_image, _ = crop_img(res, pil_img)
  171. single_table_start_time = time.time()
  172. html_code = None
  173. if self.table_model_name == MODEL_NAME.STRUCT_EQTABLE:
  174. with torch.no_grad():
  175. table_result = self.table_model.predict(new_image, "html")
  176. if len(table_result) > 0:
  177. html_code = table_result[0]
  178. elif self.table_model_name == MODEL_NAME.TABLE_MASTER:
  179. html_code = self.table_model.img2html(new_image)
  180. elif self.table_model_name == MODEL_NAME.RAPID_TABLE:
  181. html_code, table_cell_bboxes, elapse = self.table_model.predict(new_image)
  182. run_time = time.time() - single_table_start_time
  183. if run_time > self.table_max_time:
  184. logger.warning(f"table recognition processing exceeds max time {self.table_max_time}s")
  185. # 判断是否返回正常
  186. if html_code:
  187. expected_ending = html_code.strip().endswith('</html>') or html_code.strip().endswith('</table>')
  188. if expected_ending:
  189. res["html"] = html_code
  190. else:
  191. logger.warning(f"table recognition processing fails, not found expected HTML table end")
  192. else:
  193. logger.warning(f"table recognition processing fails, not get html return")
  194. logger.info(f"table time: {round(time.time() - table_start, 2)}")
  195. logger.info(f"-----page total time: {round(time.time() - page_start, 2)}-----")
  196. return layout_res