pdf_extract_kit.py 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421
  1. from loguru import logger
  2. import os
  3. import time
  4. from magic_pdf.libs.Constants import *
  5. from magic_pdf.libs.clean_memory import clean_memory
  6. from magic_pdf.model.model_list import AtomicModel
  7. os.environ['NO_ALBUMENTATIONS_UPDATE'] = '1' # 禁止albumentations检查更新
  8. try:
  9. import cv2
  10. import yaml
  11. import argparse
  12. import numpy as np
  13. import torch
  14. import torchtext
  15. if torchtext.__version__ >= "0.18.0":
  16. torchtext.disable_torchtext_deprecation_warning()
  17. from PIL import Image
  18. from torchvision import transforms
  19. from torch.utils.data import Dataset, DataLoader
  20. from ultralytics import YOLO
  21. from unimernet.common.config import Config
  22. import unimernet.tasks as tasks
  23. from unimernet.processors import load_processor
  24. except ImportError as e:
  25. logger.exception(e)
  26. logger.error(
  27. 'Required dependency not installed, please install by \n'
  28. '"pip install magic-pdf[full] --extra-index-url https://myhloli.github.io/wheels/"')
  29. exit(1)
  30. from magic_pdf.model.pek_sub_modules.layoutlmv3.model_init import Layoutlmv3_Predictor
  31. from magic_pdf.model.pek_sub_modules.post_process import latex_rm_whitespace
  32. from magic_pdf.model.pek_sub_modules.self_modify import ModifiedPaddleOCR
  33. from magic_pdf.model.pek_sub_modules.structeqtable.StructTableModel import StructTableModel
  34. from magic_pdf.model.ppTableModel import ppTableModel
  35. def table_model_init(table_model_type, model_path, max_time, _device_='cpu'):
  36. if table_model_type == STRUCT_EQTABLE:
  37. table_model = StructTableModel(model_path, max_time=max_time, device=_device_)
  38. else:
  39. config = {
  40. "model_dir": model_path,
  41. "device": _device_
  42. }
  43. table_model = ppTableModel(config)
  44. return table_model
  45. def mfd_model_init(weight):
  46. mfd_model = YOLO(weight)
  47. return mfd_model
  48. def mfr_model_init(weight_dir, cfg_path, _device_='cpu'):
  49. args = argparse.Namespace(cfg_path=cfg_path, options=None)
  50. cfg = Config(args)
  51. cfg.config.model.pretrained = os.path.join(weight_dir, "pytorch_model.pth")
  52. cfg.config.model.model_config.model_name = weight_dir
  53. cfg.config.model.tokenizer_config.path = weight_dir
  54. task = tasks.setup_task(cfg)
  55. model = task.build_model(cfg)
  56. model.to(_device_)
  57. model.eval()
  58. vis_processor = load_processor('formula_image_eval', cfg.config.datasets.formula_rec_eval.vis_processor.eval)
  59. mfr_transform = transforms.Compose([vis_processor, ])
  60. return [model, mfr_transform]
  61. def layout_model_init(weight, config_file, device):
  62. model = Layoutlmv3_Predictor(weight, config_file, device)
  63. return model
  64. def ocr_model_init(show_log: bool = False, det_db_box_thresh=0.3, lang=None):
  65. if lang is not None:
  66. model = ModifiedPaddleOCR(show_log=show_log, det_db_box_thresh=det_db_box_thresh, lang=lang)
  67. else:
  68. model = ModifiedPaddleOCR(show_log=show_log, det_db_box_thresh=det_db_box_thresh)
  69. return model
  70. class MathDataset(Dataset):
  71. def __init__(self, image_paths, transform=None):
  72. self.image_paths = image_paths
  73. self.transform = transform
  74. def __len__(self):
  75. return len(self.image_paths)
  76. def __getitem__(self, idx):
  77. # if not pil image, then convert to pil image
  78. if isinstance(self.image_paths[idx], str):
  79. raw_image = Image.open(self.image_paths[idx])
  80. else:
  81. raw_image = self.image_paths[idx]
  82. if self.transform:
  83. image = self.transform(raw_image)
  84. return image
  85. class AtomModelSingleton:
  86. _instance = None
  87. _models = {}
  88. def __new__(cls, *args, **kwargs):
  89. if cls._instance is None:
  90. cls._instance = super().__new__(cls)
  91. return cls._instance
  92. def get_atom_model(self, atom_model_name: str, **kwargs):
  93. if atom_model_name not in self._models:
  94. self._models[atom_model_name] = atom_model_init(model_name=atom_model_name, **kwargs)
  95. return self._models[atom_model_name]
  96. def atom_model_init(model_name: str, **kwargs):
  97. if model_name == AtomicModel.Layout:
  98. atom_model = layout_model_init(
  99. kwargs.get("layout_weights"),
  100. kwargs.get("layout_config_file"),
  101. kwargs.get("device")
  102. )
  103. elif model_name == AtomicModel.MFD:
  104. atom_model = mfd_model_init(
  105. kwargs.get("mfd_weights")
  106. )
  107. elif model_name == AtomicModel.MFR:
  108. atom_model = mfr_model_init(
  109. kwargs.get("mfr_weight_dir"),
  110. kwargs.get("mfr_cfg_path"),
  111. kwargs.get("device")
  112. )
  113. elif model_name == AtomicModel.OCR:
  114. atom_model = ocr_model_init(
  115. kwargs.get("ocr_show_log"),
  116. kwargs.get("det_db_box_thresh"),
  117. kwargs.get("lang")
  118. )
  119. elif model_name == AtomicModel.Table:
  120. atom_model = table_model_init(
  121. kwargs.get("table_model_type"),
  122. kwargs.get("table_model_path"),
  123. kwargs.get("table_max_time"),
  124. kwargs.get("device")
  125. )
  126. else:
  127. logger.error("model name not allow")
  128. exit(1)
  129. return atom_model
  130. # Unified crop img logic
  131. def crop_img(input_res, input_pil_img, crop_paste_x=0, crop_paste_y=0):
  132. crop_xmin, crop_ymin = int(input_res['poly'][0]), int(input_res['poly'][1])
  133. crop_xmax, crop_ymax = int(input_res['poly'][4]), int(input_res['poly'][5])
  134. # Create a white background with an additional width and height of 50
  135. crop_new_width = crop_xmax - crop_xmin + crop_paste_x * 2
  136. crop_new_height = crop_ymax - crop_ymin + crop_paste_y * 2
  137. return_image = Image.new('RGB', (crop_new_width, crop_new_height), 'white')
  138. # Crop image
  139. crop_box = (crop_xmin, crop_ymin, crop_xmax, crop_ymax)
  140. cropped_img = input_pil_img.crop(crop_box)
  141. return_image.paste(cropped_img, (crop_paste_x, crop_paste_y))
  142. return_list = [crop_paste_x, crop_paste_y, crop_xmin, crop_ymin, crop_xmax, crop_ymax, crop_new_width, crop_new_height]
  143. return return_image, return_list
  144. class CustomPEKModel:
  145. def __init__(self, ocr: bool = False, show_log: bool = False, **kwargs):
  146. """
  147. ======== model init ========
  148. """
  149. # 获取当前文件(即 pdf_extract_kit.py)的绝对路径
  150. current_file_path = os.path.abspath(__file__)
  151. # 获取当前文件所在的目录(model)
  152. current_dir = os.path.dirname(current_file_path)
  153. # 上一级目录(magic_pdf)
  154. root_dir = os.path.dirname(current_dir)
  155. # model_config目录
  156. model_config_dir = os.path.join(root_dir, 'resources', 'model_config')
  157. # 构建 model_configs.yaml 文件的完整路径
  158. config_path = os.path.join(model_config_dir, 'model_configs.yaml')
  159. with open(config_path, "r", encoding='utf-8') as f:
  160. self.configs = yaml.load(f, Loader=yaml.FullLoader)
  161. # 初始化解析配置
  162. self.apply_layout = kwargs.get("apply_layout", self.configs["config"]["layout"])
  163. self.apply_formula = kwargs.get("apply_formula", self.configs["config"]["formula"])
  164. # table config
  165. self.table_config = kwargs.get("table_config", self.configs["config"]["table_config"])
  166. self.apply_table = self.table_config.get("is_table_recog_enable", False)
  167. self.table_max_time = self.table_config.get("max_time", TABLE_MAX_TIME_VALUE)
  168. self.table_model_type = self.table_config.get("model", TABLE_MASTER)
  169. self.apply_ocr = ocr
  170. self.lang = kwargs.get("lang", None)
  171. logger.info(
  172. "DocAnalysis init, this may take some times. apply_layout: {}, apply_formula: {}, apply_ocr: {}, apply_table: {}, lang: {}".format(
  173. self.apply_layout, self.apply_formula, self.apply_ocr, self.apply_table, self.lang
  174. )
  175. )
  176. assert self.apply_layout, "DocAnalysis must contain layout model."
  177. # 初始化解析方案
  178. self.device = kwargs.get("device", self.configs["config"]["device"])
  179. logger.info("using device: {}".format(self.device))
  180. models_dir = kwargs.get("models_dir", os.path.join(root_dir, "resources", "models"))
  181. logger.info("using models_dir: {}".format(models_dir))
  182. atom_model_manager = AtomModelSingleton()
  183. # 初始化公式识别
  184. if self.apply_formula:
  185. # 初始化公式检测模型
  186. # self.mfd_model = mfd_model_init(str(os.path.join(models_dir, self.configs["weights"]["mfd"])))
  187. self.mfd_model = atom_model_manager.get_atom_model(
  188. atom_model_name=AtomicModel.MFD,
  189. mfd_weights=str(os.path.join(models_dir, self.configs["weights"]["mfd"]))
  190. )
  191. # 初始化公式解析模型
  192. mfr_weight_dir = str(os.path.join(models_dir, self.configs["weights"]["mfr"]))
  193. mfr_cfg_path = str(os.path.join(model_config_dir, "UniMERNet", "demo.yaml"))
  194. # self.mfr_model, mfr_vis_processors = mfr_model_init(mfr_weight_dir, mfr_cfg_path, _device_=self.device)
  195. # self.mfr_transform = transforms.Compose([mfr_vis_processors, ])
  196. self.mfr_model, self.mfr_transform = atom_model_manager.get_atom_model(
  197. atom_model_name=AtomicModel.MFR,
  198. mfr_weight_dir=mfr_weight_dir,
  199. mfr_cfg_path=mfr_cfg_path,
  200. device=self.device
  201. )
  202. # 初始化layout模型
  203. # self.layout_model = Layoutlmv3_Predictor(
  204. # str(os.path.join(models_dir, self.configs['weights']['layout'])),
  205. # str(os.path.join(model_config_dir, "layoutlmv3", "layoutlmv3_base_inference.yaml")),
  206. # device=self.device
  207. # )
  208. self.layout_model = atom_model_manager.get_atom_model(
  209. atom_model_name=AtomicModel.Layout,
  210. layout_weights=str(os.path.join(models_dir, self.configs['weights']['layout'])),
  211. layout_config_file=str(os.path.join(model_config_dir, "layoutlmv3", "layoutlmv3_base_inference.yaml")),
  212. device=self.device
  213. )
  214. # 初始化ocr
  215. if self.apply_ocr:
  216. # self.ocr_model = ModifiedPaddleOCR(show_log=show_log, det_db_box_thresh=0.3)
  217. self.ocr_model = atom_model_manager.get_atom_model(
  218. atom_model_name=AtomicModel.OCR,
  219. ocr_show_log=show_log,
  220. det_db_box_thresh=0.3,
  221. lang=self.lang
  222. )
  223. # init table model
  224. if self.apply_table:
  225. table_model_dir = self.configs["weights"][self.table_model_type]
  226. # self.table_model = table_model_init(self.table_model_type, str(os.path.join(models_dir, table_model_dir)),
  227. # max_time=self.table_max_time, _device_=self.device)
  228. self.table_model = atom_model_manager.get_atom_model(
  229. atom_model_name=AtomicModel.Table,
  230. table_model_type=self.table_model_type,
  231. table_model_path=str(os.path.join(models_dir, table_model_dir)),
  232. table_max_time=self.table_max_time,
  233. device=self.device
  234. )
  235. logger.info('DocAnalysis init done!')
  236. def __call__(self, image):
  237. latex_filling_list = []
  238. mf_image_list = []
  239. # layout检测
  240. layout_start = time.time()
  241. layout_res = self.layout_model(image, ignore_catids=[])
  242. layout_cost = round(time.time() - layout_start, 2)
  243. logger.info(f"layout detection cost: {layout_cost}")
  244. pil_img = Image.fromarray(image)
  245. if self.apply_formula:
  246. # 公式检测
  247. mfd_res = self.mfd_model.predict(image, imgsz=1888, conf=0.25, iou=0.45, verbose=True)[0]
  248. for xyxy, conf, cla in zip(mfd_res.boxes.xyxy.cpu(), mfd_res.boxes.conf.cpu(), mfd_res.boxes.cls.cpu()):
  249. xmin, ymin, xmax, ymax = [int(p.item()) for p in xyxy]
  250. new_item = {
  251. 'category_id': 13 + int(cla.item()),
  252. 'poly': [xmin, ymin, xmax, ymin, xmax, ymax, xmin, ymax],
  253. 'score': round(float(conf.item()), 2),
  254. 'latex': '',
  255. }
  256. layout_res.append(new_item)
  257. latex_filling_list.append(new_item)
  258. # bbox_img = get_croped_image(pil_img, [xmin, ymin, xmax, ymax])
  259. bbox_img = pil_img.crop((xmin, ymin, xmax, ymax))
  260. mf_image_list.append(bbox_img)
  261. # 公式识别
  262. mfr_start = time.time()
  263. dataset = MathDataset(mf_image_list, transform=self.mfr_transform)
  264. dataloader = DataLoader(dataset, batch_size=64, num_workers=0)
  265. mfr_res = []
  266. for mf_img in dataloader:
  267. mf_img = mf_img.to(self.device)
  268. output = self.mfr_model.generate({'image': mf_img})
  269. mfr_res.extend(output['pred_str'])
  270. for res, latex in zip(latex_filling_list, mfr_res):
  271. res['latex'] = latex_rm_whitespace(latex)
  272. mfr_cost = round(time.time() - mfr_start, 2)
  273. logger.info(f"formula nums: {len(mf_image_list)}, mfr time: {mfr_cost}")
  274. # Select regions for OCR / formula regions / table regions
  275. ocr_res_list = []
  276. table_res_list = []
  277. single_page_mfdetrec_res = []
  278. for res in layout_res:
  279. if int(res['category_id']) in [13, 14]:
  280. single_page_mfdetrec_res.append({
  281. "bbox": [int(res['poly'][0]), int(res['poly'][1]),
  282. int(res['poly'][4]), int(res['poly'][5])],
  283. })
  284. elif int(res['category_id']) in [0, 1, 2, 4, 6, 7]:
  285. ocr_res_list.append(res)
  286. elif int(res['category_id']) in [5]:
  287. table_res_list.append(res)
  288. clean_memory()
  289. # ocr识别
  290. if self.apply_ocr:
  291. ocr_start = time.time()
  292. # Process each area that requires OCR processing
  293. for res in ocr_res_list:
  294. new_image, useful_list = crop_img(res, pil_img, crop_paste_x=50, crop_paste_y=50)
  295. paste_x, paste_y, xmin, ymin, xmax, ymax, new_width, new_height = useful_list
  296. # Adjust the coordinates of the formula area
  297. adjusted_mfdetrec_res = []
  298. for mf_res in single_page_mfdetrec_res:
  299. mf_xmin, mf_ymin, mf_xmax, mf_ymax = mf_res["bbox"]
  300. # Adjust the coordinates of the formula area to the coordinates relative to the cropping area
  301. x0 = mf_xmin - xmin + paste_x
  302. y0 = mf_ymin - ymin + paste_y
  303. x1 = mf_xmax - xmin + paste_x
  304. y1 = mf_ymax - ymin + paste_y
  305. # Filter formula blocks outside the graph
  306. if any([x1 < 0, y1 < 0]) or any([x0 > new_width, y0 > new_height]):
  307. continue
  308. else:
  309. adjusted_mfdetrec_res.append({
  310. "bbox": [x0, y0, x1, y1],
  311. })
  312. # OCR recognition
  313. new_image = cv2.cvtColor(np.asarray(new_image), cv2.COLOR_RGB2BGR)
  314. ocr_res = self.ocr_model.ocr(new_image, mfd_res=adjusted_mfdetrec_res)[0]
  315. # Integration results
  316. if ocr_res:
  317. for box_ocr_res in ocr_res:
  318. p1, p2, p3, p4 = box_ocr_res[0]
  319. text, score = box_ocr_res[1]
  320. # Convert the coordinates back to the original coordinate system
  321. p1 = [p1[0] - paste_x + xmin, p1[1] - paste_y + ymin]
  322. p2 = [p2[0] - paste_x + xmin, p2[1] - paste_y + ymin]
  323. p3 = [p3[0] - paste_x + xmin, p3[1] - paste_y + ymin]
  324. p4 = [p4[0] - paste_x + xmin, p4[1] - paste_y + ymin]
  325. layout_res.append({
  326. 'category_id': 15,
  327. 'poly': p1 + p2 + p3 + p4,
  328. 'score': round(score, 2),
  329. 'text': text,
  330. })
  331. ocr_cost = round(time.time() - ocr_start, 2)
  332. logger.info(f"ocr cost: {ocr_cost}")
  333. # 表格识别 table recognition
  334. if self.apply_table:
  335. table_start = time.time()
  336. for res in table_res_list:
  337. new_image, _ = crop_img(res, pil_img)
  338. single_table_start_time = time.time()
  339. logger.info("------------------table recognition processing begins-----------------")
  340. latex_code = None
  341. html_code = None
  342. if self.table_model_type == STRUCT_EQTABLE:
  343. with torch.no_grad():
  344. latex_code = self.table_model.image2latex(new_image)[0]
  345. else:
  346. html_code = self.table_model.img2html(new_image)
  347. run_time = time.time() - single_table_start_time
  348. logger.info(f"------------table recognition processing ends within {run_time}s-----")
  349. if run_time > self.table_max_time:
  350. logger.warning(f"------------table recognition processing exceeds max time {self.table_max_time}s----------")
  351. # 判断是否返回正常
  352. if latex_code:
  353. expected_ending = latex_code.strip().endswith('end{tabular}') or latex_code.strip().endswith(
  354. 'end{table}')
  355. if expected_ending:
  356. res["latex"] = latex_code
  357. else:
  358. logger.warning(f"------------table recognition processing fails----------")
  359. elif html_code:
  360. res["html"] = html_code
  361. else:
  362. logger.warning(f"------------table recognition processing fails----------")
  363. table_cost = round(time.time() - table_start, 2)
  364. logger.info(f"table cost: {table_cost}")
  365. return layout_res