pdf_extract_kit.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266
  1. # flake8: noqa
  2. import os
  3. import time
  4. import cv2
  5. import torch
  6. import yaml
  7. from loguru import logger
  8. os.environ['NO_ALBUMENTATIONS_UPDATE'] = '1' # 禁止albumentations检查更新
  9. from magic_pdf.config.constants import *
  10. from magic_pdf.model.model_list import AtomicModel
  11. from magic_pdf.model.sub_modules.model_init import AtomModelSingleton
  12. from magic_pdf.model.sub_modules.model_utils import (
  13. clean_vram, crop_img, get_res_list_from_layout_res)
  14. from magic_pdf.model.sub_modules.ocr.paddleocr.ocr_utils import (
  15. get_adjusted_mfdetrec_res, get_ocr_result_list)
  16. class CustomPEKModel:
  17. def __init__(self, ocr: bool = False, show_log: bool = False, **kwargs):
  18. """
  19. ======== model init ========
  20. """
  21. # 获取当前文件(即 pdf_extract_kit.py)的绝对路径
  22. current_file_path = os.path.abspath(__file__)
  23. # 获取当前文件所在的目录(model)
  24. current_dir = os.path.dirname(current_file_path)
  25. # 上一级目录(magic_pdf)
  26. root_dir = os.path.dirname(current_dir)
  27. # model_config目录
  28. model_config_dir = os.path.join(root_dir, 'resources', 'model_config')
  29. # 构建 model_configs.yaml 文件的完整路径
  30. config_path = os.path.join(model_config_dir, 'model_configs.yaml')
  31. with open(config_path, 'r', encoding='utf-8') as f:
  32. self.configs = yaml.load(f, Loader=yaml.FullLoader)
  33. # 初始化解析配置
  34. # layout config
  35. self.layout_config = kwargs.get('layout_config')
  36. self.layout_model_name = self.layout_config.get(
  37. 'model', MODEL_NAME.DocLayout_YOLO
  38. )
  39. # formula config
  40. self.formula_config = kwargs.get('formula_config')
  41. self.mfd_model_name = self.formula_config.get(
  42. 'mfd_model', MODEL_NAME.YOLO_V8_MFD
  43. )
  44. self.mfr_model_name = self.formula_config.get(
  45. 'mfr_model', MODEL_NAME.UniMerNet_v2_Small
  46. )
  47. self.apply_formula = self.formula_config.get('enable', True)
  48. # table config
  49. self.table_config = kwargs.get('table_config')
  50. self.apply_table = self.table_config.get('enable', False)
  51. self.table_max_time = self.table_config.get('max_time', TABLE_MAX_TIME_VALUE)
  52. self.table_model_name = self.table_config.get('model', MODEL_NAME.RAPID_TABLE)
  53. self.table_sub_model_name = self.table_config.get('sub_model', None)
  54. # ocr config
  55. self.apply_ocr = ocr
  56. self.lang = kwargs.get('lang', None)
  57. logger.info(
  58. 'DocAnalysis init, this may take some times, layout_model: {}, apply_formula: {}, apply_ocr: {}, '
  59. 'apply_table: {}, table_model: {}, lang: {}'.format(
  60. self.layout_model_name,
  61. self.apply_formula,
  62. self.apply_ocr,
  63. self.apply_table,
  64. self.table_model_name,
  65. self.lang,
  66. )
  67. )
  68. # 初始化解析方案
  69. self.device = kwargs.get('device', 'cpu')
  70. logger.info('using device: {}'.format(self.device))
  71. models_dir = kwargs.get(
  72. 'models_dir', os.path.join(root_dir, 'resources', 'models')
  73. )
  74. logger.info('using models_dir: {}'.format(models_dir))
  75. atom_model_manager = AtomModelSingleton()
  76. # 初始化公式识别
  77. if self.apply_formula:
  78. # 初始化公式检测模型
  79. self.mfd_model = atom_model_manager.get_atom_model(
  80. atom_model_name=AtomicModel.MFD,
  81. mfd_weights=str(
  82. os.path.join(
  83. models_dir, self.configs['weights'][self.mfd_model_name]
  84. )
  85. ),
  86. device=self.device,
  87. )
  88. # 初始化公式解析模型
  89. mfr_weight_dir = str(
  90. os.path.join(models_dir, self.configs['weights'][self.mfr_model_name])
  91. )
  92. mfr_cfg_path = str(os.path.join(model_config_dir, 'UniMERNet', 'demo.yaml'))
  93. self.mfr_model = atom_model_manager.get_atom_model(
  94. atom_model_name=AtomicModel.MFR,
  95. mfr_weight_dir=mfr_weight_dir,
  96. mfr_cfg_path=mfr_cfg_path,
  97. device=self.device,
  98. )
  99. # 初始化layout模型
  100. if self.layout_model_name == MODEL_NAME.LAYOUTLMv3:
  101. self.layout_model = atom_model_manager.get_atom_model(
  102. atom_model_name=AtomicModel.Layout,
  103. layout_model_name=MODEL_NAME.LAYOUTLMv3,
  104. layout_weights=str(
  105. os.path.join(
  106. models_dir, self.configs['weights'][self.layout_model_name]
  107. )
  108. ),
  109. layout_config_file=str(
  110. os.path.join(
  111. model_config_dir, 'layoutlmv3', 'layoutlmv3_base_inference.yaml'
  112. )
  113. ),
  114. device='cpu' if str(self.device).startswith("mps") else self.device,
  115. )
  116. elif self.layout_model_name == MODEL_NAME.DocLayout_YOLO:
  117. self.layout_model = atom_model_manager.get_atom_model(
  118. atom_model_name=AtomicModel.Layout,
  119. layout_model_name=MODEL_NAME.DocLayout_YOLO,
  120. doclayout_yolo_weights=str(
  121. os.path.join(
  122. models_dir, self.configs['weights'][self.layout_model_name]
  123. )
  124. ),
  125. device=self.device,
  126. )
  127. # 初始化ocr
  128. self.ocr_model = atom_model_manager.get_atom_model(
  129. atom_model_name=AtomicModel.OCR,
  130. ocr_show_log=show_log,
  131. det_db_box_thresh=0.3,
  132. lang=self.lang
  133. )
  134. # init table model
  135. if self.apply_table:
  136. table_model_dir = self.configs['weights'][self.table_model_name]
  137. self.table_model = atom_model_manager.get_atom_model(
  138. atom_model_name=AtomicModel.Table,
  139. table_model_name=self.table_model_name,
  140. table_model_path=str(os.path.join(models_dir, table_model_dir)),
  141. table_max_time=self.table_max_time,
  142. device=self.device,
  143. ocr_engine=self.ocr_model,
  144. table_sub_model_name=self.table_sub_model_name
  145. )
  146. logger.info('DocAnalysis init done!')
  147. def __call__(self, image):
  148. # layout检测
  149. layout_start = time.time()
  150. layout_res = []
  151. if self.layout_model_name == MODEL_NAME.LAYOUTLMv3:
  152. # layoutlmv3
  153. layout_res = self.layout_model(image, ignore_catids=[])
  154. elif self.layout_model_name == MODEL_NAME.DocLayout_YOLO:
  155. layout_res = self.layout_model.predict(image)
  156. layout_cost = round(time.time() - layout_start, 2)
  157. logger.info(f'layout detection time: {layout_cost}')
  158. if self.apply_formula:
  159. # 公式检测
  160. mfd_start = time.time()
  161. mfd_res = self.mfd_model.predict(image)
  162. logger.info(f'mfd time: {round(time.time() - mfd_start, 2)}')
  163. # 公式识别
  164. mfr_start = time.time()
  165. formula_list = self.mfr_model.predict(mfd_res, image)
  166. layout_res.extend(formula_list)
  167. mfr_cost = round(time.time() - mfr_start, 2)
  168. logger.info(f'formula nums: {len(formula_list)}, mfr time: {mfr_cost}')
  169. # 清理显存
  170. clean_vram(self.device, vram_threshold=6)
  171. # 从layout_res中获取ocr区域、表格区域、公式区域
  172. ocr_res_list, table_res_list, single_page_mfdetrec_res = (
  173. get_res_list_from_layout_res(layout_res)
  174. )
  175. # ocr识别
  176. ocr_start = time.time()
  177. # Process each area that requires OCR processing
  178. for res in ocr_res_list:
  179. new_image, useful_list = crop_img(res, image, crop_paste_x=50, crop_paste_y=50)
  180. adjusted_mfdetrec_res = get_adjusted_mfdetrec_res(single_page_mfdetrec_res, useful_list)
  181. # OCR recognition
  182. new_image = cv2.cvtColor(new_image, cv2.COLOR_RGB2BGR)
  183. if self.apply_ocr:
  184. ocr_res = self.ocr_model.ocr(new_image, mfd_res=adjusted_mfdetrec_res)[0]
  185. else:
  186. ocr_res = self.ocr_model.ocr(new_image, mfd_res=adjusted_mfdetrec_res, rec=False)[0]
  187. # Integration results
  188. if ocr_res:
  189. ocr_result_list = get_ocr_result_list(ocr_res, useful_list)
  190. layout_res.extend(ocr_result_list)
  191. ocr_cost = round(time.time() - ocr_start, 2)
  192. if self.apply_ocr:
  193. logger.info(f"ocr time: {ocr_cost}")
  194. else:
  195. logger.info(f"det time: {ocr_cost}")
  196. # 表格识别 table recognition
  197. if self.apply_table:
  198. table_start = time.time()
  199. for res in table_res_list:
  200. new_image, _ = crop_img(res, image)
  201. single_table_start_time = time.time()
  202. html_code = None
  203. if self.table_model_name == MODEL_NAME.STRUCT_EQTABLE:
  204. with torch.no_grad():
  205. table_result = self.table_model.predict(new_image, 'html')
  206. if len(table_result) > 0:
  207. html_code = table_result[0]
  208. elif self.table_model_name == MODEL_NAME.TABLE_MASTER:
  209. html_code = self.table_model.img2html(new_image)
  210. elif self.table_model_name == MODEL_NAME.RAPID_TABLE:
  211. html_code, table_cell_bboxes, logic_points, elapse = self.table_model.predict(
  212. new_image
  213. )
  214. run_time = time.time() - single_table_start_time
  215. if run_time > self.table_max_time:
  216. logger.warning(
  217. f'table recognition processing exceeds max time {self.table_max_time}s'
  218. )
  219. # 判断是否返回正常
  220. if html_code:
  221. expected_ending = html_code.strip().endswith(
  222. '</html>'
  223. ) or html_code.strip().endswith('</table>')
  224. if expected_ending:
  225. res['html'] = html_code
  226. else:
  227. logger.warning(
  228. 'table recognition processing fails, not found expected HTML table end'
  229. )
  230. else:
  231. logger.warning(
  232. 'table recognition processing fails, not get html return'
  233. )
  234. logger.info(f'table time: {round(time.time() - table_start, 2)}')
  235. return layout_res