pdf_extract_kit.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278
  1. # flake8: noqa
  2. import os
  3. import time
  4. import cv2
  5. import numpy as np
  6. import torch
  7. import yaml
  8. from loguru import logger
  9. from PIL import Image
  10. os.environ['NO_ALBUMENTATIONS_UPDATE'] = '1' # 禁止albumentations检查更新
  11. os.environ['YOLO_VERBOSE'] = 'False' # disable yolo logger
  12. try:
  13. import torchtext
  14. if torchtext.__version__ >= '0.18.0':
  15. torchtext.disable_torchtext_deprecation_warning()
  16. except ImportError:
  17. pass
  18. from magic_pdf.config.constants import *
  19. from magic_pdf.model.model_list import AtomicModel
  20. from magic_pdf.model.sub_modules.model_init import AtomModelSingleton
  21. from magic_pdf.model.sub_modules.model_utils import (
  22. clean_vram, crop_img, get_res_list_from_layout_res)
  23. from magic_pdf.model.sub_modules.ocr.paddleocr.ocr_utils import (
  24. get_adjusted_mfdetrec_res, get_ocr_result_list)
  25. from threading import Lock
  26. class CustomPEKModel:
  27. def __init__(self, ocr: bool = False, show_log: bool = False, **kwargs):
  28. """
  29. ======== model init ========
  30. """
  31. # 获取当前文件(即 pdf_extract_kit.py)的绝对路径
  32. current_file_path = os.path.abspath(__file__)
  33. # 获取当前文件所在的目录(model)
  34. current_dir = os.path.dirname(current_file_path)
  35. # 上一级目录(magic_pdf)
  36. root_dir = os.path.dirname(current_dir)
  37. # model_config目录
  38. model_config_dir = os.path.join(root_dir, 'resources', 'model_config')
  39. # 构建 model_configs.yaml 文件的完整路径
  40. config_path = os.path.join(model_config_dir, 'model_configs.yaml')
  41. with open(config_path, 'r', encoding='utf-8') as f:
  42. self.configs = yaml.load(f, Loader=yaml.FullLoader)
  43. # 初始化解析配置
  44. # layout config
  45. self.layout_config = kwargs.get('layout_config')
  46. self.layout_model_name = self.layout_config.get(
  47. 'model', MODEL_NAME.DocLayout_YOLO
  48. )
  49. # formula config
  50. self.formula_config = kwargs.get('formula_config')
  51. self.mfd_model_name = self.formula_config.get(
  52. 'mfd_model', MODEL_NAME.YOLO_V8_MFD
  53. )
  54. self.mfr_model_name = self.formula_config.get(
  55. 'mfr_model', MODEL_NAME.UniMerNet_v2_Small
  56. )
  57. self.apply_formula = self.formula_config.get('enable', True)
  58. # table config
  59. self.table_config = kwargs.get('table_config')
  60. self.apply_table = self.table_config.get('enable', False)
  61. self.table_max_time = self.table_config.get('max_time', TABLE_MAX_TIME_VALUE)
  62. self.table_model_name = self.table_config.get('model', MODEL_NAME.RAPID_TABLE)
  63. # ocr config
  64. self.apply_ocr = ocr
  65. self.lang = kwargs.get('lang', None)
  66. logger.info(
  67. 'DocAnalysis init, this may take some times, layout_model: {}, apply_formula: {}, apply_ocr: {}, '
  68. 'apply_table: {}, table_model: {}, lang: {}'.format(
  69. self.layout_model_name,
  70. self.apply_formula,
  71. self.apply_ocr,
  72. self.apply_table,
  73. self.table_model_name,
  74. self.lang,
  75. )
  76. )
  77. # 初始化解析方案
  78. self.device = kwargs.get('device', 'cpu')
  79. logger.info('using device: {}'.format(self.device))
  80. models_dir = kwargs.get(
  81. 'models_dir', os.path.join(root_dir, 'resources', 'models')
  82. )
  83. logger.info('using models_dir: {}'.format(models_dir))
  84. atom_model_manager = AtomModelSingleton()
  85. # 初始化公式识别
  86. if self.apply_formula:
  87. # 初始化公式检测模型
  88. self.mfd_model = atom_model_manager.get_atom_model(
  89. atom_model_name=AtomicModel.MFD,
  90. mfd_weights=str(
  91. os.path.join(
  92. models_dir, self.configs['weights'][self.mfd_model_name]
  93. )
  94. ),
  95. device=self.device,
  96. )
  97. # 初始化公式解析模型
  98. mfr_weight_dir = str(
  99. os.path.join(models_dir, self.configs['weights'][self.mfr_model_name])
  100. )
  101. mfr_cfg_path = str(os.path.join(model_config_dir, 'UniMERNet', 'demo.yaml'))
  102. self.mfr_model = atom_model_manager.get_atom_model(
  103. atom_model_name=AtomicModel.MFR,
  104. mfr_weight_dir=mfr_weight_dir,
  105. mfr_cfg_path=mfr_cfg_path,
  106. device=self.device,
  107. )
  108. # 初始化layout模型
  109. if self.layout_model_name == MODEL_NAME.LAYOUTLMv3:
  110. self.layout_model = atom_model_manager.get_atom_model(
  111. atom_model_name=AtomicModel.Layout,
  112. layout_model_name=MODEL_NAME.LAYOUTLMv3,
  113. layout_weights=str(
  114. os.path.join(
  115. models_dir, self.configs['weights'][self.layout_model_name]
  116. )
  117. ),
  118. layout_config_file=str(
  119. os.path.join(
  120. model_config_dir, 'layoutlmv3', 'layoutlmv3_base_inference.yaml'
  121. )
  122. ),
  123. device=self.device,
  124. )
  125. elif self.layout_model_name == MODEL_NAME.DocLayout_YOLO:
  126. self.layout_model = atom_model_manager.get_atom_model(
  127. atom_model_name=AtomicModel.Layout,
  128. layout_model_name=MODEL_NAME.DocLayout_YOLO,
  129. doclayout_yolo_weights=str(
  130. os.path.join(
  131. models_dir, self.configs['weights'][self.layout_model_name]
  132. )
  133. ),
  134. device=self.device,
  135. )
  136. # 初始化ocr
  137. self.ocr_model = atom_model_manager.get_atom_model(
  138. atom_model_name=AtomicModel.OCR,
  139. ocr_show_log=show_log,
  140. det_db_box_thresh=0.3,
  141. lang=self.lang
  142. )
  143. # init table model
  144. if self.apply_table:
  145. table_model_dir = self.configs['weights'][self.table_model_name]
  146. self.table_model = atom_model_manager.get_atom_model(
  147. atom_model_name=AtomicModel.Table,
  148. table_model_name=self.table_model_name,
  149. table_model_path=str(os.path.join(models_dir, table_model_dir)),
  150. table_max_time=self.table_max_time,
  151. device=self.device,
  152. )
  153. logger.info('DocAnalysis init done!')
  154. def __call__(self, image):
  155. # layout检测
  156. layout_start = time.time()
  157. layout_res = []
  158. if self.layout_model_name == MODEL_NAME.LAYOUTLMv3:
  159. # layoutlmv3
  160. layout_res = self.layout_model(image, ignore_catids=[])
  161. elif self.layout_model_name == MODEL_NAME.DocLayout_YOLO:
  162. # doclayout_yolo
  163. layout_res = self.layout_model.predict(image)
  164. layout_cost = round(time.time() - layout_start, 2)
  165. logger.info(f'layout detection time: {layout_cost}')
  166. pil_img = Image.fromarray(image)
  167. if self.apply_formula:
  168. # 公式检测
  169. mfd_start = time.time()
  170. mfd_res = self.mfd_model.predict(image)
  171. logger.info(f'mfd time: {round(time.time() - mfd_start, 2)}')
  172. # 公式识别
  173. mfr_start = time.time()
  174. formula_list = self.mfr_model.predict(mfd_res, image)
  175. layout_res.extend(formula_list)
  176. mfr_cost = round(time.time() - mfr_start, 2)
  177. logger.info(f'formula nums: {len(formula_list)}, mfr time: {mfr_cost}')
  178. # 清理显存
  179. clean_vram(self.device, vram_threshold=8)
  180. # 从layout_res中获取ocr区域、表格区域、公式区域
  181. ocr_res_list, table_res_list, single_page_mfdetrec_res = (
  182. get_res_list_from_layout_res(layout_res)
  183. )
  184. # ocr识别
  185. ocr_start = time.time()
  186. # Process each area that requires OCR processing
  187. lock = Lock()
  188. for res in ocr_res_list:
  189. new_image, useful_list = crop_img(res, pil_img, crop_paste_x=50, crop_paste_y=50)
  190. adjusted_mfdetrec_res = get_adjusted_mfdetrec_res(single_page_mfdetrec_res, useful_list)
  191. # OCR recognition
  192. new_image = cv2.cvtColor(np.asarray(new_image), cv2.COLOR_RGB2BGR)
  193. with lock:
  194. if self.apply_ocr:
  195. ocr_res = self.ocr_model.ocr(new_image, mfd_res=adjusted_mfdetrec_res)[0]
  196. else:
  197. ocr_res = self.ocr_model.ocr(new_image, mfd_res=adjusted_mfdetrec_res, rec=False)[0]
  198. # Integration results
  199. if ocr_res:
  200. ocr_result_list = get_ocr_result_list(ocr_res, useful_list)
  201. layout_res.extend(ocr_result_list)
  202. ocr_cost = round(time.time() - ocr_start, 2)
  203. if self.apply_ocr:
  204. logger.info(f"ocr time: {ocr_cost}")
  205. else:
  206. logger.info(f"det time: {ocr_cost}")
  207. # 表格识别 table recognition
  208. if self.apply_table:
  209. table_start = time.time()
  210. for res in table_res_list:
  211. new_image, _ = crop_img(res, pil_img)
  212. single_table_start_time = time.time()
  213. html_code = None
  214. if self.table_model_name == MODEL_NAME.STRUCT_EQTABLE:
  215. with torch.no_grad():
  216. table_result = self.table_model.predict(new_image, 'html')
  217. if len(table_result) > 0:
  218. html_code = table_result[0]
  219. elif self.table_model_name == MODEL_NAME.TABLE_MASTER:
  220. html_code = self.table_model.img2html(new_image)
  221. elif self.table_model_name == MODEL_NAME.RAPID_TABLE:
  222. html_code, table_cell_bboxes, elapse = self.table_model.predict(
  223. new_image
  224. )
  225. run_time = time.time() - single_table_start_time
  226. if run_time > self.table_max_time:
  227. logger.warning(
  228. f'table recognition processing exceeds max time {self.table_max_time}s'
  229. )
  230. # 判断是否返回正常
  231. if html_code:
  232. expected_ending = html_code.strip().endswith(
  233. '</html>'
  234. ) or html_code.strip().endswith('</table>')
  235. if expected_ending:
  236. res['html'] = html_code
  237. else:
  238. logger.warning(
  239. 'table recognition processing fails, not found expected HTML table end'
  240. )
  241. else:
  242. logger.warning(
  243. 'table recognition processing fails, not get html return'
  244. )
  245. logger.info(f'table time: {round(time.time() - table_start, 2)}')
  246. return layout_res