pdf_extract_kit.py 3.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687
  1. import os
  2. import numpy as np
  3. import yaml
  4. from ultralytics import YOLO
  5. from loguru import logger
  6. from magic_pdf.model.pek_sub_modules.layoutlmv3.model_init import Layoutlmv3_Predictor
  7. from unimernet.common.config import Config
  8. import unimernet.tasks as tasks
  9. from unimernet.processors import load_processor
  10. import argparse
  11. from torchvision import transforms
  12. from magic_pdf.model.pek_sub_modules.self_modify import ModifiedPaddleOCR
  13. def layout_model_init(weight, config_file):
  14. model = Layoutlmv3_Predictor(weight, config_file)
  15. return model
  16. def mfr_model_init(weight_dir, cfg_path, device='cpu'):
  17. args = argparse.Namespace(cfg_path=cfg_path, options=None)
  18. cfg = Config(args)
  19. cfg.config.model.pretrained = os.path.join(weight_dir, "pytorch_model.bin")
  20. cfg.config.model.model_config.model_name = weight_dir
  21. cfg.config.model.tokenizer_config.path = weight_dir
  22. task = tasks.setup_task(cfg)
  23. model = task.build_model(cfg)
  24. model = model.to(device)
  25. vis_processor = load_processor('formula_image_eval', cfg.config.datasets.formula_rec_eval.vis_processor.eval)
  26. return model, vis_processor
  27. class CustomPEKModel:
  28. def __init__(self, ocr: bool = False, show_log: bool = False, **kwargs):
  29. """
  30. ======== model init ========
  31. """
  32. # 获取当前文件(即 pdf_extract_kit.py)的绝对路径
  33. current_file_path = os.path.abspath(__file__)
  34. # 获取当前文件所在的目录(model)
  35. current_dir = os.path.dirname(current_file_path)
  36. # 上一级目录(magic_pdf)
  37. root_dir = os.path.dirname(current_dir)
  38. # model_config目录
  39. model_config_dir = os.path.join(root_dir, 'resources', 'model_config')
  40. # 构建 model_configs.yaml 文件的完整路径
  41. config_path = os.path.join(model_config_dir, 'model_configs.yaml')
  42. with open(config_path, "r") as f:
  43. self.configs = yaml.load(f, Loader=yaml.FullLoader)
  44. # 初始化解析配置
  45. self.apply_layout = kwargs.get("apply_layout", self.configs["config"]["layout"])
  46. self.apply_formula = kwargs.get("apply_formula", self.configs["config"]["formula"])
  47. self.apply_ocr = ocr
  48. logger.info(
  49. "DocAnalysis init, this may take some times. apply_layout: {}, apply_formula: {}, apply_ocr: {}".format(
  50. self.apply_layout, self.apply_formula, self.apply_ocr
  51. )
  52. )
  53. assert self.apply_layout, "DocAnalysis must contain layout model."
  54. # 初始化解析方案
  55. self.device = self.configs["config"]["device"]
  56. logger.info("using device: {}".format(self.device))
  57. # 初始化layout模型
  58. self.layout_model = layout_model_init(
  59. os.path.join(root_dir, self.configs['weights']['layout']),
  60. os.path.join(model_config_dir, "layoutlmv3", "layoutlmv3_base_inference.yaml")
  61. )
  62. # 初始化公式识别
  63. if self.apply_formula:
  64. # 初始化公式检测模型
  65. self.mfd_model = YOLO(model=str(os.path.join(root_dir, self.configs["weights"]["mfd"])))
  66. # 初始化公式解析模型
  67. mfr_config_path = os.path.join(model_config_dir, 'UniMERNet', 'demo.yaml')
  68. self.mfr_model, mfr_vis_processors = mfr_model_init(
  69. os.path.join(root_dir, self.configs["weights"]["mfr"]), mfr_config_path,
  70. device=self.device)
  71. self.mfr_transform = transforms.Compose([mfr_vis_processors, ])
  72. # 初始化ocr
  73. if self.apply_ocr:
  74. self.ocr_model = ModifiedPaddleOCR(show_log=show_log)
  75. logger.info('DocAnalysis init done!')
  76. def __call__(self, image):
  77. pass