pdf_extract_kit.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274
  1. # flake8: noqa
  2. import os
  3. import time
  4. import cv2
  5. import numpy as np
  6. import torch
  7. import yaml
  8. from loguru import logger
  9. from PIL import Image
  10. os.environ['NO_ALBUMENTATIONS_UPDATE'] = '1' # 禁止albumentations检查更新
  11. os.environ['YOLO_VERBOSE'] = 'False' # disable yolo logger
  12. try:
  13. import torchtext
  14. if torchtext.__version__ >= '0.18.0':
  15. torchtext.disable_torchtext_deprecation_warning()
  16. except ImportError:
  17. pass
  18. from magic_pdf.config.constants import *
  19. from magic_pdf.model.model_list import AtomicModel
  20. from magic_pdf.model.sub_modules.model_init import AtomModelSingleton
  21. from magic_pdf.model.sub_modules.model_utils import (
  22. clean_vram, crop_img, get_res_list_from_layout_res)
  23. from magic_pdf.model.sub_modules.ocr.paddleocr.ocr_utils import (
  24. get_adjusted_mfdetrec_res, get_ocr_result_list)
  25. class CustomPEKModel:
  26. def __init__(self, ocr: bool = False, show_log: bool = False, **kwargs):
  27. """
  28. ======== model init ========
  29. """
  30. # 获取当前文件(即 pdf_extract_kit.py)的绝对路径
  31. current_file_path = os.path.abspath(__file__)
  32. # 获取当前文件所在的目录(model)
  33. current_dir = os.path.dirname(current_file_path)
  34. # 上一级目录(magic_pdf)
  35. root_dir = os.path.dirname(current_dir)
  36. # model_config目录
  37. model_config_dir = os.path.join(root_dir, 'resources', 'model_config')
  38. # 构建 model_configs.yaml 文件的完整路径
  39. config_path = os.path.join(model_config_dir, 'model_configs.yaml')
  40. with open(config_path, 'r', encoding='utf-8') as f:
  41. self.configs = yaml.load(f, Loader=yaml.FullLoader)
  42. # 初始化解析配置
  43. # layout config
  44. self.layout_config = kwargs.get('layout_config')
  45. self.layout_model_name = self.layout_config.get(
  46. 'model', MODEL_NAME.DocLayout_YOLO
  47. )
  48. # formula config
  49. self.formula_config = kwargs.get('formula_config')
  50. self.mfd_model_name = self.formula_config.get(
  51. 'mfd_model', MODEL_NAME.YOLO_V8_MFD
  52. )
  53. self.mfr_model_name = self.formula_config.get(
  54. 'mfr_model', MODEL_NAME.UniMerNet_v2_Small
  55. )
  56. self.apply_formula = self.formula_config.get('enable', True)
  57. # table config
  58. self.table_config = kwargs.get('table_config')
  59. self.apply_table = self.table_config.get('enable', False)
  60. self.table_max_time = self.table_config.get('max_time', TABLE_MAX_TIME_VALUE)
  61. self.table_model_name = self.table_config.get('model', MODEL_NAME.RAPID_TABLE)
  62. # ocr config
  63. self.apply_ocr = ocr
  64. self.lang = kwargs.get('lang', None)
  65. logger.info(
  66. 'DocAnalysis init, this may take some times, layout_model: {}, apply_formula: {}, apply_ocr: {}, '
  67. 'apply_table: {}, table_model: {}, lang: {}'.format(
  68. self.layout_model_name,
  69. self.apply_formula,
  70. self.apply_ocr,
  71. self.apply_table,
  72. self.table_model_name,
  73. self.lang,
  74. )
  75. )
  76. # 初始化解析方案
  77. self.device = kwargs.get('device', 'cpu')
  78. logger.info('using device: {}'.format(self.device))
  79. models_dir = kwargs.get(
  80. 'models_dir', os.path.join(root_dir, 'resources', 'models')
  81. )
  82. logger.info('using models_dir: {}'.format(models_dir))
  83. atom_model_manager = AtomModelSingleton()
  84. # 初始化公式识别
  85. if self.apply_formula:
  86. # 初始化公式检测模型
  87. self.mfd_model = atom_model_manager.get_atom_model(
  88. atom_model_name=AtomicModel.MFD,
  89. mfd_weights=str(
  90. os.path.join(
  91. models_dir, self.configs['weights'][self.mfd_model_name]
  92. )
  93. ),
  94. device=self.device,
  95. )
  96. # 初始化公式解析模型
  97. mfr_weight_dir = str(
  98. os.path.join(models_dir, self.configs['weights'][self.mfr_model_name])
  99. )
  100. mfr_cfg_path = str(os.path.join(model_config_dir, 'UniMERNet', 'demo.yaml'))
  101. self.mfr_model = atom_model_manager.get_atom_model(
  102. atom_model_name=AtomicModel.MFR,
  103. mfr_weight_dir=mfr_weight_dir,
  104. mfr_cfg_path=mfr_cfg_path,
  105. device=self.device,
  106. )
  107. # 初始化layout模型
  108. if self.layout_model_name == MODEL_NAME.LAYOUTLMv3:
  109. self.layout_model = atom_model_manager.get_atom_model(
  110. atom_model_name=AtomicModel.Layout,
  111. layout_model_name=MODEL_NAME.LAYOUTLMv3,
  112. layout_weights=str(
  113. os.path.join(
  114. models_dir, self.configs['weights'][self.layout_model_name]
  115. )
  116. ),
  117. layout_config_file=str(
  118. os.path.join(
  119. model_config_dir, 'layoutlmv3', 'layoutlmv3_base_inference.yaml'
  120. )
  121. ),
  122. device=self.device,
  123. )
  124. elif self.layout_model_name == MODEL_NAME.DocLayout_YOLO:
  125. self.layout_model = atom_model_manager.get_atom_model(
  126. atom_model_name=AtomicModel.Layout,
  127. layout_model_name=MODEL_NAME.DocLayout_YOLO,
  128. doclayout_yolo_weights=str(
  129. os.path.join(
  130. models_dir, self.configs['weights'][self.layout_model_name]
  131. )
  132. ),
  133. device=self.device,
  134. )
  135. # 初始化ocr
  136. self.ocr_model = atom_model_manager.get_atom_model(
  137. atom_model_name=AtomicModel.OCR,
  138. ocr_show_log=show_log,
  139. det_db_box_thresh=0.3,
  140. lang=self.lang
  141. )
  142. # init table model
  143. if self.apply_table:
  144. table_model_dir = self.configs['weights'][self.table_model_name]
  145. self.table_model = atom_model_manager.get_atom_model(
  146. atom_model_name=AtomicModel.Table,
  147. table_model_name=self.table_model_name,
  148. table_model_path=str(os.path.join(models_dir, table_model_dir)),
  149. table_max_time=self.table_max_time,
  150. device=self.device,
  151. )
  152. logger.info('DocAnalysis init done!')
  153. def __call__(self, image):
  154. # layout检测
  155. layout_start = time.time()
  156. layout_res = []
  157. if self.layout_model_name == MODEL_NAME.LAYOUTLMv3:
  158. # layoutlmv3
  159. layout_res = self.layout_model(image, ignore_catids=[])
  160. elif self.layout_model_name == MODEL_NAME.DocLayout_YOLO:
  161. # doclayout_yolo
  162. layout_res = self.layout_model.predict(image)
  163. layout_cost = round(time.time() - layout_start, 2)
  164. logger.info(f'layout detection time: {layout_cost}')
  165. pil_img = Image.fromarray(image)
  166. if self.apply_formula:
  167. # 公式检测
  168. mfd_start = time.time()
  169. mfd_res = self.mfd_model.predict(image)
  170. logger.info(f'mfd time: {round(time.time() - mfd_start, 2)}')
  171. # 公式识别
  172. mfr_start = time.time()
  173. formula_list = self.mfr_model.predict(mfd_res, image)
  174. layout_res.extend(formula_list)
  175. mfr_cost = round(time.time() - mfr_start, 2)
  176. logger.info(f'formula nums: {len(formula_list)}, mfr time: {mfr_cost}')
  177. # 清理显存
  178. clean_vram(self.device, vram_threshold=8)
  179. # 从layout_res中获取ocr区域、表格区域、公式区域
  180. ocr_res_list, table_res_list, single_page_mfdetrec_res = (
  181. get_res_list_from_layout_res(layout_res)
  182. )
  183. # ocr识别
  184. ocr_start = time.time()
  185. # Process each area that requires OCR processing
  186. for res in ocr_res_list:
  187. new_image, useful_list = crop_img(res, pil_img, crop_paste_x=50, crop_paste_y=50)
  188. adjusted_mfdetrec_res = get_adjusted_mfdetrec_res(single_page_mfdetrec_res, useful_list)
  189. # OCR recognition
  190. new_image = cv2.cvtColor(np.asarray(new_image), cv2.COLOR_RGB2BGR)
  191. if self.apply_ocr:
  192. ocr_res = self.ocr_model.ocr(new_image, mfd_res=adjusted_mfdetrec_res)[0]
  193. else:
  194. ocr_res = self.ocr_model.ocr(new_image, mfd_res=adjusted_mfdetrec_res, rec=False)[0]
  195. # Integration results
  196. if ocr_res:
  197. ocr_result_list = get_ocr_result_list(ocr_res, useful_list)
  198. layout_res.extend(ocr_result_list)
  199. ocr_cost = round(time.time() - ocr_start, 2)
  200. if self.apply_ocr:
  201. logger.info(f"ocr time: {ocr_cost}")
  202. else:
  203. logger.info(f"det time: {ocr_cost}")
  204. # 表格识别 table recognition
  205. if self.apply_table:
  206. table_start = time.time()
  207. for res in table_res_list:
  208. new_image, _ = crop_img(res, pil_img)
  209. single_table_start_time = time.time()
  210. html_code = None
  211. if self.table_model_name == MODEL_NAME.STRUCT_EQTABLE:
  212. with torch.no_grad():
  213. table_result = self.table_model.predict(new_image, 'html')
  214. if len(table_result) > 0:
  215. html_code = table_result[0]
  216. elif self.table_model_name == MODEL_NAME.TABLE_MASTER:
  217. html_code = self.table_model.img2html(new_image)
  218. elif self.table_model_name == MODEL_NAME.RAPID_TABLE:
  219. html_code, table_cell_bboxes, elapse = self.table_model.predict(
  220. new_image
  221. )
  222. run_time = time.time() - single_table_start_time
  223. if run_time > self.table_max_time:
  224. logger.warning(
  225. f'table recognition processing exceeds max time {self.table_max_time}s'
  226. )
  227. # 判断是否返回正常
  228. if html_code:
  229. expected_ending = html_code.strip().endswith(
  230. '</html>'
  231. ) or html_code.strip().endswith('</table>')
  232. if expected_ending:
  233. res['html'] = html_code
  234. else:
  235. logger.warning(
  236. 'table recognition processing fails, not found expected HTML table end'
  237. )
  238. else:
  239. logger.warning(
  240. 'table recognition processing fails, not get html return'
  241. )
  242. logger.info(f'table time: {round(time.time() - table_start, 2)}')
  243. return layout_res