Ver Fonte

refactor(magic_pdf): support mps device and optimize image processing

- Add support for Apple M1 chips (mps device)
- Refactor image processing for better performance and compatibility
- Update model loading and inference for various devices
- Adjust batch processing and memory management
myhloli há 8 meses atrás
pai
commit
af27c0cc81

+ 7 - 6
magic_pdf/model/doc_analyze_by_custom_model.py

@@ -256,27 +256,28 @@ def may_batch_image_analyze(
     batch_ratio = 1
     batch_ratio = 1
     device = get_device()
     device = get_device()
 
 
-    npu_support = False
     if str(device).startswith('npu'):
     if str(device).startswith('npu'):
         import torch_npu
         import torch_npu
         if torch_npu.npu.is_available():
         if torch_npu.npu.is_available():
-            npu_support = True
             torch.npu.set_compile_mode(jit_compile=False)
             torch.npu.set_compile_mode(jit_compile=False)
 
 
-    if torch.cuda.is_available() and device != 'cpu' or npu_support:
+    if str(device).startswith('npu') or str(device).startswith('cuda'):
         gpu_memory = int(os.getenv('VIRTUAL_VRAM_SIZE', round(get_vram(device))))
         gpu_memory = int(os.getenv('VIRTUAL_VRAM_SIZE', round(get_vram(device))))
-        if gpu_memory is not None and gpu_memory >= 8:
+        if gpu_memory is not None:
             if gpu_memory >= 20:
             if gpu_memory >= 20:
                 batch_ratio = 16
                 batch_ratio = 16
             elif gpu_memory >= 15:
             elif gpu_memory >= 15:
                 batch_ratio = 8
                 batch_ratio = 8
             elif gpu_memory >= 10:
             elif gpu_memory >= 10:
                 batch_ratio = 4
                 batch_ratio = 4
-            else:
+            elif gpu_memory >= 7:
                 batch_ratio = 2
                 batch_ratio = 2
-
+            else:
+                batch_ratio = 1
             logger.info(f'gpu_memory: {gpu_memory} GB, batch_ratio: {batch_ratio}')
             logger.info(f'gpu_memory: {gpu_memory} GB, batch_ratio: {batch_ratio}')
             batch_analyze = True
             batch_analyze = True
+    elif str(device).startswith('mps'):
+        batch_analyze = True
     doc_analyze_start = time.time()
     doc_analyze_start = time.time()
 
 
     if batch_analyze:
     if batch_analyze:

+ 1 - 1
magic_pdf/model/pdf_extract_kit.py

@@ -118,7 +118,7 @@ class CustomPEKModel:
                 atom_model_name=AtomicModel.MFR,
                 atom_model_name=AtomicModel.MFR,
                 mfr_weight_dir=mfr_weight_dir,
                 mfr_weight_dir=mfr_weight_dir,
                 mfr_cfg_path=mfr_cfg_path,
                 mfr_cfg_path=mfr_cfg_path,
-                device='cpu' if str(self.device).startswith("mps") else self.device,
+                device=self.device,
             )
             )
 
 
         # 初始化layout模型
         # 初始化layout模型

+ 0 - 2
magic_pdf/model/sub_modules/language_detection/yolov11/YOLOv11.py

@@ -44,7 +44,6 @@ def split_images(image, result_images=None):
             # 判断裁剪区域是否超出图片范围,如果超出则不进行裁剪保存操作
             # 判断裁剪区域是否超出图片范围,如果超出则不进行裁剪保存操作
             if x + new_long_side > width:
             if x + new_long_side > width:
                 continue
                 continue
-            box = (x, 0, x + new_long_side, height)
             sub_image = image[0:height, x:x + new_long_side]
             sub_image = image[0:height, x:x + new_long_side]
             sub_images.append(sub_image)
             sub_images.append(sub_image)
     else:  # 如果高度是较长边
     else:  # 如果高度是较长边
@@ -52,7 +51,6 @@ def split_images(image, result_images=None):
             # 判断裁剪区域是否超出图片范围,如果超出则不进行裁剪保存操作
             # 判断裁剪区域是否超出图片范围,如果超出则不进行裁剪保存操作
             if y + new_long_side > height:
             if y + new_long_side > height:
                 continue
                 continue
-            box = (0, y, width, y + new_long_side)
             sub_image = image[y:y + new_long_side, 0:width]
             sub_image = image[y:y + new_long_side, 0:width]
             sub_images.append(sub_image)
             sub_images.append(sub_image)
 
 

+ 2 - 0
magic_pdf/model/sub_modules/layout/doclayout_yolo/DocLayoutYOLO.py

@@ -4,6 +4,8 @@ from doclayout_yolo import YOLOv10
 class DocLayoutYOLOModel(object):
 class DocLayoutYOLOModel(object):
     def __init__(self, weight, device):
     def __init__(self, weight, device):
         self.model = YOLOv10(weight)
         self.model = YOLOv10(weight)
+        if not device.startswith("cpu"):
+            self.model.half()
         self.device = device
         self.device = device
 
 
     def predict(self, image):
     def predict(self, image):

+ 2 - 0
magic_pdf/model/sub_modules/mfd/yolov8/YOLOv8.py

@@ -4,6 +4,8 @@ from ultralytics import YOLO
 class YOLOv8MFDModel(object):
 class YOLOv8MFDModel(object):
     def __init__(self, weight, device="cpu"):
     def __init__(self, weight, device="cpu"):
         self.mfd_model = YOLO(weight)
         self.mfd_model = YOLO(weight)
+        if not device.startswith("cpu"):
+            self.mfd_model.half()
         self.device = device
         self.device = device
 
 
     def predict(self, image):
     def predict(self, image):

+ 21 - 47
magic_pdf/model/sub_modules/mfr/unimernet/Unimernet.py

@@ -1,13 +1,5 @@
-import argparse
-import os
-import re
-
 import torch
 import torch
-import unimernet.tasks as tasks
 from torch.utils.data import DataLoader, Dataset
 from torch.utils.data import DataLoader, Dataset
-from torchvision import transforms
-from unimernet.common.config import Config
-from unimernet.processors import load_processor
 
 
 
 
 class MathDataset(Dataset):
 class MathDataset(Dataset):
@@ -18,46 +10,26 @@ class MathDataset(Dataset):
     def __len__(self):
     def __len__(self):
         return len(self.image_paths)
         return len(self.image_paths)
 
 
-
-def latex_rm_whitespace(s: str):
-    """Remove unnecessary whitespace from LaTeX code."""
-    text_reg = r"(\\(operatorname|mathrm|text|mathbf)\s?\*? {.*?})"
-    letter = "[a-zA-Z]"
-    noletter = "[\W_^\d]"
-    names = [x[0].replace(" ", "") for x in re.findall(text_reg, s)]
-    s = re.sub(text_reg, lambda match: str(names.pop(0)), s)
-    news = s
-    while True:
-        s = news
-        news = re.sub(r"(?!\\ )(%s)\s+?(%s)" % (noletter, noletter), r"\1\2", s)
-        news = re.sub(r"(?!\\ )(%s)\s+?(%s)" % (noletter, letter), r"\1\2", news)
-        news = re.sub(r"(%s)\s+?(%s)" % (letter, noletter), r"\1\2", news)
-        if news == s:
-            break
-    return s
+    def __getitem__(self, idx):
+        raw_image = self.image_paths[idx]
+        if self.transform:
+            image = self.transform(raw_image)
+            return image
 
 
 
 
 class UnimernetModel(object):
 class UnimernetModel(object):
     def __init__(self, weight_dir, cfg_path, _device_="cpu"):
     def __init__(self, weight_dir, cfg_path, _device_="cpu"):
-        args = argparse.Namespace(cfg_path=cfg_path, options=None)
-        cfg = Config(args)
-        cfg.config.model.pretrained = os.path.join(weight_dir, "pytorch_model.pth")
-        cfg.config.model.model_config.model_name = weight_dir
-        cfg.config.model.tokenizer_config.path = weight_dir
-        task = tasks.setup_task(cfg)
-        self.model = task.build_model(cfg)
+        from .unimernet_hf import UnimernetModel
+        if _device_.startswith("mps"):
+            self.model = UnimernetModel.from_pretrained(weight_dir, attn_implementation="eager")
+        else:
+            self.model = UnimernetModel.from_pretrained(weight_dir)
         self.device = _device_
         self.device = _device_
         self.model.to(_device_)
         self.model.to(_device_)
+        if not _device_.startswith("cpu"):
+            self.model = self.model.to(dtype=torch.float16)
         self.model.eval()
         self.model.eval()
-        vis_processor = load_processor(
-            "formula_image_eval",
-            cfg.config.datasets.formula_rec_eval.vis_processor.eval,
-        )
-        self.mfr_transform = transforms.Compose(
-            [
-                vis_processor,
-            ]
-        )
+
 
 
     def predict(self, mfd_res, image):
     def predict(self, mfd_res, image):
         formula_list = []
         formula_list = []
@@ -76,16 +48,17 @@ class UnimernetModel(object):
             bbox_img = image[ymin:ymax, xmin:xmax]
             bbox_img = image[ymin:ymax, xmin:xmax]
             mf_image_list.append(bbox_img)
             mf_image_list.append(bbox_img)
 
 
-        dataset = MathDataset(mf_image_list, transform=self.mfr_transform)
+        dataset = MathDataset(mf_image_list, transform=self.model.transform)
         dataloader = DataLoader(dataset, batch_size=32, num_workers=0)
         dataloader = DataLoader(dataset, batch_size=32, num_workers=0)
         mfr_res = []
         mfr_res = []
         for mf_img in dataloader:
         for mf_img in dataloader:
+            mf_img = mf_img.to(dtype=self.model.dtype)
             mf_img = mf_img.to(self.device)
             mf_img = mf_img.to(self.device)
             with torch.no_grad():
             with torch.no_grad():
                 output = self.model.generate({"image": mf_img})
                 output = self.model.generate({"image": mf_img})
-            mfr_res.extend(output["pred_str"])
+            mfr_res.extend(output["fixed_str"])
         for res, latex in zip(formula_list, mfr_res):
         for res, latex in zip(formula_list, mfr_res):
-            res["latex"] = latex_rm_whitespace(latex)
+            res["latex"] = latex
         return formula_list
         return formula_list
 
 
     def batch_predict(self, images_mfd_res: list, images: list, batch_size: int = 64) -> list:
     def batch_predict(self, images_mfd_res: list, images: list, batch_size: int = 64) -> list:
@@ -130,22 +103,23 @@ class UnimernetModel(object):
         index_mapping = {new_idx: old_idx for new_idx, old_idx in enumerate(sorted_indices)}
         index_mapping = {new_idx: old_idx for new_idx, old_idx in enumerate(sorted_indices)}
 
 
         # Create dataset with sorted images
         # Create dataset with sorted images
-        dataset = MathDataset(sorted_images, transform=self.mfr_transform)
+        dataset = MathDataset(sorted_images, transform=self.model.transform)
         dataloader = DataLoader(dataset, batch_size=batch_size, num_workers=0)
         dataloader = DataLoader(dataset, batch_size=batch_size, num_workers=0)
 
 
         # Process batches and store results
         # Process batches and store results
         mfr_res = []
         mfr_res = []
         for mf_img in dataloader:
         for mf_img in dataloader:
+            mf_img = mf_img.to(dtype=self.model.dtype)
             mf_img = mf_img.to(self.device)
             mf_img = mf_img.to(self.device)
             with torch.no_grad():
             with torch.no_grad():
                 output = self.model.generate({"image": mf_img})
                 output = self.model.generate({"image": mf_img})
-            mfr_res.extend(output["pred_str"])
+            mfr_res.extend(output["fixed_str"])
 
 
         # Restore original order
         # Restore original order
         unsorted_results = [""] * len(mfr_res)
         unsorted_results = [""] * len(mfr_res)
         for new_idx, latex in enumerate(mfr_res):
         for new_idx, latex in enumerate(mfr_res):
             original_idx = index_mapping[new_idx]
             original_idx = index_mapping[new_idx]
-            unsorted_results[original_idx] = latex_rm_whitespace(latex)
+            unsorted_results[original_idx] = latex
 
 
         # Fill results back
         # Fill results back
         for res, latex in zip(backfill_list, unsorted_results):
         for res, latex in zip(backfill_list, unsorted_results):

+ 70 - 38
magic_pdf/model/sub_modules/mfr/unimernet/unimernet_hf/unimer_swin/image_processing_unimer_swin.py

@@ -1,17 +1,15 @@
 from transformers.image_processing_utils import BaseImageProcessor
 from transformers.image_processing_utils import BaseImageProcessor
-from PIL import Image, ImageOps
 import numpy as np
 import numpy as np
 import cv2
 import cv2
 import albumentations as alb
 import albumentations as alb
 from albumentations.pytorch import ToTensorV2
 from albumentations.pytorch import ToTensorV2
-from torchvision.transforms.functional import resize
 
 
 
 
 # TODO: dereference cv2 if possible
 # TODO: dereference cv2 if possible
 class UnimerSwinImageProcessor(BaseImageProcessor):
 class UnimerSwinImageProcessor(BaseImageProcessor):
     def __init__(
     def __init__(
             self,
             self,
-            image_size = [192, 672],
+            image_size = (192, 672),
         ):
         ):
         self.input_size = [int(_) for _ in image_size]
         self.input_size = [int(_) for _ in image_size]
         assert len(self.input_size) == 2
         assert len(self.input_size) == 2
@@ -27,56 +25,90 @@ class UnimerSwinImageProcessor(BaseImageProcessor):
 
 
     def __call__(self, item):
     def __call__(self, item):
         image = self.prepare_input(item)
         image = self.prepare_input(item)
-        return self.transform(image=np.array(image))['image'][:1]
+        return self.transform(image=image)['image'][:1]
 
 
     @staticmethod
     @staticmethod
-    def crop_margin(img: Image.Image) -> Image.Image:
-        data = np.array(img.convert("L"))
-        data = data.astype(np.uint8)
-        max_val = data.max()
-        min_val = data.min()
-        if max_val == min_val:
+    def crop_margin_numpy(img: np.ndarray) -> np.ndarray:
+        """Crop margins of image using NumPy operations"""
+        # Convert to grayscale if it's a color image
+        if len(img.shape) == 3 and img.shape[2] == 3:
+            gray = cv2.cvtColor(img, cv2.COLOR_RGB2GRAY)
+        else:
+            gray = img.copy()
+
+        # Normalize and threshold
+        if gray.max() == gray.min():
             return img
             return img
-        data = (data - min_val) / (max_val - min_val) * 255
-        gray = 255 * (data < 200).astype(np.uint8)
 
 
-        coords = cv2.findNonZero(gray)  # Find all non-zero points (text)
-        a, b, w, h = cv2.boundingRect(coords)  # Find minimum spanning bounding box
-        return img.crop((a, b, w + a, h + b))
+        normalized = (((gray - gray.min()) / (gray.max() - gray.min())) * 255).astype(np.uint8)
+        binary = 255 * (normalized < 200).astype(np.uint8)
+
+        # Find bounding box
+        coords = cv2.findNonZero(binary)  # Find all non-zero points (text)
+        x, y, w, h = cv2.boundingRect(coords)  # Find minimum spanning bounding box
 
 
-    def prepare_input(self, img: Image.Image, random_padding: bool = False):
+        # Return cropped image
+        return img[y:y + h, x:x + w]
+
+    def prepare_input(self, img, random_padding: bool = False):
         """
         """
-        Convert PIL Image to tensor according to specified input_size after following steps below:
-            - resize
-            - rotate (if align_long_axis is True and image is not aligned longer axis with canvas)
-            - pad
+        Convert PIL Image or numpy array to properly sized and padded image after:
+            - crop margins
+            - resize while maintaining aspect ratio
+            - pad to target size
         """
         """
         if img is None:
         if img is None:
-            return
-        # crop margins
+            return None
+
         try:
         try:
-            img = self.crop_margin(img.convert("RGB"))
-        except OSError:
+            img = self.crop_margin_numpy(img)
+        except Exception:
             # might throw an error for broken files
             # might throw an error for broken files
-            return
+            return None
+
+        if img.shape[0] == 0 or img.shape[1] == 0:
+            return None
+
+        # Resize while preserving aspect ratio
+        h, w = img.shape[:2]
+        scale = min(self.input_size[0] / h, self.input_size[1] / w)
+        new_h, new_w = int(h * scale), int(w * scale)
+        resized_img = cv2.resize(img, (new_w, new_h), interpolation=cv2.INTER_AREA)
+
+        # Calculate padding
+        pad_width, pad_height = self._get_padding_values(new_w, new_h, random_padding)
+
+        # Create and apply padding
+        channels = 3 if len(img.shape) == 3 else 1
+        padded_img = np.full((self.input_size[0], self.input_size[1], channels), 255, dtype=np.uint8)
+        padded_img[pad_height:pad_height + new_h, pad_width:pad_width + new_w] = resized_img
+
+        return padded_img
+
+    def _calculate_padding(self, new_w, new_h, random_padding):
+        """Calculate padding values for PIL images"""
+        delta_width = self.input_size[1] - new_w
+        delta_height = self.input_size[0] - new_h
+
+        pad_width, pad_height = self._get_padding_values(new_w, new_h, random_padding)
 
 
-        if img.height == 0 or img.width == 0:
-            return
+        return (
+            pad_width,
+            pad_height,
+            delta_width - pad_width,
+            delta_height - pad_height,
+        )
+
+    def _get_padding_values(self, new_w, new_h, random_padding):
+        """Get padding values based on image dimensions and padding strategy"""
+        delta_width = self.input_size[1] - new_w
+        delta_height = self.input_size[0] - new_h
 
 
-        img = resize(img, min(self.input_size))
-        img.thumbnail((self.input_size[1], self.input_size[0]))
-        delta_width = self.input_size[1] - img.width
-        delta_height = self.input_size[0] - img.height
         if random_padding:
         if random_padding:
             pad_width = np.random.randint(low=0, high=delta_width + 1)
             pad_width = np.random.randint(low=0, high=delta_width + 1)
             pad_height = np.random.randint(low=0, high=delta_height + 1)
             pad_height = np.random.randint(low=0, high=delta_height + 1)
         else:
         else:
             pad_width = delta_width // 2
             pad_width = delta_width // 2
             pad_height = delta_height // 2
             pad_height = delta_height // 2
-        padding = (
-            pad_width,
-            pad_height,
-            delta_width - pad_width,
-            delta_height - pad_height,
-        )
-        return ImageOps.expand(img, padding)
+
+        return pad_width, pad_height

+ 1 - 1
magic_pdf/pdf_parse_union_core_v2.py

@@ -492,7 +492,7 @@ def insert_lines_into_block(block_bbox, line_height, page_w, page_h):
     else:
     else:
         return [[x0, y0, x1, y1]]
         return [[x0, y0, x1, y1]]
 
 
-# @measure_time
+
 def sort_lines_by_model(fix_blocks, page_w, page_h, line_height):
 def sort_lines_by_model(fix_blocks, page_w, page_h, line_height):
     page_line_list = []
     page_line_list = []
 
 

+ 1 - 1
magic_pdf/resources/model_config/model_configs.yaml

@@ -2,7 +2,7 @@ weights:
   layoutlmv3: Layout/LayoutLMv3/model_final.pth
   layoutlmv3: Layout/LayoutLMv3/model_final.pth
   doclayout_yolo: Layout/YOLO/doclayout_yolo_docstructbench_imgsz1280_2501.pt
   doclayout_yolo: Layout/YOLO/doclayout_yolo_docstructbench_imgsz1280_2501.pt
   yolo_v8_mfd: MFD/YOLO/yolo_v8_ft.pt
   yolo_v8_mfd: MFD/YOLO/yolo_v8_ft.pt
-  unimernet_small: MFR/unimernet_small_2501
+  unimernet_small: MFR/unimernet_hf_small_2503
   struct_eqtable: TabRec/StructEqTable
   struct_eqtable: TabRec/StructEqTable
   tablemaster: TabRec/TableMaster
   tablemaster: TabRec/TableMaster
   rapid_table: TabRec/RapidTable
   rapid_table: TabRec/RapidTable

+ 3 - 2
requirements.txt

@@ -7,7 +7,8 @@ numpy>=1.21.6,<2.0.0
 pydantic>=2.7.2
 pydantic>=2.7.2
 PyMuPDF>=1.24.9,<=1.24.14
 PyMuPDF>=1.24.9,<=1.24.14
 scikit-learn>=1.0.2
 scikit-learn>=1.0.2
-torch>=2.2.2
-transformers
+torch>=2.2.2,!=2.5.0,!=2.5.1,<=2.6.0
+torchvision
+transformers>=4.49.0
 pdfminer.six==20231228
 pdfminer.six==20231228
 # The requirements.txt must ensure that only necessary external dependencies are introduced. If there are new dependencies to add, please contact the project administrator.
 # The requirements.txt must ensure that only necessary external dependencies are introduced. If there are new dependencies to add, please contact the project administrator.

+ 1 - 7
setup.py

@@ -36,25 +36,19 @@ if __name__ == '__main__':
                      "paddlepaddle==3.0.0b1;platform_system=='Linux'",
                      "paddlepaddle==3.0.0b1;platform_system=='Linux'",
                      "paddlepaddle==2.6.1;platform_system=='Windows' or platform_system=='Darwin'",
                      "paddlepaddle==2.6.1;platform_system=='Windows' or platform_system=='Darwin'",
                      ],
                      ],
-            "full": ["unimernet==0.2.3",  # unimernet升级0.2.3,移除torchtext/eva-decord的依赖
-                     "torch>=2.2.2,<=2.3.1",  # torch2.4.0及之后版本未测试,先卡住版本上限
-                     "torchvision>=0.17.2,<=0.18.1",  # torchvision 受torch版本约束
+            "full": [
                      "matplotlib<=3.9.0;platform_system=='Windows'",  # 3.9.1及之后不提供windows的预编译包,避免一些没有编译环境的windows设备安装失败
                      "matplotlib<=3.9.0;platform_system=='Windows'",  # 3.9.1及之后不提供windows的预编译包,避免一些没有编译环境的windows设备安装失败
                      "matplotlib;platform_system=='Linux' or platform_system=='Darwin'",  # linux 和 macos 不应限制matplotlib的最高版本,以避免无法更新导致的一些bug
                      "matplotlib;platform_system=='Linux' or platform_system=='Darwin'",  # linux 和 macos 不应限制matplotlib的最高版本,以避免无法更新导致的一些bug
                      "ultralytics>=8.3.48",  # yolov8,公式检测
                      "ultralytics>=8.3.48",  # yolov8,公式检测
                      "paddleocr==2.7.3",  # 2.8.0及2.8.1版本与detectron2有冲突,需锁定2.7.3
                      "paddleocr==2.7.3",  # 2.8.0及2.8.1版本与detectron2有冲突,需锁定2.7.3
                      "paddlepaddle==3.0.0rc1;platform_system=='Linux' or platform_system=='Darwin'",  # 解决linux的段异常问题
                      "paddlepaddle==3.0.0rc1;platform_system=='Linux' or platform_system=='Darwin'",  # 解决linux的段异常问题
                      "paddlepaddle==2.6.1;platform_system=='Windows'",  # windows版本3.0.0效率下降,需锁定2.6.1
                      "paddlepaddle==2.6.1;platform_system=='Windows'",  # windows版本3.0.0效率下降,需锁定2.6.1
-                     "struct-eqtable==0.3.2",  # 表格解析
-                     "einops",  # struct-eqtable依赖
-                     "accelerate",  # struct-eqtable依赖
                      "doclayout_yolo==0.0.2b1",  # doclayout_yolo
                      "doclayout_yolo==0.0.2b1",  # doclayout_yolo
                      "rapidocr-paddle>=1.4.5,<2.0.0",  # rapidocr-paddle
                      "rapidocr-paddle>=1.4.5,<2.0.0",  # rapidocr-paddle
                      "rapidocr_onnxruntime>=1.4.4,<2.0.0",
                      "rapidocr_onnxruntime>=1.4.4,<2.0.0",
                      "rapid_table>=1.0.3,<2.0.0",  # rapid_table
                      "rapid_table>=1.0.3,<2.0.0",  # rapid_table
                      "PyYAML",  # yaml
                      "PyYAML",  # yaml
                      "openai",  # openai SDK
                      "openai",  # openai SDK
-                     "detectron2"
                      ],
                      ],
             "old_linux":[
             "old_linux":[
                 "albumentations<=1.4.20", # 1.4.21引入的simsimd不支持2019年及更早的linux系统
                 "albumentations<=1.4.20", # 1.4.21引入的simsimd不支持2019年及更早的linux系统