pdf_extract_kit.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295
  1. # flake8: noqa
  2. import os
  3. import time
  4. import cv2
  5. import numpy as np
  6. import torch
  7. import yaml
  8. from loguru import logger
  9. from PIL import Image
  10. os.environ['NO_ALBUMENTATIONS_UPDATE'] = '1' # 禁止albumentations检查更新
  11. os.environ['YOLO_VERBOSE'] = 'False' # disable yolo logger
  12. try:
  13. import torchtext
  14. if torchtext.__version__ >= '0.18.0':
  15. torchtext.disable_torchtext_deprecation_warning()
  16. except ImportError:
  17. pass
  18. from magic_pdf.config.constants import *
  19. from magic_pdf.model.model_list import AtomicModel
  20. from magic_pdf.model.sub_modules.model_init import AtomModelSingleton
  21. from magic_pdf.model.sub_modules.model_utils import (
  22. clean_vram, crop_img, get_res_list_from_layout_res)
  23. from magic_pdf.model.sub_modules.ocr.paddleocr.ocr_utils import (
  24. get_adjusted_mfdetrec_res, get_ocr_result_list)
  25. class CustomPEKModel:
  26. def __init__(self, ocr: bool = False, show_log: bool = False, **kwargs):
  27. """
  28. ======== model init ========
  29. """
  30. # 获取当前文件(即 pdf_extract_kit.py)的绝对路径
  31. current_file_path = os.path.abspath(__file__)
  32. # 获取当前文件所在的目录(model)
  33. current_dir = os.path.dirname(current_file_path)
  34. # 上一级目录(magic_pdf)
  35. root_dir = os.path.dirname(current_dir)
  36. # model_config目录
  37. model_config_dir = os.path.join(root_dir, 'resources', 'model_config')
  38. # 构建 model_configs.yaml 文件的完整路径
  39. config_path = os.path.join(model_config_dir, 'model_configs.yaml')
  40. with open(config_path, 'r', encoding='utf-8') as f:
  41. self.configs = yaml.load(f, Loader=yaml.FullLoader)
  42. # 初始化解析配置
  43. # layout config
  44. self.layout_config = kwargs.get('layout_config')
  45. self.layout_model_name = self.layout_config.get(
  46. 'model', MODEL_NAME.DocLayout_YOLO
  47. )
  48. # formula config
  49. self.formula_config = kwargs.get('formula_config')
  50. self.mfd_model_name = self.formula_config.get(
  51. 'mfd_model', MODEL_NAME.YOLO_V8_MFD
  52. )
  53. self.mfr_model_name = self.formula_config.get(
  54. 'mfr_model', MODEL_NAME.UniMerNet_v2_Small
  55. )
  56. self.apply_formula = self.formula_config.get('enable', True)
  57. # table config
  58. self.table_config = kwargs.get('table_config')
  59. self.apply_table = self.table_config.get('enable', False)
  60. self.table_max_time = self.table_config.get('max_time', TABLE_MAX_TIME_VALUE)
  61. self.table_model_name = self.table_config.get('model', MODEL_NAME.RAPID_TABLE)
  62. # ocr config
  63. self.apply_ocr = ocr
  64. self.lang = kwargs.get('lang', None)
  65. logger.info(
  66. 'DocAnalysis init, this may take some times, layout_model: {}, apply_formula: {}, apply_ocr: {}, '
  67. 'apply_table: {}, table_model: {}, lang: {}'.format(
  68. self.layout_model_name,
  69. self.apply_formula,
  70. self.apply_ocr,
  71. self.apply_table,
  72. self.table_model_name,
  73. self.lang,
  74. )
  75. )
  76. # 初始化解析方案
  77. self.device = kwargs.get('device', 'cpu')
  78. logger.info('using device: {}'.format(self.device))
  79. models_dir = kwargs.get(
  80. 'models_dir', os.path.join(root_dir, 'resources', 'models')
  81. )
  82. logger.info('using models_dir: {}'.format(models_dir))
  83. atom_model_manager = AtomModelSingleton()
  84. # 初始化公式识别
  85. if self.apply_formula:
  86. # 初始化公式检测模型
  87. self.mfd_model = atom_model_manager.get_atom_model(
  88. atom_model_name=AtomicModel.MFD,
  89. mfd_weights=str(
  90. os.path.join(
  91. models_dir, self.configs['weights'][self.mfd_model_name]
  92. )
  93. ),
  94. device=self.device,
  95. )
  96. # 初始化公式解析模型
  97. mfr_weight_dir = str(
  98. os.path.join(models_dir, self.configs['weights'][self.mfr_model_name])
  99. )
  100. mfr_cfg_path = str(os.path.join(model_config_dir, 'UniMERNet', 'demo.yaml'))
  101. self.mfr_model = atom_model_manager.get_atom_model(
  102. atom_model_name=AtomicModel.MFR,
  103. mfr_weight_dir=mfr_weight_dir,
  104. mfr_cfg_path=mfr_cfg_path,
  105. device=self.device,
  106. )
  107. # 初始化layout模型
  108. if self.layout_model_name == MODEL_NAME.LAYOUTLMv3:
  109. self.layout_model = atom_model_manager.get_atom_model(
  110. atom_model_name=AtomicModel.Layout,
  111. layout_model_name=MODEL_NAME.LAYOUTLMv3,
  112. layout_weights=str(
  113. os.path.join(
  114. models_dir, self.configs['weights'][self.layout_model_name]
  115. )
  116. ),
  117. layout_config_file=str(
  118. os.path.join(
  119. model_config_dir, 'layoutlmv3', 'layoutlmv3_base_inference.yaml'
  120. )
  121. ),
  122. device=self.device,
  123. )
  124. elif self.layout_model_name == MODEL_NAME.DocLayout_YOLO:
  125. self.layout_model = atom_model_manager.get_atom_model(
  126. atom_model_name=AtomicModel.Layout,
  127. layout_model_name=MODEL_NAME.DocLayout_YOLO,
  128. doclayout_yolo_weights=str(
  129. os.path.join(
  130. models_dir, self.configs['weights'][self.layout_model_name]
  131. )
  132. ),
  133. device=self.device,
  134. )
  135. # 初始化ocr
  136. self.ocr_model = atom_model_manager.get_atom_model(
  137. atom_model_name=AtomicModel.OCR,
  138. ocr_show_log=show_log,
  139. det_db_box_thresh=0.3,
  140. lang=self.lang
  141. )
  142. # init table model
  143. if self.apply_table:
  144. table_model_dir = self.configs['weights'][self.table_model_name]
  145. self.table_model = atom_model_manager.get_atom_model(
  146. atom_model_name=AtomicModel.Table,
  147. table_model_name=self.table_model_name,
  148. table_model_path=str(os.path.join(models_dir, table_model_dir)),
  149. table_max_time=self.table_max_time,
  150. device=self.device,
  151. )
  152. logger.info('DocAnalysis init done!')
  153. def __call__(self, image):
  154. pil_img = Image.fromarray(image)
  155. width, height = pil_img.size
  156. # logger.info(f'width: {width}, height: {height}')
  157. # layout检测
  158. layout_start = time.time()
  159. layout_res = []
  160. if self.layout_model_name == MODEL_NAME.LAYOUTLMv3:
  161. # layoutlmv3
  162. layout_res = self.layout_model(image, ignore_catids=[])
  163. elif self.layout_model_name == MODEL_NAME.DocLayout_YOLO:
  164. # doclayout_yolo
  165. if height > width:
  166. input_res = {"poly":[0,0,width,0,width,height,0,height]}
  167. new_image, useful_list = crop_img(input_res, pil_img, crop_paste_x=width//2, crop_paste_y=0)
  168. paste_x, paste_y, xmin, ymin, xmax, ymax, new_width, new_height = useful_list
  169. layout_res = self.layout_model.predict(new_image)
  170. for res in layout_res:
  171. p1, p2, p3, p4, p5, p6, p7, p8 = res['poly']
  172. p1 = p1 - paste_x + xmin
  173. p2 = p2 - paste_y + ymin
  174. p3 = p3 - paste_x + xmin
  175. p4 = p4 - paste_y + ymin
  176. p5 = p5 - paste_x + xmin
  177. p6 = p6 - paste_y + ymin
  178. p7 = p7 - paste_x + xmin
  179. p8 = p8 - paste_y + ymin
  180. res['poly'] = [p1, p2, p3, p4, p5, p6, p7, p8]
  181. else:
  182. layout_res = self.layout_model.predict(image)
  183. layout_cost = round(time.time() - layout_start, 2)
  184. logger.info(f'layout detection time: {layout_cost}')
  185. if self.apply_formula:
  186. # 公式检测
  187. mfd_start = time.time()
  188. mfd_res = self.mfd_model.predict(image)
  189. logger.info(f'mfd time: {round(time.time() - mfd_start, 2)}')
  190. # 公式识别
  191. mfr_start = time.time()
  192. formula_list = self.mfr_model.predict(mfd_res, image)
  193. layout_res.extend(formula_list)
  194. mfr_cost = round(time.time() - mfr_start, 2)
  195. logger.info(f'formula nums: {len(formula_list)}, mfr time: {mfr_cost}')
  196. # 清理显存
  197. clean_vram(self.device, vram_threshold=8)
  198. # 从layout_res中获取ocr区域、表格区域、公式区域
  199. ocr_res_list, table_res_list, single_page_mfdetrec_res = (
  200. get_res_list_from_layout_res(layout_res)
  201. )
  202. # ocr识别
  203. ocr_start = time.time()
  204. # Process each area that requires OCR processing
  205. for res in ocr_res_list:
  206. new_image, useful_list = crop_img(res, pil_img, crop_paste_x=50, crop_paste_y=50)
  207. adjusted_mfdetrec_res = get_adjusted_mfdetrec_res(single_page_mfdetrec_res, useful_list)
  208. # OCR recognition
  209. new_image = cv2.cvtColor(np.asarray(new_image), cv2.COLOR_RGB2BGR)
  210. if self.apply_ocr:
  211. ocr_res = self.ocr_model.ocr(new_image, mfd_res=adjusted_mfdetrec_res)[0]
  212. else:
  213. ocr_res = self.ocr_model.ocr(new_image, mfd_res=adjusted_mfdetrec_res, rec=False)[0]
  214. # Integration results
  215. if ocr_res:
  216. ocr_result_list = get_ocr_result_list(ocr_res, useful_list)
  217. layout_res.extend(ocr_result_list)
  218. ocr_cost = round(time.time() - ocr_start, 2)
  219. if self.apply_ocr:
  220. logger.info(f"ocr time: {ocr_cost}")
  221. else:
  222. logger.info(f"det time: {ocr_cost}")
  223. # 表格识别 table recognition
  224. if self.apply_table:
  225. table_start = time.time()
  226. for res in table_res_list:
  227. new_image, _ = crop_img(res, pil_img)
  228. single_table_start_time = time.time()
  229. html_code = None
  230. if self.table_model_name == MODEL_NAME.STRUCT_EQTABLE:
  231. with torch.no_grad():
  232. table_result = self.table_model.predict(new_image, 'html')
  233. if len(table_result) > 0:
  234. html_code = table_result[0]
  235. elif self.table_model_name == MODEL_NAME.TABLE_MASTER:
  236. html_code = self.table_model.img2html(new_image)
  237. elif self.table_model_name == MODEL_NAME.RAPID_TABLE:
  238. html_code, table_cell_bboxes, elapse = self.table_model.predict(
  239. new_image
  240. )
  241. run_time = time.time() - single_table_start_time
  242. if run_time > self.table_max_time:
  243. logger.warning(
  244. f'table recognition processing exceeds max time {self.table_max_time}s'
  245. )
  246. # 判断是否返回正常
  247. if html_code:
  248. expected_ending = html_code.strip().endswith(
  249. '</html>'
  250. ) or html_code.strip().endswith('</table>')
  251. if expected_ending:
  252. res['html'] = html_code
  253. else:
  254. logger.warning(
  255. 'table recognition processing fails, not found expected HTML table end'
  256. )
  257. else:
  258. logger.warning(
  259. 'table recognition processing fails, not get html return'
  260. )
  261. logger.info(f'table time: {round(time.time() - table_start, 2)}')
  262. return layout_res