pdf_extract_kit.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299
  1. # flake8: noqa
  2. import os
  3. import time
  4. import cv2
  5. import numpy as np
  6. import torch
  7. import yaml
  8. from loguru import logger
  9. from PIL import Image
  10. os.environ['NO_ALBUMENTATIONS_UPDATE'] = '1' # 禁止albumentations检查更新
  11. try:
  12. import torchtext
  13. if torchtext.__version__ >= '0.18.0':
  14. torchtext.disable_torchtext_deprecation_warning()
  15. except ImportError:
  16. pass
  17. from magic_pdf.config.constants import *
  18. from magic_pdf.model.model_list import AtomicModel
  19. from magic_pdf.model.sub_modules.model_init import AtomModelSingleton
  20. from magic_pdf.model.sub_modules.model_utils import (
  21. clean_vram, crop_img, get_res_list_from_layout_res)
  22. from magic_pdf.model.sub_modules.ocr.paddleocr.ocr_utils import (
  23. get_adjusted_mfdetrec_res, get_ocr_result_list)
  24. class CustomPEKModel:
  25. def __init__(self, ocr: bool = False, show_log: bool = False, **kwargs):
  26. """
  27. ======== model init ========
  28. """
  29. # 获取当前文件(即 pdf_extract_kit.py)的绝对路径
  30. current_file_path = os.path.abspath(__file__)
  31. # 获取当前文件所在的目录(model)
  32. current_dir = os.path.dirname(current_file_path)
  33. # 上一级目录(magic_pdf)
  34. root_dir = os.path.dirname(current_dir)
  35. # model_config目录
  36. model_config_dir = os.path.join(root_dir, 'resources', 'model_config')
  37. # 构建 model_configs.yaml 文件的完整路径
  38. config_path = os.path.join(model_config_dir, 'model_configs.yaml')
  39. with open(config_path, 'r', encoding='utf-8') as f:
  40. self.configs = yaml.load(f, Loader=yaml.FullLoader)
  41. # 初始化解析配置
  42. # layout config
  43. self.layout_config = kwargs.get('layout_config')
  44. self.layout_model_name = self.layout_config.get(
  45. 'model', MODEL_NAME.DocLayout_YOLO
  46. )
  47. # formula config
  48. self.formula_config = kwargs.get('formula_config')
  49. self.mfd_model_name = self.formula_config.get(
  50. 'mfd_model', MODEL_NAME.YOLO_V8_MFD
  51. )
  52. self.mfr_model_name = self.formula_config.get(
  53. 'mfr_model', MODEL_NAME.UniMerNet_v2_Small
  54. )
  55. self.apply_formula = self.formula_config.get('enable', True)
  56. # table config
  57. self.table_config = kwargs.get('table_config')
  58. self.apply_table = self.table_config.get('enable', False)
  59. self.table_max_time = self.table_config.get('max_time', TABLE_MAX_TIME_VALUE)
  60. self.table_model_name = self.table_config.get('model', MODEL_NAME.RAPID_TABLE)
  61. self.table_sub_model_name = self.table_config.get('sub_model', None)
  62. # ocr config
  63. self.apply_ocr = ocr
  64. self.lang = kwargs.get('lang', None)
  65. logger.info(
  66. 'DocAnalysis init, this may take some times, layout_model: {}, apply_formula: {}, apply_ocr: {}, '
  67. 'apply_table: {}, table_model: {}, lang: {}'.format(
  68. self.layout_model_name,
  69. self.apply_formula,
  70. self.apply_ocr,
  71. self.apply_table,
  72. self.table_model_name,
  73. self.lang,
  74. )
  75. )
  76. # 初始化解析方案
  77. self.device = kwargs.get('device', 'cpu')
  78. logger.info('using device: {}'.format(self.device))
  79. models_dir = kwargs.get(
  80. 'models_dir', os.path.join(root_dir, 'resources', 'models')
  81. )
  82. logger.info('using models_dir: {}'.format(models_dir))
  83. atom_model_manager = AtomModelSingleton()
  84. # 初始化公式识别
  85. if self.apply_formula:
  86. # 初始化公式检测模型
  87. self.mfd_model = atom_model_manager.get_atom_model(
  88. atom_model_name=AtomicModel.MFD,
  89. mfd_weights=str(
  90. os.path.join(
  91. models_dir, self.configs['weights'][self.mfd_model_name]
  92. )
  93. ),
  94. device=self.device,
  95. )
  96. # 初始化公式解析模型
  97. mfr_weight_dir = str(
  98. os.path.join(models_dir, self.configs['weights'][self.mfr_model_name])
  99. )
  100. mfr_cfg_path = str(os.path.join(model_config_dir, 'UniMERNet', 'demo.yaml'))
  101. self.mfr_model = atom_model_manager.get_atom_model(
  102. atom_model_name=AtomicModel.MFR,
  103. mfr_weight_dir=mfr_weight_dir,
  104. mfr_cfg_path=mfr_cfg_path,
  105. device='cpu' if str(self.device).startswith("mps") else self.device,
  106. )
  107. # 初始化layout模型
  108. if self.layout_model_name == MODEL_NAME.LAYOUTLMv3:
  109. self.layout_model = atom_model_manager.get_atom_model(
  110. atom_model_name=AtomicModel.Layout,
  111. layout_model_name=MODEL_NAME.LAYOUTLMv3,
  112. layout_weights=str(
  113. os.path.join(
  114. models_dir, self.configs['weights'][self.layout_model_name]
  115. )
  116. ),
  117. layout_config_file=str(
  118. os.path.join(
  119. model_config_dir, 'layoutlmv3', 'layoutlmv3_base_inference.yaml'
  120. )
  121. ),
  122. device='cpu' if str(self.device).startswith("mps") else self.device,
  123. )
  124. elif self.layout_model_name == MODEL_NAME.DocLayout_YOLO:
  125. self.layout_model = atom_model_manager.get_atom_model(
  126. atom_model_name=AtomicModel.Layout,
  127. layout_model_name=MODEL_NAME.DocLayout_YOLO,
  128. doclayout_yolo_weights=str(
  129. os.path.join(
  130. models_dir, self.configs['weights'][self.layout_model_name]
  131. )
  132. ),
  133. device=self.device,
  134. )
  135. # 初始化ocr
  136. self.ocr_model = atom_model_manager.get_atom_model(
  137. atom_model_name=AtomicModel.OCR,
  138. ocr_show_log=show_log,
  139. det_db_box_thresh=0.3,
  140. lang=self.lang
  141. )
  142. # init table model
  143. if self.apply_table:
  144. table_model_dir = self.configs['weights'][self.table_model_name]
  145. self.table_model = atom_model_manager.get_atom_model(
  146. atom_model_name=AtomicModel.Table,
  147. table_model_name=self.table_model_name,
  148. table_model_path=str(os.path.join(models_dir, table_model_dir)),
  149. table_max_time=self.table_max_time,
  150. device=self.device,
  151. ocr_engine=self.ocr_model,
  152. table_sub_model_name=self.table_sub_model_name
  153. )
  154. logger.info('DocAnalysis init done!')
  155. def __call__(self, image):
  156. pil_img = Image.fromarray(image)
  157. width, height = pil_img.size
  158. # logger.info(f'width: {width}, height: {height}')
  159. # layout检测
  160. layout_start = time.time()
  161. layout_res = []
  162. if self.layout_model_name == MODEL_NAME.LAYOUTLMv3:
  163. # layoutlmv3
  164. layout_res = self.layout_model(image, ignore_catids=[])
  165. elif self.layout_model_name == MODEL_NAME.DocLayout_YOLO:
  166. # doclayout_yolo
  167. # if height > width:
  168. # input_res = {"poly":[0,0,width,0,width,height,0,height]}
  169. # new_image, useful_list = crop_img(input_res, pil_img, crop_paste_x=width//2, crop_paste_y=0)
  170. # paste_x, paste_y, xmin, ymin, xmax, ymax, new_width, new_height = useful_list
  171. # layout_res = self.layout_model.predict(new_image)
  172. # for res in layout_res:
  173. # p1, p2, p3, p4, p5, p6, p7, p8 = res['poly']
  174. # p1 = p1 - paste_x + xmin
  175. # p2 = p2 - paste_y + ymin
  176. # p3 = p3 - paste_x + xmin
  177. # p4 = p4 - paste_y + ymin
  178. # p5 = p5 - paste_x + xmin
  179. # p6 = p6 - paste_y + ymin
  180. # p7 = p7 - paste_x + xmin
  181. # p8 = p8 - paste_y + ymin
  182. # res['poly'] = [p1, p2, p3, p4, p5, p6, p7, p8]
  183. # else:
  184. layout_res = self.layout_model.predict(image)
  185. layout_cost = round(time.time() - layout_start, 2)
  186. logger.info(f'layout detection time: {layout_cost}')
  187. if self.apply_formula:
  188. # 公式检测
  189. mfd_start = time.time()
  190. mfd_res = self.mfd_model.predict(image)
  191. logger.info(f'mfd time: {round(time.time() - mfd_start, 2)}')
  192. # 公式识别
  193. mfr_start = time.time()
  194. formula_list = self.mfr_model.predict(mfd_res, image)
  195. layout_res.extend(formula_list)
  196. mfr_cost = round(time.time() - mfr_start, 2)
  197. logger.info(f'formula nums: {len(formula_list)}, mfr time: {mfr_cost}')
  198. # 清理显存
  199. clean_vram(self.device, vram_threshold=6)
  200. # 从layout_res中获取ocr区域、表格区域、公式区域
  201. ocr_res_list, table_res_list, single_page_mfdetrec_res = (
  202. get_res_list_from_layout_res(layout_res)
  203. )
  204. # ocr识别
  205. ocr_start = time.time()
  206. # Process each area that requires OCR processing
  207. for res in ocr_res_list:
  208. new_image, useful_list = crop_img(res, pil_img, crop_paste_x=50, crop_paste_y=50)
  209. adjusted_mfdetrec_res = get_adjusted_mfdetrec_res(single_page_mfdetrec_res, useful_list)
  210. # OCR recognition
  211. new_image = cv2.cvtColor(np.asarray(new_image), cv2.COLOR_RGB2BGR)
  212. if self.apply_ocr:
  213. ocr_res = self.ocr_model.ocr(new_image, mfd_res=adjusted_mfdetrec_res)[0]
  214. else:
  215. ocr_res = self.ocr_model.ocr(new_image, mfd_res=adjusted_mfdetrec_res, rec=False)[0]
  216. # Integration results
  217. if ocr_res:
  218. ocr_result_list = get_ocr_result_list(ocr_res, useful_list)
  219. layout_res.extend(ocr_result_list)
  220. ocr_cost = round(time.time() - ocr_start, 2)
  221. if self.apply_ocr:
  222. logger.info(f"ocr time: {ocr_cost}")
  223. else:
  224. logger.info(f"det time: {ocr_cost}")
  225. # 表格识别 table recognition
  226. if self.apply_table:
  227. table_start = time.time()
  228. for res in table_res_list:
  229. new_image, _ = crop_img(res, pil_img)
  230. single_table_start_time = time.time()
  231. html_code = None
  232. if self.table_model_name == MODEL_NAME.STRUCT_EQTABLE:
  233. with torch.no_grad():
  234. table_result = self.table_model.predict(new_image, 'html')
  235. if len(table_result) > 0:
  236. html_code = table_result[0]
  237. elif self.table_model_name == MODEL_NAME.TABLE_MASTER:
  238. html_code = self.table_model.img2html(new_image)
  239. elif self.table_model_name == MODEL_NAME.RAPID_TABLE:
  240. html_code, table_cell_bboxes, logic_points, elapse = self.table_model.predict(
  241. new_image
  242. )
  243. run_time = time.time() - single_table_start_time
  244. if run_time > self.table_max_time:
  245. logger.warning(
  246. f'table recognition processing exceeds max time {self.table_max_time}s'
  247. )
  248. # 判断是否返回正常
  249. if html_code:
  250. expected_ending = html_code.strip().endswith(
  251. '</html>'
  252. ) or html_code.strip().endswith('</table>')
  253. if expected_ending:
  254. res['html'] = html_code
  255. else:
  256. logger.warning(
  257. 'table recognition processing fails, not found expected HTML table end'
  258. )
  259. else:
  260. logger.warning(
  261. 'table recognition processing fails, not get html return'
  262. )
  263. logger.info(f'table time: {round(time.time() - table_start, 2)}')
  264. return layout_res