pdf_extract_kit.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296
  1. from loguru import logger
  2. import os
  3. import time
  4. from magic_pdf.libs.Constants import TABLE_MAX_TIME_VALUE
  5. os.environ['NO_ALBUMENTATIONS_UPDATE'] = '1' # 禁止albumentations检查更新
  6. try:
  7. import cv2
  8. import yaml
  9. import argparse
  10. import numpy as np
  11. import torch
  12. import torchtext
  13. if torchtext.__version__ >= "0.18.0":
  14. torchtext.disable_torchtext_deprecation_warning()
  15. from PIL import Image
  16. from torchvision import transforms
  17. from torch.utils.data import Dataset, DataLoader
  18. from ultralytics import YOLO
  19. from unimernet.common.config import Config
  20. import unimernet.tasks as tasks
  21. from unimernet.processors import load_processor
  22. except ImportError as e:
  23. logger.exception(e)
  24. logger.error(
  25. 'Required dependency not installed, please install by \n'
  26. '"pip install magic-pdf[full] --extra-index-url https://myhloli.github.io/wheels/"')
  27. exit(1)
  28. from magic_pdf.model.pek_sub_modules.layoutlmv3.model_init import Layoutlmv3_Predictor
  29. from magic_pdf.model.pek_sub_modules.post_process import get_croped_image, latex_rm_whitespace
  30. from magic_pdf.model.pek_sub_modules.self_modify import ModifiedPaddleOCR
  31. from magic_pdf.model.pek_sub_modules.structeqtable.StructTableModel import StructTableModel
  32. def table_model_init(model_path, max_time, _device_='cpu'):
  33. table_model = StructTableModel(model_path, max_time=max_time, device=_device_)
  34. return table_model
  35. def mfd_model_init(weight):
  36. mfd_model = YOLO(weight)
  37. return mfd_model
  38. def mfr_model_init(weight_dir, cfg_path, _device_='cpu'):
  39. args = argparse.Namespace(cfg_path=cfg_path, options=None)
  40. cfg = Config(args)
  41. cfg.config.model.pretrained = os.path.join(weight_dir, "pytorch_model.bin")
  42. cfg.config.model.model_config.model_name = weight_dir
  43. cfg.config.model.tokenizer_config.path = weight_dir
  44. task = tasks.setup_task(cfg)
  45. model = task.build_model(cfg)
  46. model = model.to(_device_)
  47. vis_processor = load_processor('formula_image_eval', cfg.config.datasets.formula_rec_eval.vis_processor.eval)
  48. return model, vis_processor
  49. def layout_model_init(weight, config_file, device):
  50. model = Layoutlmv3_Predictor(weight, config_file, device)
  51. return model
  52. class MathDataset(Dataset):
  53. def __init__(self, image_paths, transform=None):
  54. self.image_paths = image_paths
  55. self.transform = transform
  56. def __len__(self):
  57. return len(self.image_paths)
  58. def __getitem__(self, idx):
  59. # if not pil image, then convert to pil image
  60. if isinstance(self.image_paths[idx], str):
  61. raw_image = Image.open(self.image_paths[idx])
  62. else:
  63. raw_image = self.image_paths[idx]
  64. if self.transform:
  65. image = self.transform(raw_image)
  66. return image
  67. class CustomPEKModel:
  68. def __init__(self, ocr: bool = False, show_log: bool = False, **kwargs):
  69. """
  70. ======== model init ========
  71. """
  72. # 获取当前文件(即 pdf_extract_kit.py)的绝对路径
  73. current_file_path = os.path.abspath(__file__)
  74. # 获取当前文件所在的目录(model)
  75. current_dir = os.path.dirname(current_file_path)
  76. # 上一级目录(magic_pdf)
  77. root_dir = os.path.dirname(current_dir)
  78. # model_config目录
  79. model_config_dir = os.path.join(root_dir, 'resources', 'model_config')
  80. # 构建 model_configs.yaml 文件的完整路径
  81. config_path = os.path.join(model_config_dir, 'model_configs.yaml')
  82. with open(config_path, "r", encoding='utf-8') as f:
  83. self.configs = yaml.load(f, Loader=yaml.FullLoader)
  84. # 初始化解析配置
  85. self.apply_layout = kwargs.get("apply_layout", self.configs["config"]["layout"])
  86. self.apply_formula = kwargs.get("apply_formula", self.configs["config"]["formula"])
  87. self.table_config = kwargs.get("table_config", self.configs["config"]["table_config"])
  88. self.apply_table = self.table_config.get("is_table_recog_enable", False)
  89. self.table_max_time = self.table_config.get("max_time", TABLE_MAX_TIME_VALUE)
  90. self.apply_ocr = ocr
  91. logger.info(
  92. "DocAnalysis init, this may take some times. apply_layout: {}, apply_formula: {}, apply_ocr: {}, apply_table: {}".format(
  93. self.apply_layout, self.apply_formula, self.apply_ocr, self.apply_table
  94. )
  95. )
  96. assert self.apply_layout, "DocAnalysis must contain layout model."
  97. # 初始化解析方案
  98. self.device = kwargs.get("device", self.configs["config"]["device"])
  99. logger.info("using device: {}".format(self.device))
  100. models_dir = kwargs.get("models_dir", os.path.join(root_dir, "resources", "models"))
  101. logger.info("using models_dir: {}".format(models_dir))
  102. # 初始化公式识别
  103. if self.apply_formula:
  104. # 初始化公式检测模型
  105. self.mfd_model = mfd_model_init(str(os.path.join(models_dir, self.configs["weights"]["mfd"])))
  106. # 初始化公式解析模型
  107. mfr_weight_dir = str(os.path.join(models_dir, self.configs["weights"]["mfr"]))
  108. mfr_cfg_path = str(os.path.join(model_config_dir, "UniMERNet", "demo.yaml"))
  109. self.mfr_model, mfr_vis_processors = mfr_model_init(mfr_weight_dir, mfr_cfg_path, _device_=self.device)
  110. self.mfr_transform = transforms.Compose([mfr_vis_processors, ])
  111. # 初始化layout模型
  112. self.layout_model = Layoutlmv3_Predictor(
  113. str(os.path.join(models_dir, self.configs['weights']['layout'])),
  114. str(os.path.join(model_config_dir, "layoutlmv3", "layoutlmv3_base_inference.yaml")),
  115. device=self.device
  116. )
  117. # 初始化ocr
  118. if self.apply_ocr:
  119. self.ocr_model = ModifiedPaddleOCR(show_log=show_log, det_db_box_thresh=0.3)
  120. # init structeqtable
  121. if self.apply_table:
  122. self.table_model = table_model_init(str(os.path.join(models_dir, self.configs["weights"]["table"])),
  123. max_time = self.table_max_time, _device_=self.device)
  124. logger.info('DocAnalysis init done!')
  125. def __call__(self, image):
  126. latex_filling_list = []
  127. mf_image_list = []
  128. # layout检测
  129. layout_start = time.time()
  130. layout_res = self.layout_model(image, ignore_catids=[])
  131. layout_cost = round(time.time() - layout_start, 2)
  132. logger.info(f"layout detection cost: {layout_cost}")
  133. if self.apply_formula:
  134. # 公式检测
  135. mfd_res = self.mfd_model.predict(image, imgsz=1888, conf=0.25, iou=0.45, verbose=True)[0]
  136. for xyxy, conf, cla in zip(mfd_res.boxes.xyxy.cpu(), mfd_res.boxes.conf.cpu(), mfd_res.boxes.cls.cpu()):
  137. xmin, ymin, xmax, ymax = [int(p.item()) for p in xyxy]
  138. new_item = {
  139. 'category_id': 13 + int(cla.item()),
  140. 'poly': [xmin, ymin, xmax, ymin, xmax, ymax, xmin, ymax],
  141. 'score': round(float(conf.item()), 2),
  142. 'latex': '',
  143. }
  144. layout_res.append(new_item)
  145. latex_filling_list.append(new_item)
  146. bbox_img = get_croped_image(Image.fromarray(image), [xmin, ymin, xmax, ymax])
  147. mf_image_list.append(bbox_img)
  148. # 公式识别
  149. mfr_start = time.time()
  150. dataset = MathDataset(mf_image_list, transform=self.mfr_transform)
  151. dataloader = DataLoader(dataset, batch_size=64, num_workers=0)
  152. mfr_res = []
  153. for mf_img in dataloader:
  154. mf_img = mf_img.to(self.device)
  155. output = self.mfr_model.generate({'image': mf_img})
  156. mfr_res.extend(output['pred_str'])
  157. for res, latex in zip(latex_filling_list, mfr_res):
  158. res['latex'] = latex_rm_whitespace(latex)
  159. mfr_cost = round(time.time() - mfr_start, 2)
  160. logger.info(f"formula nums: {len(mf_image_list)}, mfr time: {mfr_cost}")
  161. # Select regions for OCR / formula regions / table regions
  162. ocr_res_list = []
  163. table_res_list = []
  164. single_page_mfdetrec_res = []
  165. for res in layout_res:
  166. if int(res['category_id']) in [13, 14]:
  167. single_page_mfdetrec_res.append({
  168. "bbox": [int(res['poly'][0]), int(res['poly'][1]),
  169. int(res['poly'][4]), int(res['poly'][5])],
  170. })
  171. elif int(res['category_id']) in [0, 1, 2, 4, 6, 7]:
  172. ocr_res_list.append(res)
  173. elif int(res['category_id']) in [5]:
  174. table_res_list.append(res)
  175. # Unified crop img logic
  176. def crop_img(input_res, input_pil_img, crop_paste_x=0, crop_paste_y=0):
  177. crop_xmin, crop_ymin = int(input_res['poly'][0]), int(input_res['poly'][1])
  178. crop_xmax, crop_ymax = int(input_res['poly'][4]), int(input_res['poly'][5])
  179. # Create a white background with an additional width and height of 50
  180. crop_new_width = crop_xmax - crop_xmin + crop_paste_x * 2
  181. crop_new_height = crop_ymax - crop_ymin + crop_paste_y * 2
  182. return_image = Image.new('RGB', (crop_new_width, crop_new_height), 'white')
  183. # Crop image
  184. crop_box = (crop_xmin, crop_ymin, crop_xmax, crop_ymax)
  185. cropped_img = input_pil_img.crop(crop_box)
  186. return_image.paste(cropped_img, (crop_paste_x, crop_paste_y))
  187. return_list = [crop_paste_x, crop_paste_y, crop_xmin, crop_ymin, crop_xmax, crop_ymax, crop_new_width, crop_new_height]
  188. return return_image, return_list
  189. pil_img = Image.fromarray(image)
  190. # ocr识别
  191. if self.apply_ocr:
  192. ocr_start = time.time()
  193. # Process each area that requires OCR processing
  194. for res in ocr_res_list:
  195. new_image, useful_list = crop_img(res, pil_img, crop_paste_x=50, crop_paste_y=50)
  196. paste_x, paste_y, xmin, ymin, xmax, ymax, new_width, new_height = useful_list
  197. # Adjust the coordinates of the formula area
  198. adjusted_mfdetrec_res = []
  199. for mf_res in single_page_mfdetrec_res:
  200. mf_xmin, mf_ymin, mf_xmax, mf_ymax = mf_res["bbox"]
  201. # Adjust the coordinates of the formula area to the coordinates relative to the cropping area
  202. x0 = mf_xmin - xmin + paste_x
  203. y0 = mf_ymin - ymin + paste_y
  204. x1 = mf_xmax - xmin + paste_x
  205. y1 = mf_ymax - ymin + paste_y
  206. # Filter formula blocks outside the graph
  207. if any([x1 < 0, y1 < 0]) or any([x0 > new_width, y0 > new_height]):
  208. continue
  209. else:
  210. adjusted_mfdetrec_res.append({
  211. "bbox": [x0, y0, x1, y1],
  212. })
  213. # OCR recognition
  214. new_image = cv2.cvtColor(np.asarray(new_image), cv2.COLOR_RGB2BGR)
  215. ocr_res = self.ocr_model.ocr(new_image, mfd_res=adjusted_mfdetrec_res)[0]
  216. # Integration results
  217. if ocr_res:
  218. for box_ocr_res in ocr_res:
  219. p1, p2, p3, p4 = box_ocr_res[0]
  220. text, score = box_ocr_res[1]
  221. # Convert the coordinates back to the original coordinate system
  222. p1 = [p1[0] - paste_x + xmin, p1[1] - paste_y + ymin]
  223. p2 = [p2[0] - paste_x + xmin, p2[1] - paste_y + ymin]
  224. p3 = [p3[0] - paste_x + xmin, p3[1] - paste_y + ymin]
  225. p4 = [p4[0] - paste_x + xmin, p4[1] - paste_y + ymin]
  226. layout_res.append({
  227. 'category_id': 15,
  228. 'poly': p1 + p2 + p3 + p4,
  229. 'score': round(score, 2),
  230. 'text': text,
  231. })
  232. ocr_cost = round(time.time() - ocr_start, 2)
  233. logger.info(f"ocr cost: {ocr_cost}")
  234. # 表格识别 table recognition
  235. if self.apply_table:
  236. table_start = time.time()
  237. for res in table_res_list:
  238. new_image, _ = crop_img(res, pil_img)
  239. single_table_start_time = time.time()
  240. logger.info("------------------table recognition processing begins-----------------")
  241. with torch.no_grad():
  242. latex_code = self.table_model.image2latex(new_image)[0]
  243. run_time = time.time() - single_table_start_time
  244. logger.info(f"------------table recognition processing ends within {run_time}s-----")
  245. if run_time > self.table_max_time:
  246. logger.warning(f"------------table recognition processing exceeds max time {self.table_max_time}s----------")
  247. # 判断是否返回正常
  248. expected_ending = latex_code.strip().endswith('end{tabular}') or latex_code.strip().endswith('end{table}')
  249. if latex_code and expected_ending:
  250. res["latex"] = latex_code
  251. else:
  252. logger.warning(f"------------table recognition processing fails----------")
  253. table_cost = round(time.time() - table_start, 2)
  254. logger.info(f"table cost: {table_cost}")
  255. return layout_res