pdf_extract_kit.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304
  1. # flake8: noqa
  2. import os
  3. import time
  4. import cv2
  5. import numpy as np
  6. import torch
  7. import yaml
  8. from loguru import logger
  9. from PIL import Image
  10. os.environ['NO_ALBUMENTATIONS_UPDATE'] = '1' # 禁止albumentations检查更新
  11. try:
  12. import torchtext
  13. if torchtext.__version__ >= '0.18.0':
  14. torchtext.disable_torchtext_deprecation_warning()
  15. except ImportError:
  16. pass
  17. from magic_pdf.config.constants import *
  18. from magic_pdf.model.model_list import AtomicModel
  19. from magic_pdf.model.sub_modules.model_init import AtomModelSingleton
  20. from magic_pdf.model.sub_modules.model_utils import (
  21. clean_vram, crop_img, get_res_list_from_layout_res)
  22. from magic_pdf.model.sub_modules.ocr.paddleocr.ocr_utils import (
  23. get_adjusted_mfdetrec_res, get_ocr_result_list)
  24. class CustomPEKModel:
  25. def __init__(self, ocr: bool = False, show_log: bool = False, **kwargs):
  26. """
  27. ======== model init ========
  28. """
  29. # 获取当前文件(即 pdf_extract_kit.py)的绝对路径
  30. current_file_path = os.path.abspath(__file__)
  31. # 获取当前文件所在的目录(model)
  32. current_dir = os.path.dirname(current_file_path)
  33. # 上一级目录(magic_pdf)
  34. root_dir = os.path.dirname(current_dir)
  35. # model_config目录
  36. model_config_dir = os.path.join(root_dir, 'resources', 'model_config')
  37. # 构建 model_configs.yaml 文件的完整路径
  38. config_path = os.path.join(model_config_dir, 'model_configs.yaml')
  39. with open(config_path, 'r', encoding='utf-8') as f:
  40. self.configs = yaml.load(f, Loader=yaml.FullLoader)
  41. # 初始化解析配置
  42. # layout config
  43. self.layout_config = kwargs.get('layout_config')
  44. self.layout_model_name = self.layout_config.get(
  45. 'model', MODEL_NAME.DocLayout_YOLO
  46. )
  47. # formula config
  48. self.formula_config = kwargs.get('formula_config')
  49. self.mfd_model_name = self.formula_config.get(
  50. 'mfd_model', MODEL_NAME.YOLO_V8_MFD
  51. )
  52. self.mfr_model_name = self.formula_config.get(
  53. 'mfr_model', MODEL_NAME.UniMerNet_v2_Small
  54. )
  55. self.apply_formula = self.formula_config.get('enable', True)
  56. # table config
  57. self.table_config = kwargs.get('table_config')
  58. self.apply_table = self.table_config.get('enable', False)
  59. self.table_max_time = self.table_config.get('max_time', TABLE_MAX_TIME_VALUE)
  60. self.table_model_name = self.table_config.get('model', MODEL_NAME.RAPID_TABLE)
  61. # ocr config
  62. self.apply_ocr = ocr
  63. self.lang = kwargs.get('lang', None)
  64. logger.info(
  65. 'DocAnalysis init, this may take some times, layout_model: {}, apply_formula: {}, apply_ocr: {}, '
  66. 'apply_table: {}, table_model: {}, lang: {}'.format(
  67. self.layout_model_name,
  68. self.apply_formula,
  69. self.apply_ocr,
  70. self.apply_table,
  71. self.table_model_name,
  72. self.lang,
  73. )
  74. )
  75. # 初始化解析方案
  76. self.device = kwargs.get('device', 'cpu')
  77. if str(self.device).startswith("npu"):
  78. import torch_npu
  79. os.environ['FLAGS_npu_jit_compile'] = '0'
  80. os.environ['FLAGS_use_stride_kernel'] = '0'
  81. elif str(self.device).startswith("mps"):
  82. os.environ['PYTORCH_ENABLE_MPS_FALLBACK'] = '1'
  83. logger.info('using device: {}'.format(self.device))
  84. models_dir = kwargs.get(
  85. 'models_dir', os.path.join(root_dir, 'resources', 'models')
  86. )
  87. logger.info('using models_dir: {}'.format(models_dir))
  88. atom_model_manager = AtomModelSingleton()
  89. # 初始化公式识别
  90. if self.apply_formula:
  91. # 初始化公式检测模型
  92. self.mfd_model = atom_model_manager.get_atom_model(
  93. atom_model_name=AtomicModel.MFD,
  94. mfd_weights=str(
  95. os.path.join(
  96. models_dir, self.configs['weights'][self.mfd_model_name]
  97. )
  98. ),
  99. device=self.device,
  100. )
  101. # 初始化公式解析模型
  102. mfr_weight_dir = str(
  103. os.path.join(models_dir, self.configs['weights'][self.mfr_model_name])
  104. )
  105. mfr_cfg_path = str(os.path.join(model_config_dir, 'UniMERNet', 'demo.yaml'))
  106. self.mfr_model = atom_model_manager.get_atom_model(
  107. atom_model_name=AtomicModel.MFR,
  108. mfr_weight_dir=mfr_weight_dir,
  109. mfr_cfg_path=mfr_cfg_path,
  110. device='cpu' if str(self.device).startswith("mps") else self.device,
  111. )
  112. # 初始化layout模型
  113. if self.layout_model_name == MODEL_NAME.LAYOUTLMv3:
  114. self.layout_model = atom_model_manager.get_atom_model(
  115. atom_model_name=AtomicModel.Layout,
  116. layout_model_name=MODEL_NAME.LAYOUTLMv3,
  117. layout_weights=str(
  118. os.path.join(
  119. models_dir, self.configs['weights'][self.layout_model_name]
  120. )
  121. ),
  122. layout_config_file=str(
  123. os.path.join(
  124. model_config_dir, 'layoutlmv3', 'layoutlmv3_base_inference.yaml'
  125. )
  126. ),
  127. device='cpu' if str(self.device).startswith("mps") else self.device,
  128. )
  129. elif self.layout_model_name == MODEL_NAME.DocLayout_YOLO:
  130. self.layout_model = atom_model_manager.get_atom_model(
  131. atom_model_name=AtomicModel.Layout,
  132. layout_model_name=MODEL_NAME.DocLayout_YOLO,
  133. doclayout_yolo_weights=str(
  134. os.path.join(
  135. models_dir, self.configs['weights'][self.layout_model_name]
  136. )
  137. ),
  138. device=self.device,
  139. )
  140. # 初始化ocr
  141. self.ocr_model = atom_model_manager.get_atom_model(
  142. atom_model_name=AtomicModel.OCR,
  143. ocr_show_log=show_log,
  144. det_db_box_thresh=0.3,
  145. lang=self.lang
  146. )
  147. # init table model
  148. if self.apply_table:
  149. table_model_dir = self.configs['weights'][self.table_model_name]
  150. self.table_model = atom_model_manager.get_atom_model(
  151. atom_model_name=AtomicModel.Table,
  152. table_model_name=self.table_model_name,
  153. table_model_path=str(os.path.join(models_dir, table_model_dir)),
  154. table_max_time=self.table_max_time,
  155. device=self.device,
  156. ocr_engine=self.ocr_model,
  157. )
  158. logger.info('DocAnalysis init done!')
  159. def __call__(self, image):
  160. pil_img = Image.fromarray(image)
  161. width, height = pil_img.size
  162. # logger.info(f'width: {width}, height: {height}')
  163. # layout检测
  164. layout_start = time.time()
  165. layout_res = []
  166. if self.layout_model_name == MODEL_NAME.LAYOUTLMv3:
  167. # layoutlmv3
  168. layout_res = self.layout_model(image, ignore_catids=[])
  169. elif self.layout_model_name == MODEL_NAME.DocLayout_YOLO:
  170. # doclayout_yolo
  171. # if height > width:
  172. # input_res = {"poly":[0,0,width,0,width,height,0,height]}
  173. # new_image, useful_list = crop_img(input_res, pil_img, crop_paste_x=width//2, crop_paste_y=0)
  174. # paste_x, paste_y, xmin, ymin, xmax, ymax, new_width, new_height = useful_list
  175. # layout_res = self.layout_model.predict(new_image)
  176. # for res in layout_res:
  177. # p1, p2, p3, p4, p5, p6, p7, p8 = res['poly']
  178. # p1 = p1 - paste_x + xmin
  179. # p2 = p2 - paste_y + ymin
  180. # p3 = p3 - paste_x + xmin
  181. # p4 = p4 - paste_y + ymin
  182. # p5 = p5 - paste_x + xmin
  183. # p6 = p6 - paste_y + ymin
  184. # p7 = p7 - paste_x + xmin
  185. # p8 = p8 - paste_y + ymin
  186. # res['poly'] = [p1, p2, p3, p4, p5, p6, p7, p8]
  187. # else:
  188. layout_res = self.layout_model.predict(image)
  189. layout_cost = round(time.time() - layout_start, 2)
  190. logger.info(f'layout detection time: {layout_cost}')
  191. if self.apply_formula:
  192. # 公式检测
  193. mfd_start = time.time()
  194. mfd_res = self.mfd_model.predict(image)
  195. logger.info(f'mfd time: {round(time.time() - mfd_start, 2)}')
  196. # 公式识别
  197. mfr_start = time.time()
  198. formula_list = self.mfr_model.predict(mfd_res, image)
  199. layout_res.extend(formula_list)
  200. mfr_cost = round(time.time() - mfr_start, 2)
  201. logger.info(f'formula nums: {len(formula_list)}, mfr time: {mfr_cost}')
  202. # 清理显存
  203. clean_vram(self.device, vram_threshold=8)
  204. # 从layout_res中获取ocr区域、表格区域、公式区域
  205. ocr_res_list, table_res_list, single_page_mfdetrec_res = (
  206. get_res_list_from_layout_res(layout_res)
  207. )
  208. # ocr识别
  209. ocr_start = time.time()
  210. # Process each area that requires OCR processing
  211. for res in ocr_res_list:
  212. new_image, useful_list = crop_img(res, pil_img, crop_paste_x=50, crop_paste_y=50)
  213. adjusted_mfdetrec_res = get_adjusted_mfdetrec_res(single_page_mfdetrec_res, useful_list)
  214. # OCR recognition
  215. new_image = cv2.cvtColor(np.asarray(new_image), cv2.COLOR_RGB2BGR)
  216. if self.apply_ocr:
  217. ocr_res = self.ocr_model.ocr(new_image, mfd_res=adjusted_mfdetrec_res)[0]
  218. else:
  219. ocr_res = self.ocr_model.ocr(new_image, mfd_res=adjusted_mfdetrec_res, rec=False)[0]
  220. # Integration results
  221. if ocr_res:
  222. ocr_result_list = get_ocr_result_list(ocr_res, useful_list)
  223. layout_res.extend(ocr_result_list)
  224. ocr_cost = round(time.time() - ocr_start, 2)
  225. if self.apply_ocr:
  226. logger.info(f"ocr time: {ocr_cost}")
  227. else:
  228. logger.info(f"det time: {ocr_cost}")
  229. # 表格识别 table recognition
  230. if self.apply_table:
  231. table_start = time.time()
  232. for res in table_res_list:
  233. new_image, _ = crop_img(res, pil_img)
  234. single_table_start_time = time.time()
  235. html_code = None
  236. if self.table_model_name == MODEL_NAME.STRUCT_EQTABLE:
  237. with torch.no_grad():
  238. table_result = self.table_model.predict(new_image, 'html')
  239. if len(table_result) > 0:
  240. html_code = table_result[0]
  241. elif self.table_model_name == MODEL_NAME.TABLE_MASTER:
  242. html_code = self.table_model.img2html(new_image)
  243. elif self.table_model_name == MODEL_NAME.RAPID_TABLE:
  244. html_code, table_cell_bboxes, elapse = self.table_model.predict(
  245. new_image
  246. )
  247. run_time = time.time() - single_table_start_time
  248. if run_time > self.table_max_time:
  249. logger.warning(
  250. f'table recognition processing exceeds max time {self.table_max_time}s'
  251. )
  252. # 判断是否返回正常
  253. if html_code:
  254. expected_ending = html_code.strip().endswith(
  255. '</html>'
  256. ) or html_code.strip().endswith('</table>')
  257. if expected_ending:
  258. res['html'] = html_code
  259. else:
  260. logger.warning(
  261. 'table recognition processing fails, not found expected HTML table end'
  262. )
  263. else:
  264. logger.warning(
  265. 'table recognition processing fails, not get html return'
  266. )
  267. logger.info(f'table time: {round(time.time() - table_start, 2)}')
  268. return layout_res