Просмотр исходного кода

refactor(magic_pdf): support mps device and optimize image processing

- Add support for Apple M1 chips (mps device)
- Refactor image processing for better performance and compatibility
- Update model loading and inference for various devices
- Adjust batch processing and memory management
myhloli 8 месяцев назад
Родитель
Сommit
af27c0cc81

+ 7 - 6
magic_pdf/model/doc_analyze_by_custom_model.py

@@ -256,27 +256,28 @@ def may_batch_image_analyze(
     batch_ratio = 1
     device = get_device()
 
-    npu_support = False
     if str(device).startswith('npu'):
         import torch_npu
         if torch_npu.npu.is_available():
-            npu_support = True
             torch.npu.set_compile_mode(jit_compile=False)
 
-    if torch.cuda.is_available() and device != 'cpu' or npu_support:
+    if str(device).startswith('npu') or str(device).startswith('cuda'):
         gpu_memory = int(os.getenv('VIRTUAL_VRAM_SIZE', round(get_vram(device))))
-        if gpu_memory is not None and gpu_memory >= 8:
+        if gpu_memory is not None:
             if gpu_memory >= 20:
                 batch_ratio = 16
             elif gpu_memory >= 15:
                 batch_ratio = 8
             elif gpu_memory >= 10:
                 batch_ratio = 4
-            else:
+            elif gpu_memory >= 7:
                 batch_ratio = 2
-
+            else:
+                batch_ratio = 1
             logger.info(f'gpu_memory: {gpu_memory} GB, batch_ratio: {batch_ratio}')
             batch_analyze = True
+    elif str(device).startswith('mps'):
+        batch_analyze = True
     doc_analyze_start = time.time()
 
     if batch_analyze:

+ 1 - 1
magic_pdf/model/pdf_extract_kit.py

@@ -118,7 +118,7 @@ class CustomPEKModel:
                 atom_model_name=AtomicModel.MFR,
                 mfr_weight_dir=mfr_weight_dir,
                 mfr_cfg_path=mfr_cfg_path,
-                device='cpu' if str(self.device).startswith("mps") else self.device,
+                device=self.device,
             )
 
         # 初始化layout模型

+ 0 - 2
magic_pdf/model/sub_modules/language_detection/yolov11/YOLOv11.py

@@ -44,7 +44,6 @@ def split_images(image, result_images=None):
             # 判断裁剪区域是否超出图片范围,如果超出则不进行裁剪保存操作
             if x + new_long_side > width:
                 continue
-            box = (x, 0, x + new_long_side, height)
             sub_image = image[0:height, x:x + new_long_side]
             sub_images.append(sub_image)
     else:  # 如果高度是较长边
@@ -52,7 +51,6 @@ def split_images(image, result_images=None):
             # 判断裁剪区域是否超出图片范围,如果超出则不进行裁剪保存操作
             if y + new_long_side > height:
                 continue
-            box = (0, y, width, y + new_long_side)
             sub_image = image[y:y + new_long_side, 0:width]
             sub_images.append(sub_image)
 

+ 2 - 0
magic_pdf/model/sub_modules/layout/doclayout_yolo/DocLayoutYOLO.py

@@ -4,6 +4,8 @@ from doclayout_yolo import YOLOv10
 class DocLayoutYOLOModel(object):
     def __init__(self, weight, device):
         self.model = YOLOv10(weight)
+        if not device.startswith("cpu"):
+            self.model.half()
         self.device = device
 
     def predict(self, image):

+ 2 - 0
magic_pdf/model/sub_modules/mfd/yolov8/YOLOv8.py

@@ -4,6 +4,8 @@ from ultralytics import YOLO
 class YOLOv8MFDModel(object):
     def __init__(self, weight, device="cpu"):
         self.mfd_model = YOLO(weight)
+        if not device.startswith("cpu"):
+            self.mfd_model.half()
         self.device = device
 
     def predict(self, image):

+ 21 - 47
magic_pdf/model/sub_modules/mfr/unimernet/Unimernet.py

@@ -1,13 +1,5 @@
-import argparse
-import os
-import re
-
 import torch
-import unimernet.tasks as tasks
 from torch.utils.data import DataLoader, Dataset
-from torchvision import transforms
-from unimernet.common.config import Config
-from unimernet.processors import load_processor
 
 
 class MathDataset(Dataset):
@@ -18,46 +10,26 @@ class MathDataset(Dataset):
     def __len__(self):
         return len(self.image_paths)
 
-
-def latex_rm_whitespace(s: str):
-    """Remove unnecessary whitespace from LaTeX code."""
-    text_reg = r"(\\(operatorname|mathrm|text|mathbf)\s?\*? {.*?})"
-    letter = "[a-zA-Z]"
-    noletter = "[\W_^\d]"
-    names = [x[0].replace(" ", "") for x in re.findall(text_reg, s)]
-    s = re.sub(text_reg, lambda match: str(names.pop(0)), s)
-    news = s
-    while True:
-        s = news
-        news = re.sub(r"(?!\\ )(%s)\s+?(%s)" % (noletter, noletter), r"\1\2", s)
-        news = re.sub(r"(?!\\ )(%s)\s+?(%s)" % (noletter, letter), r"\1\2", news)
-        news = re.sub(r"(%s)\s+?(%s)" % (letter, noletter), r"\1\2", news)
-        if news == s:
-            break
-    return s
+    def __getitem__(self, idx):
+        raw_image = self.image_paths[idx]
+        if self.transform:
+            image = self.transform(raw_image)
+            return image
 
 
 class UnimernetModel(object):
     def __init__(self, weight_dir, cfg_path, _device_="cpu"):
-        args = argparse.Namespace(cfg_path=cfg_path, options=None)
-        cfg = Config(args)
-        cfg.config.model.pretrained = os.path.join(weight_dir, "pytorch_model.pth")
-        cfg.config.model.model_config.model_name = weight_dir
-        cfg.config.model.tokenizer_config.path = weight_dir
-        task = tasks.setup_task(cfg)
-        self.model = task.build_model(cfg)
+        from .unimernet_hf import UnimernetModel
+        if _device_.startswith("mps"):
+            self.model = UnimernetModel.from_pretrained(weight_dir, attn_implementation="eager")
+        else:
+            self.model = UnimernetModel.from_pretrained(weight_dir)
         self.device = _device_
         self.model.to(_device_)
+        if not _device_.startswith("cpu"):
+            self.model = self.model.to(dtype=torch.float16)
         self.model.eval()
-        vis_processor = load_processor(
-            "formula_image_eval",
-            cfg.config.datasets.formula_rec_eval.vis_processor.eval,
-        )
-        self.mfr_transform = transforms.Compose(
-            [
-                vis_processor,
-            ]
-        )
+
 
     def predict(self, mfd_res, image):
         formula_list = []
@@ -76,16 +48,17 @@ class UnimernetModel(object):
             bbox_img = image[ymin:ymax, xmin:xmax]
             mf_image_list.append(bbox_img)
 
-        dataset = MathDataset(mf_image_list, transform=self.mfr_transform)
+        dataset = MathDataset(mf_image_list, transform=self.model.transform)
         dataloader = DataLoader(dataset, batch_size=32, num_workers=0)
         mfr_res = []
         for mf_img in dataloader:
+            mf_img = mf_img.to(dtype=self.model.dtype)
             mf_img = mf_img.to(self.device)
             with torch.no_grad():
                 output = self.model.generate({"image": mf_img})
-            mfr_res.extend(output["pred_str"])
+            mfr_res.extend(output["fixed_str"])
         for res, latex in zip(formula_list, mfr_res):
-            res["latex"] = latex_rm_whitespace(latex)
+            res["latex"] = latex
         return formula_list
 
     def batch_predict(self, images_mfd_res: list, images: list, batch_size: int = 64) -> list:
@@ -130,22 +103,23 @@ class UnimernetModel(object):
         index_mapping = {new_idx: old_idx for new_idx, old_idx in enumerate(sorted_indices)}
 
         # Create dataset with sorted images
-        dataset = MathDataset(sorted_images, transform=self.mfr_transform)
+        dataset = MathDataset(sorted_images, transform=self.model.transform)
         dataloader = DataLoader(dataset, batch_size=batch_size, num_workers=0)
 
         # Process batches and store results
         mfr_res = []
         for mf_img in dataloader:
+            mf_img = mf_img.to(dtype=self.model.dtype)
             mf_img = mf_img.to(self.device)
             with torch.no_grad():
                 output = self.model.generate({"image": mf_img})
-            mfr_res.extend(output["pred_str"])
+            mfr_res.extend(output["fixed_str"])
 
         # Restore original order
         unsorted_results = [""] * len(mfr_res)
         for new_idx, latex in enumerate(mfr_res):
             original_idx = index_mapping[new_idx]
-            unsorted_results[original_idx] = latex_rm_whitespace(latex)
+            unsorted_results[original_idx] = latex
 
         # Fill results back
         for res, latex in zip(backfill_list, unsorted_results):

+ 70 - 38
magic_pdf/model/sub_modules/mfr/unimernet/unimernet_hf/unimer_swin/image_processing_unimer_swin.py

@@ -1,17 +1,15 @@
 from transformers.image_processing_utils import BaseImageProcessor
-from PIL import Image, ImageOps
 import numpy as np
 import cv2
 import albumentations as alb
 from albumentations.pytorch import ToTensorV2
-from torchvision.transforms.functional import resize
 
 
 # TODO: dereference cv2 if possible
 class UnimerSwinImageProcessor(BaseImageProcessor):
     def __init__(
             self,
-            image_size = [192, 672],
+            image_size = (192, 672),
         ):
         self.input_size = [int(_) for _ in image_size]
         assert len(self.input_size) == 2
@@ -27,56 +25,90 @@ class UnimerSwinImageProcessor(BaseImageProcessor):
 
     def __call__(self, item):
         image = self.prepare_input(item)
-        return self.transform(image=np.array(image))['image'][:1]
+        return self.transform(image=image)['image'][:1]
 
     @staticmethod
-    def crop_margin(img: Image.Image) -> Image.Image:
-        data = np.array(img.convert("L"))
-        data = data.astype(np.uint8)
-        max_val = data.max()
-        min_val = data.min()
-        if max_val == min_val:
+    def crop_margin_numpy(img: np.ndarray) -> np.ndarray:
+        """Crop margins of image using NumPy operations"""
+        # Convert to grayscale if it's a color image
+        if len(img.shape) == 3 and img.shape[2] == 3:
+            gray = cv2.cvtColor(img, cv2.COLOR_RGB2GRAY)
+        else:
+            gray = img.copy()
+
+        # Normalize and threshold
+        if gray.max() == gray.min():
             return img
-        data = (data - min_val) / (max_val - min_val) * 255
-        gray = 255 * (data < 200).astype(np.uint8)
 
-        coords = cv2.findNonZero(gray)  # Find all non-zero points (text)
-        a, b, w, h = cv2.boundingRect(coords)  # Find minimum spanning bounding box
-        return img.crop((a, b, w + a, h + b))
+        normalized = (((gray - gray.min()) / (gray.max() - gray.min())) * 255).astype(np.uint8)
+        binary = 255 * (normalized < 200).astype(np.uint8)
+
+        # Find bounding box
+        coords = cv2.findNonZero(binary)  # Find all non-zero points (text)
+        x, y, w, h = cv2.boundingRect(coords)  # Find minimum spanning bounding box
 
-    def prepare_input(self, img: Image.Image, random_padding: bool = False):
+        # Return cropped image
+        return img[y:y + h, x:x + w]
+
+    def prepare_input(self, img, random_padding: bool = False):
         """
-        Convert PIL Image to tensor according to specified input_size after following steps below:
-            - resize
-            - rotate (if align_long_axis is True and image is not aligned longer axis with canvas)
-            - pad
+        Convert PIL Image or numpy array to properly sized and padded image after:
+            - crop margins
+            - resize while maintaining aspect ratio
+            - pad to target size
         """
         if img is None:
-            return
-        # crop margins
+            return None
+
         try:
-            img = self.crop_margin(img.convert("RGB"))
-        except OSError:
+            img = self.crop_margin_numpy(img)
+        except Exception:
             # might throw an error for broken files
-            return
+            return None
+
+        if img.shape[0] == 0 or img.shape[1] == 0:
+            return None
+
+        # Resize while preserving aspect ratio
+        h, w = img.shape[:2]
+        scale = min(self.input_size[0] / h, self.input_size[1] / w)
+        new_h, new_w = int(h * scale), int(w * scale)
+        resized_img = cv2.resize(img, (new_w, new_h), interpolation=cv2.INTER_AREA)
+
+        # Calculate padding
+        pad_width, pad_height = self._get_padding_values(new_w, new_h, random_padding)
+
+        # Create and apply padding
+        channels = 3 if len(img.shape) == 3 else 1
+        padded_img = np.full((self.input_size[0], self.input_size[1], channels), 255, dtype=np.uint8)
+        padded_img[pad_height:pad_height + new_h, pad_width:pad_width + new_w] = resized_img
+
+        return padded_img
+
+    def _calculate_padding(self, new_w, new_h, random_padding):
+        """Calculate padding values for PIL images"""
+        delta_width = self.input_size[1] - new_w
+        delta_height = self.input_size[0] - new_h
+
+        pad_width, pad_height = self._get_padding_values(new_w, new_h, random_padding)
 
-        if img.height == 0 or img.width == 0:
-            return
+        return (
+            pad_width,
+            pad_height,
+            delta_width - pad_width,
+            delta_height - pad_height,
+        )
+
+    def _get_padding_values(self, new_w, new_h, random_padding):
+        """Get padding values based on image dimensions and padding strategy"""
+        delta_width = self.input_size[1] - new_w
+        delta_height = self.input_size[0] - new_h
 
-        img = resize(img, min(self.input_size))
-        img.thumbnail((self.input_size[1], self.input_size[0]))
-        delta_width = self.input_size[1] - img.width
-        delta_height = self.input_size[0] - img.height
         if random_padding:
             pad_width = np.random.randint(low=0, high=delta_width + 1)
             pad_height = np.random.randint(low=0, high=delta_height + 1)
         else:
             pad_width = delta_width // 2
             pad_height = delta_height // 2
-        padding = (
-            pad_width,
-            pad_height,
-            delta_width - pad_width,
-            delta_height - pad_height,
-        )
-        return ImageOps.expand(img, padding)
+
+        return pad_width, pad_height

+ 1 - 1
magic_pdf/pdf_parse_union_core_v2.py

@@ -492,7 +492,7 @@ def insert_lines_into_block(block_bbox, line_height, page_w, page_h):
     else:
         return [[x0, y0, x1, y1]]
 
-# @measure_time
+
 def sort_lines_by_model(fix_blocks, page_w, page_h, line_height):
     page_line_list = []
 

+ 1 - 1
magic_pdf/resources/model_config/model_configs.yaml

@@ -2,7 +2,7 @@ weights:
   layoutlmv3: Layout/LayoutLMv3/model_final.pth
   doclayout_yolo: Layout/YOLO/doclayout_yolo_docstructbench_imgsz1280_2501.pt
   yolo_v8_mfd: MFD/YOLO/yolo_v8_ft.pt
-  unimernet_small: MFR/unimernet_small_2501
+  unimernet_small: MFR/unimernet_hf_small_2503
   struct_eqtable: TabRec/StructEqTable
   tablemaster: TabRec/TableMaster
   rapid_table: TabRec/RapidTable

+ 3 - 2
requirements.txt

@@ -7,7 +7,8 @@ numpy>=1.21.6,<2.0.0
 pydantic>=2.7.2
 PyMuPDF>=1.24.9,<=1.24.14
 scikit-learn>=1.0.2
-torch>=2.2.2
-transformers
+torch>=2.2.2,!=2.5.0,!=2.5.1,<=2.6.0
+torchvision
+transformers>=4.49.0
 pdfminer.six==20231228
 # The requirements.txt must ensure that only necessary external dependencies are introduced. If there are new dependencies to add, please contact the project administrator.

+ 1 - 7
setup.py

@@ -36,25 +36,19 @@ if __name__ == '__main__':
                      "paddlepaddle==3.0.0b1;platform_system=='Linux'",
                      "paddlepaddle==2.6.1;platform_system=='Windows' or platform_system=='Darwin'",
                      ],
-            "full": ["unimernet==0.2.3",  # unimernet升级0.2.3,移除torchtext/eva-decord的依赖
-                     "torch>=2.2.2,<=2.3.1",  # torch2.4.0及之后版本未测试,先卡住版本上限
-                     "torchvision>=0.17.2,<=0.18.1",  # torchvision 受torch版本约束
+            "full": [
                      "matplotlib<=3.9.0;platform_system=='Windows'",  # 3.9.1及之后不提供windows的预编译包,避免一些没有编译环境的windows设备安装失败
                      "matplotlib;platform_system=='Linux' or platform_system=='Darwin'",  # linux 和 macos 不应限制matplotlib的最高版本,以避免无法更新导致的一些bug
                      "ultralytics>=8.3.48",  # yolov8,公式检测
                      "paddleocr==2.7.3",  # 2.8.0及2.8.1版本与detectron2有冲突,需锁定2.7.3
                      "paddlepaddle==3.0.0rc1;platform_system=='Linux' or platform_system=='Darwin'",  # 解决linux的段异常问题
                      "paddlepaddle==2.6.1;platform_system=='Windows'",  # windows版本3.0.0效率下降,需锁定2.6.1
-                     "struct-eqtable==0.3.2",  # 表格解析
-                     "einops",  # struct-eqtable依赖
-                     "accelerate",  # struct-eqtable依赖
                      "doclayout_yolo==0.0.2b1",  # doclayout_yolo
                      "rapidocr-paddle>=1.4.5,<2.0.0",  # rapidocr-paddle
                      "rapidocr_onnxruntime>=1.4.4,<2.0.0",
                      "rapid_table>=1.0.3,<2.0.0",  # rapid_table
                      "PyYAML",  # yaml
                      "openai",  # openai SDK
-                     "detectron2"
                      ],
             "old_linux":[
                 "albumentations<=1.4.20", # 1.4.21引入的simsimd不支持2019年及更早的linux系统