pdf_extract_kit.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274
  1. # flake8: noqa
  2. import os
  3. import time
  4. import cv2
  5. import torch
  6. import yaml
  7. from loguru import logger
  8. os.environ['NO_ALBUMENTATIONS_UPDATE'] = '1' # 禁止albumentations检查更新
  9. try:
  10. import torchtext
  11. if torchtext.__version__ >= '0.18.0':
  12. torchtext.disable_torchtext_deprecation_warning()
  13. except ImportError:
  14. pass
  15. from magic_pdf.config.constants import *
  16. from magic_pdf.model.model_list import AtomicModel
  17. from magic_pdf.model.sub_modules.model_init import AtomModelSingleton
  18. from magic_pdf.model.sub_modules.model_utils import (
  19. clean_vram, crop_img, get_res_list_from_layout_res)
  20. from magic_pdf.model.sub_modules.ocr.paddleocr.ocr_utils import (
  21. get_adjusted_mfdetrec_res, get_ocr_result_list)
  22. class CustomPEKModel:
  23. def __init__(self, ocr: bool = False, show_log: bool = False, **kwargs):
  24. """
  25. ======== model init ========
  26. """
  27. # 获取当前文件(即 pdf_extract_kit.py)的绝对路径
  28. current_file_path = os.path.abspath(__file__)
  29. # 获取当前文件所在的目录(model)
  30. current_dir = os.path.dirname(current_file_path)
  31. # 上一级目录(magic_pdf)
  32. root_dir = os.path.dirname(current_dir)
  33. # model_config目录
  34. model_config_dir = os.path.join(root_dir, 'resources', 'model_config')
  35. # 构建 model_configs.yaml 文件的完整路径
  36. config_path = os.path.join(model_config_dir, 'model_configs.yaml')
  37. with open(config_path, 'r', encoding='utf-8') as f:
  38. self.configs = yaml.load(f, Loader=yaml.FullLoader)
  39. # 初始化解析配置
  40. # layout config
  41. self.layout_config = kwargs.get('layout_config')
  42. self.layout_model_name = self.layout_config.get(
  43. 'model', MODEL_NAME.DocLayout_YOLO
  44. )
  45. # formula config
  46. self.formula_config = kwargs.get('formula_config')
  47. self.mfd_model_name = self.formula_config.get(
  48. 'mfd_model', MODEL_NAME.YOLO_V8_MFD
  49. )
  50. self.mfr_model_name = self.formula_config.get(
  51. 'mfr_model', MODEL_NAME.UniMerNet_v2_Small
  52. )
  53. self.apply_formula = self.formula_config.get('enable', True)
  54. # table config
  55. self.table_config = kwargs.get('table_config')
  56. self.apply_table = self.table_config.get('enable', False)
  57. self.table_max_time = self.table_config.get('max_time', TABLE_MAX_TIME_VALUE)
  58. self.table_model_name = self.table_config.get('model', MODEL_NAME.RAPID_TABLE)
  59. self.table_sub_model_name = self.table_config.get('sub_model', None)
  60. # ocr config
  61. self.apply_ocr = ocr
  62. self.lang = kwargs.get('lang', None)
  63. logger.info(
  64. 'DocAnalysis init, this may take some times, layout_model: {}, apply_formula: {}, apply_ocr: {}, '
  65. 'apply_table: {}, table_model: {}, lang: {}'.format(
  66. self.layout_model_name,
  67. self.apply_formula,
  68. self.apply_ocr,
  69. self.apply_table,
  70. self.table_model_name,
  71. self.lang,
  72. )
  73. )
  74. # 初始化解析方案
  75. self.device = kwargs.get('device', 'cpu')
  76. logger.info('using device: {}'.format(self.device))
  77. models_dir = kwargs.get(
  78. 'models_dir', os.path.join(root_dir, 'resources', 'models')
  79. )
  80. logger.info('using models_dir: {}'.format(models_dir))
  81. atom_model_manager = AtomModelSingleton()
  82. # 初始化公式识别
  83. if self.apply_formula:
  84. # 初始化公式检测模型
  85. self.mfd_model = atom_model_manager.get_atom_model(
  86. atom_model_name=AtomicModel.MFD,
  87. mfd_weights=str(
  88. os.path.join(
  89. models_dir, self.configs['weights'][self.mfd_model_name]
  90. )
  91. ),
  92. device=self.device,
  93. )
  94. # 初始化公式解析模型
  95. mfr_weight_dir = str(
  96. os.path.join(models_dir, self.configs['weights'][self.mfr_model_name])
  97. )
  98. mfr_cfg_path = str(os.path.join(model_config_dir, 'UniMERNet', 'demo.yaml'))
  99. self.mfr_model = atom_model_manager.get_atom_model(
  100. atom_model_name=AtomicModel.MFR,
  101. mfr_weight_dir=mfr_weight_dir,
  102. mfr_cfg_path=mfr_cfg_path,
  103. device='cpu' if str(self.device).startswith("mps") else self.device,
  104. )
  105. # 初始化layout模型
  106. if self.layout_model_name == MODEL_NAME.LAYOUTLMv3:
  107. self.layout_model = atom_model_manager.get_atom_model(
  108. atom_model_name=AtomicModel.Layout,
  109. layout_model_name=MODEL_NAME.LAYOUTLMv3,
  110. layout_weights=str(
  111. os.path.join(
  112. models_dir, self.configs['weights'][self.layout_model_name]
  113. )
  114. ),
  115. layout_config_file=str(
  116. os.path.join(
  117. model_config_dir, 'layoutlmv3', 'layoutlmv3_base_inference.yaml'
  118. )
  119. ),
  120. device='cpu' if str(self.device).startswith("mps") else self.device,
  121. )
  122. elif self.layout_model_name == MODEL_NAME.DocLayout_YOLO:
  123. self.layout_model = atom_model_manager.get_atom_model(
  124. atom_model_name=AtomicModel.Layout,
  125. layout_model_name=MODEL_NAME.DocLayout_YOLO,
  126. doclayout_yolo_weights=str(
  127. os.path.join(
  128. models_dir, self.configs['weights'][self.layout_model_name]
  129. )
  130. ),
  131. device=self.device,
  132. )
  133. # 初始化ocr
  134. self.ocr_model = atom_model_manager.get_atom_model(
  135. atom_model_name=AtomicModel.OCR,
  136. ocr_show_log=show_log,
  137. det_db_box_thresh=0.3,
  138. lang=self.lang
  139. )
  140. # init table model
  141. if self.apply_table:
  142. table_model_dir = self.configs['weights'][self.table_model_name]
  143. self.table_model = atom_model_manager.get_atom_model(
  144. atom_model_name=AtomicModel.Table,
  145. table_model_name=self.table_model_name,
  146. table_model_path=str(os.path.join(models_dir, table_model_dir)),
  147. table_max_time=self.table_max_time,
  148. device=self.device,
  149. ocr_engine=self.ocr_model,
  150. table_sub_model_name=self.table_sub_model_name
  151. )
  152. logger.info('DocAnalysis init done!')
  153. def __call__(self, image):
  154. # layout检测
  155. layout_start = time.time()
  156. layout_res = []
  157. if self.layout_model_name == MODEL_NAME.LAYOUTLMv3:
  158. # layoutlmv3
  159. layout_res = self.layout_model(image, ignore_catids=[])
  160. elif self.layout_model_name == MODEL_NAME.DocLayout_YOLO:
  161. layout_res = self.layout_model.predict(image)
  162. layout_cost = round(time.time() - layout_start, 2)
  163. logger.info(f'layout detection time: {layout_cost}')
  164. if self.apply_formula:
  165. # 公式检测
  166. mfd_start = time.time()
  167. mfd_res = self.mfd_model.predict(image)
  168. logger.info(f'mfd time: {round(time.time() - mfd_start, 2)}')
  169. # 公式识别
  170. mfr_start = time.time()
  171. formula_list = self.mfr_model.predict(mfd_res, image)
  172. layout_res.extend(formula_list)
  173. mfr_cost = round(time.time() - mfr_start, 2)
  174. logger.info(f'formula nums: {len(formula_list)}, mfr time: {mfr_cost}')
  175. # 清理显存
  176. clean_vram(self.device, vram_threshold=6)
  177. # 从layout_res中获取ocr区域、表格区域、公式区域
  178. ocr_res_list, table_res_list, single_page_mfdetrec_res = (
  179. get_res_list_from_layout_res(layout_res)
  180. )
  181. # ocr识别
  182. ocr_start = time.time()
  183. # Process each area that requires OCR processing
  184. for res in ocr_res_list:
  185. new_image, useful_list = crop_img(res, image, crop_paste_x=50, crop_paste_y=50)
  186. adjusted_mfdetrec_res = get_adjusted_mfdetrec_res(single_page_mfdetrec_res, useful_list)
  187. # OCR recognition
  188. new_image = cv2.cvtColor(new_image, cv2.COLOR_RGB2BGR)
  189. if self.apply_ocr:
  190. ocr_res = self.ocr_model.ocr(new_image, mfd_res=adjusted_mfdetrec_res)[0]
  191. else:
  192. ocr_res = self.ocr_model.ocr(new_image, mfd_res=adjusted_mfdetrec_res, rec=False)[0]
  193. # Integration results
  194. if ocr_res:
  195. ocr_result_list = get_ocr_result_list(ocr_res, useful_list)
  196. layout_res.extend(ocr_result_list)
  197. ocr_cost = round(time.time() - ocr_start, 2)
  198. if self.apply_ocr:
  199. logger.info(f"ocr time: {ocr_cost}")
  200. else:
  201. logger.info(f"det time: {ocr_cost}")
  202. # 表格识别 table recognition
  203. if self.apply_table:
  204. table_start = time.time()
  205. for res in table_res_list:
  206. new_image, _ = crop_img(res, image)
  207. single_table_start_time = time.time()
  208. html_code = None
  209. if self.table_model_name == MODEL_NAME.STRUCT_EQTABLE:
  210. with torch.no_grad():
  211. table_result = self.table_model.predict(new_image, 'html')
  212. if len(table_result) > 0:
  213. html_code = table_result[0]
  214. elif self.table_model_name == MODEL_NAME.TABLE_MASTER:
  215. html_code = self.table_model.img2html(new_image)
  216. elif self.table_model_name == MODEL_NAME.RAPID_TABLE:
  217. html_code, table_cell_bboxes, logic_points, elapse = self.table_model.predict(
  218. new_image
  219. )
  220. run_time = time.time() - single_table_start_time
  221. if run_time > self.table_max_time:
  222. logger.warning(
  223. f'table recognition processing exceeds max time {self.table_max_time}s'
  224. )
  225. # 判断是否返回正常
  226. if html_code:
  227. expected_ending = html_code.strip().endswith(
  228. '</html>'
  229. ) or html_code.strip().endswith('</table>')
  230. if expected_ending:
  231. res['html'] = html_code
  232. else:
  233. logger.warning(
  234. 'table recognition processing fails, not found expected HTML table end'
  235. )
  236. else:
  237. logger.warning(
  238. 'table recognition processing fails, not get html return'
  239. )
  240. logger.info(f'table time: {round(time.time() - table_start, 2)}')
  241. return layout_res