pdf_extract_kit.py 8.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201
  1. from loguru import logger
  2. import os
  3. import time
  4. try:
  5. import cv2
  6. import yaml
  7. import argparse
  8. import numpy as np
  9. import torch
  10. from paddleocr import draw_ocr
  11. from PIL import Image
  12. from torchvision import transforms
  13. from torch.utils.data import Dataset, DataLoader
  14. from ultralytics import YOLO
  15. from unimernet.common.config import Config
  16. import unimernet.tasks as tasks
  17. from unimernet.processors import load_processor
  18. from magic_pdf.model.pek_sub_modules.layoutlmv3.model_init import Layoutlmv3_Predictor
  19. from magic_pdf.model.pek_sub_modules.post_process import get_croped_image, latex_rm_whitespace
  20. from magic_pdf.model.pek_sub_modules.self_modify import ModifiedPaddleOCR
  21. except ImportError as e:
  22. logger.exception(e)
  23. logger.error('Required dependency not installed, please install by \n"pip install magic-pdf[full-cpu] detectron2 --extra-index-url https://myhloli.github.io/wheels/"')
  24. exit(1)
  25. def mfd_model_init(weight):
  26. mfd_model = YOLO(weight)
  27. return mfd_model
  28. def mfr_model_init(weight_dir, cfg_path, _device_='cpu'):
  29. args = argparse.Namespace(cfg_path=cfg_path, options=None)
  30. cfg = Config(args)
  31. cfg.config.model.pretrained = os.path.join(weight_dir, "pytorch_model.bin")
  32. cfg.config.model.model_config.model_name = weight_dir
  33. cfg.config.model.tokenizer_config.path = weight_dir
  34. task = tasks.setup_task(cfg)
  35. model = task.build_model(cfg)
  36. model = model.to(_device_)
  37. vis_processor = load_processor('formula_image_eval', cfg.config.datasets.formula_rec_eval.vis_processor.eval)
  38. return model, vis_processor
  39. def layout_model_init(weight, config_file, device):
  40. model = Layoutlmv3_Predictor(weight, config_file, device)
  41. return model
  42. class MathDataset(Dataset):
  43. def __init__(self, image_paths, transform=None):
  44. self.image_paths = image_paths
  45. self.transform = transform
  46. def __len__(self):
  47. return len(self.image_paths)
  48. def __getitem__(self, idx):
  49. # if not pil image, then convert to pil image
  50. if isinstance(self.image_paths[idx], str):
  51. raw_image = Image.open(self.image_paths[idx])
  52. else:
  53. raw_image = self.image_paths[idx]
  54. if self.transform:
  55. image = self.transform(raw_image)
  56. return image
  57. class CustomPEKModel:
  58. def __init__(self, ocr: bool = False, show_log: bool = False, **kwargs):
  59. """
  60. ======== model init ========
  61. """
  62. # 获取当前文件(即 pdf_extract_kit.py)的绝对路径
  63. current_file_path = os.path.abspath(__file__)
  64. # 获取当前文件所在的目录(model)
  65. current_dir = os.path.dirname(current_file_path)
  66. # 上一级目录(magic_pdf)
  67. root_dir = os.path.dirname(current_dir)
  68. # model_config目录
  69. model_config_dir = os.path.join(root_dir, 'resources', 'model_config')
  70. # 构建 model_configs.yaml 文件的完整路径
  71. config_path = os.path.join(model_config_dir, 'model_configs.yaml')
  72. with open(config_path, "r") as f:
  73. self.configs = yaml.load(f, Loader=yaml.FullLoader)
  74. # 初始化解析配置
  75. self.apply_layout = kwargs.get("apply_layout", self.configs["config"]["layout"])
  76. self.apply_formula = kwargs.get("apply_formula", self.configs["config"]["formula"])
  77. self.apply_ocr = ocr
  78. logger.info(
  79. "DocAnalysis init, this may take some times. apply_layout: {}, apply_formula: {}, apply_ocr: {}".format(
  80. self.apply_layout, self.apply_formula, self.apply_ocr
  81. )
  82. )
  83. assert self.apply_layout, "DocAnalysis must contain layout model."
  84. # 初始化解析方案
  85. self.device = kwargs.get("device", self.configs["config"]["device"])
  86. logger.info("using device: {}".format(self.device))
  87. models_dir = kwargs.get("models_dir", os.path.join(root_dir, "resources", "models"))
  88. # 初始化公式识别
  89. if self.apply_formula:
  90. # 初始化公式检测模型
  91. self.mfd_model = mfd_model_init(str(os.path.join(models_dir, self.configs["weights"]["mfd"])))
  92. # 初始化公式解析模型
  93. mfr_weight_dir = str(os.path.join(models_dir, self.configs["weights"]["mfr"]))
  94. mfr_cfg_path = str(os.path.join(model_config_dir, "UniMERNet", "demo.yaml"))
  95. self.mfr_model, mfr_vis_processors = mfr_model_init(mfr_weight_dir, mfr_cfg_path, _device_=self.device)
  96. self.mfr_transform = transforms.Compose([mfr_vis_processors, ])
  97. # 初始化layout模型
  98. self.layout_model = Layoutlmv3_Predictor(
  99. str(os.path.join(models_dir, self.configs['weights']['layout'])),
  100. str(os.path.join(model_config_dir, "layoutlmv3", "layoutlmv3_base_inference.yaml")),
  101. device=self.device
  102. )
  103. # 初始化ocr
  104. if self.apply_ocr:
  105. self.ocr_model = ModifiedPaddleOCR(show_log=show_log)
  106. logger.info('DocAnalysis init done!')
  107. def __call__(self, image):
  108. latex_filling_list = []
  109. mf_image_list = []
  110. # layout检测
  111. layout_start = time.time()
  112. layout_res = self.layout_model(image, ignore_catids=[])
  113. layout_cost = round(time.time() - layout_start, 2)
  114. logger.info(f"layout detection cost: {layout_cost}")
  115. # 公式检测
  116. mfd_res = self.mfd_model.predict(image, imgsz=1888, conf=0.25, iou=0.45, verbose=True)[0]
  117. for xyxy, conf, cla in zip(mfd_res.boxes.xyxy.cpu(), mfd_res.boxes.conf.cpu(), mfd_res.boxes.cls.cpu()):
  118. xmin, ymin, xmax, ymax = [int(p.item()) for p in xyxy]
  119. new_item = {
  120. 'category_id': 13 + int(cla.item()),
  121. 'poly': [xmin, ymin, xmax, ymin, xmax, ymax, xmin, ymax],
  122. 'score': round(float(conf.item()), 2),
  123. 'latex': '',
  124. }
  125. layout_res.append(new_item)
  126. latex_filling_list.append(new_item)
  127. bbox_img = get_croped_image(Image.fromarray(image), [xmin, ymin, xmax, ymax])
  128. mf_image_list.append(bbox_img)
  129. # 公式识别
  130. mfr_start = time.time()
  131. dataset = MathDataset(mf_image_list, transform=self.mfr_transform)
  132. dataloader = DataLoader(dataset, batch_size=64, num_workers=0)
  133. mfr_res = []
  134. for mf_img in dataloader:
  135. mf_img = mf_img.to(self.device)
  136. output = self.mfr_model.generate({'image': mf_img})
  137. mfr_res.extend(output['pred_str'])
  138. for res, latex in zip(latex_filling_list, mfr_res):
  139. res['latex'] = latex_rm_whitespace(latex)
  140. mfr_cost = round(time.time() - mfr_start, 2)
  141. logger.info(f"formula nums: {len(mf_image_list)}, mfr time: {mfr_cost}")
  142. # ocr识别
  143. if self.apply_ocr:
  144. ocr_start = time.time()
  145. pil_img = Image.fromarray(image)
  146. single_page_mfdetrec_res = []
  147. for res in layout_res:
  148. if int(res['category_id']) in [13, 14]:
  149. xmin, ymin = int(res['poly'][0]), int(res['poly'][1])
  150. xmax, ymax = int(res['poly'][4]), int(res['poly'][5])
  151. single_page_mfdetrec_res.append({
  152. "bbox": [xmin, ymin, xmax, ymax],
  153. })
  154. for res in layout_res:
  155. if int(res['category_id']) in [0, 1, 2, 4, 6, 7]: # 需要进行ocr的类别
  156. xmin, ymin = int(res['poly'][0]), int(res['poly'][1])
  157. xmax, ymax = int(res['poly'][4]), int(res['poly'][5])
  158. crop_box = (xmin, ymin, xmax, ymax)
  159. cropped_img = Image.new('RGB', pil_img.size, 'white')
  160. cropped_img.paste(pil_img.crop(crop_box), crop_box)
  161. cropped_img = cv2.cvtColor(np.asarray(cropped_img), cv2.COLOR_RGB2BGR)
  162. ocr_res = self.ocr_model.ocr(cropped_img, mfd_res=single_page_mfdetrec_res)[0]
  163. if ocr_res:
  164. for box_ocr_res in ocr_res:
  165. p1, p2, p3, p4 = box_ocr_res[0]
  166. text, score = box_ocr_res[1]
  167. layout_res.append({
  168. 'category_id': 15,
  169. 'poly': p1 + p2 + p3 + p4,
  170. 'score': round(score, 2),
  171. 'text': text,
  172. })
  173. ocr_cost = round(time.time() - ocr_start, 2)
  174. logger.info(f"ocr cost: {ocr_cost}")
  175. return layout_res