Sfoglia il codice sorgente

feat: add batch prediction methods for YOLOv8 and Unimernet models

Suven 11 mesi fa
parent
commit
7ce9edc674

+ 7 - 0
magic_pdf/config/exceptions.py

@@ -30,3 +30,10 @@ class EmptyData(Exception):
 
     def __str__(self):
         return f'Empty data: {self.msg}'
+
+class CUDA_NOT_AVAILABLE(Exception):
+    def __init__(self, msg):
+        self.msg = msg
+
+    def __str__(self):
+        return f'CUDA not available: {self.msg}'

+ 228 - 0
magic_pdf/model/batch_analyze.py

@@ -0,0 +1,228 @@
+import time
+
+import cv2
+import numpy as np
+import torch
+from loguru import logger
+from PIL import Image
+
+from magic_pdf.config.constants import MODEL_NAME
+from magic_pdf.config.exceptions import CUDA_NOT_AVAILABLE
+from magic_pdf.data.dataset import Dataset
+from magic_pdf.libs.clean_memory import clean_memory
+from magic_pdf.model.doc_analyze_by_custom_model import ModelSingleton
+from magic_pdf.model.operators import InferenceResult
+from magic_pdf.model.pdf_extract_kit import CustomPEKModel
+from magic_pdf.model.sub_modules.model_utils import (
+    clean_vram,
+    crop_img,
+    get_res_list_from_layout_res,
+)
+from magic_pdf.model.sub_modules.ocr.paddleocr.ocr_utils import (
+    get_adjusted_mfdetrec_res,
+    get_ocr_result_list,
+)
+
+YOLO_LAYOUT_BASE_BATCH_SIZE = 4
+MFD_BASE_BATCH_SIZE = 1
+MFR_BASE_BATCH_SIZE = 16
+
+
+class BatchAnalyze:
+    def __init__(self, model: CustomPEKModel, batch_ratio: int):
+        self.model = model
+        self.batch_ratio = batch_ratio
+
+    def __call__(self, images: list) -> list:
+        if self.model.layout_model_name == MODEL_NAME.LAYOUTLMv3:
+            # layoutlmv3
+            images_layout_res = []
+            for image in images:
+                layout_res = self.model.layout_model(image, ignore_catids=[])
+                images_layout_res.append(layout_res)
+        elif self.model.layout_model_name == MODEL_NAME.DocLayout_YOLO:
+            # doclayout_yolo
+            images_layout_res = self.model.layout_model.batch_predict(
+                images, self.batch_ratio * YOLO_LAYOUT_BASE_BATCH_SIZE
+            )
+
+        if self.model.apply_formula:
+            # 公式检测
+            images_mfd_res = self.model.mfd_model.batch_predict(
+                images, self.batch_ratio * MFD_BASE_BATCH_SIZE
+            )
+
+            # 公式识别
+            images_formula_list = self.model.mfr_model.batch_predict(
+                images_mfd_res,
+                images,
+                batch_size=self.batch_ratio * MFR_BASE_BATCH_SIZE,
+            )
+            for image_index in range(len(images)):
+                images_layout_res[image_index] += images_formula_list[image_index]
+
+        # 清理显存
+        clean_vram(self.model.device, vram_threshold=8)
+
+        # reference: magic_pdf/model/doc_analyze_by_custom_model.py:doc_analyze
+        for index in range(len(images)):
+            layout_res = images_layout_res[index]
+            pil_img = Image.fromarray(images[index])
+
+            ocr_res_list, table_res_list, single_page_mfdetrec_res = (
+                get_res_list_from_layout_res(layout_res)
+            )
+            # ocr识别
+            ocr_start = time.time()
+            # Process each area that requires OCR processing
+            for res in ocr_res_list:
+                new_image, useful_list = crop_img(
+                    res, pil_img, crop_paste_x=50, crop_paste_y=50
+                )
+                adjusted_mfdetrec_res = get_adjusted_mfdetrec_res(
+                    single_page_mfdetrec_res, useful_list
+                )
+
+                # OCR recognition
+                new_image = cv2.cvtColor(np.asarray(new_image), cv2.COLOR_RGB2BGR)
+
+                if self.model.apply_ocr:
+                    ocr_res = self.model.ocr_model.ocr(
+                        new_image, mfd_res=adjusted_mfdetrec_res
+                    )[0]
+                else:
+                    ocr_res = self.model.ocr_model.ocr(
+                        new_image, mfd_res=adjusted_mfdetrec_res, rec=False
+                    )[0]
+
+                # Integration results
+                if ocr_res:
+                    ocr_result_list = get_ocr_result_list(ocr_res, useful_list)
+                    layout_res.extend(ocr_result_list)
+
+            ocr_cost = round(time.time() - ocr_start, 2)
+            if self.model.apply_ocr:
+                logger.info(f"ocr time: {ocr_cost}")
+            else:
+                logger.info(f"det time: {ocr_cost}")
+
+            # 表格识别 table recognition
+            if self.model.apply_table:
+                table_start = time.time()
+                for res in table_res_list:
+                    new_image, _ = crop_img(res, pil_img)
+                    single_table_start_time = time.time()
+                    html_code = None
+                    if self.model.table_model_name == MODEL_NAME.STRUCT_EQTABLE:
+                        with torch.no_grad():
+                            table_result = self.model.table_model.predict(
+                                new_image, "html"
+                            )
+                            if len(table_result) > 0:
+                                html_code = table_result[0]
+                    elif self.model.table_model_name == MODEL_NAME.TABLE_MASTER:
+                        html_code = self.model.table_model.img2html(new_image)
+                    elif self.model.table_model_name == MODEL_NAME.RAPID_TABLE:
+                        html_code, table_cell_bboxes, elapse = (
+                            self.model.table_model.predict(new_image)
+                        )
+                    run_time = time.time() - single_table_start_time
+                    if run_time > self.model.table_max_time:
+                        logger.warning(
+                            f"table recognition processing exceeds max time {self.model.table_max_time}s"
+                        )
+                    # 判断是否返回正常
+                    if html_code:
+                        expected_ending = html_code.strip().endswith(
+                            "</html>"
+                        ) or html_code.strip().endswith("</table>")
+                        if expected_ending:
+                            res["html"] = html_code
+                        else:
+                            logger.warning(
+                                "table recognition processing fails, not found expected HTML table end"
+                            )
+                    else:
+                        logger.warning(
+                            "table recognition processing fails, not get html return"
+                        )
+                logger.info(f"table time: {round(time.time() - table_start, 2)}")
+
+
+def doc_batch_analyze(
+    dataset: Dataset,
+    ocr: bool = False,
+    show_log: bool = False,
+    start_page_id=0,
+    end_page_id=None,
+    lang=None,
+    layout_model=None,
+    formula_enable=None,
+    table_enable=None,
+    batch_ratio: int | None = None,
+) -> InferenceResult:
+    """
+    Perform batch analysis on a document dataset.
+
+    Args:
+        dataset (Dataset): The dataset containing document pages to be analyzed.
+        ocr (bool, optional): Flag to enable OCR (Optical Character Recognition). Defaults to False.
+        show_log (bool, optional): Flag to enable logging. Defaults to False.
+        start_page_id (int, optional): The starting page ID for analysis. Defaults to 0.
+        end_page_id (int, optional): The ending page ID for analysis. Defaults to None, which means analyze till the last page.
+        lang (str, optional): Language for OCR. Defaults to None.
+        layout_model (optional): Layout model to be used for analysis. Defaults to None.
+        formula_enable (optional): Flag to enable formula detection. Defaults to None.
+        table_enable (optional): Flag to enable table detection. Defaults to None.
+        batch_ratio (int | None, optional): Ratio for batch processing. Defaults to None, which sets it to 1.
+
+    Raises:
+        CUDA_NOT_AVAILABLE: If CUDA is not available, raises an exception as batch analysis is not supported in CPU mode.
+
+    Returns:
+        InferenceResult: The result of the batch analysis containing the analyzed data and the dataset.
+    """
+
+    if not torch.cuda.is_available():
+        raise CUDA_NOT_AVAILABLE("batch analyze not support in CPU mode")
+
+    lang = None if lang == "" else lang
+    # TODO: auto detect batch size
+    batch_ratio = 1 if batch_ratio is None else batch_ratio
+    end_page_id = end_page_id if end_page_id else len(dataset)
+
+    model_manager = ModelSingleton()
+    custom_model: CustomPEKModel = model_manager.get_model(
+        ocr, show_log, lang, layout_model, formula_enable, table_enable
+    )
+    batch_model = BatchAnalyze(model=custom_model, batch_ratio=batch_ratio)
+
+    model_json = []
+
+    # batch analyze
+    images = []
+    for index in range(len(dataset)):
+        if start_page_id <= index <= end_page_id:
+            page_data = dataset.get_page(index)
+            img_dict = page_data.get_image()
+            images.append(img_dict["img"])
+    analyze_result = batch_model(images)
+
+    for index in range(len(dataset)):
+        page_data = dataset.get_page(index)
+        img_dict = page_data.get_image()
+        page_width = img_dict["width"]
+        page_height = img_dict["height"]
+        if start_page_id <= index <= end_page_id:
+            result = analyze_result.pop(0)
+        else:
+            result = []
+
+        page_info = {"page_no": index, "height": page_height, "width": page_width}
+        page_dict = {"layout_dets": result, "page_info": page_info}
+        model_json.append(page_dict)
+
+    # TODO: clean memory when gpu memory is not enough
+    clean_memory()
+
+    return InferenceResult(model_json, dataset)

+ 41 - 7
magic_pdf/model/sub_modules/layout/doclayout_yolo/DocLayoutYOLO.py

@@ -8,14 +8,48 @@ class DocLayoutYOLOModel(object):
 
     def predict(self, image):
         layout_res = []
-        doclayout_yolo_res = self.model.predict(image, imgsz=1024, conf=0.25, iou=0.45, verbose=True, device=self.device)[0]
-        for xyxy, conf, cla in zip(doclayout_yolo_res.boxes.xyxy.cpu(), doclayout_yolo_res.boxes.conf.cpu(),
-                                   doclayout_yolo_res.boxes.cls.cpu()):
+        doclayout_yolo_res = self.model.predict(
+            image, imgsz=1024, conf=0.25, iou=0.45, verbose=True, device=self.device
+        )[0]
+        for xyxy, conf, cla in zip(
+            doclayout_yolo_res.boxes.xyxy.cpu(),
+            doclayout_yolo_res.boxes.conf.cpu(),
+            doclayout_yolo_res.boxes.cls.cpu(),
+        ):
             xmin, ymin, xmax, ymax = [int(p.item()) for p in xyxy]
             new_item = {
-                'category_id': int(cla.item()),
-                'poly': [xmin, ymin, xmax, ymin, xmax, ymax, xmin, ymax],
-                'score': round(float(conf.item()), 3),
+                "category_id": int(cla.item()),
+                "poly": [xmin, ymin, xmax, ymin, xmax, ymax, xmin, ymax],
+                "score": round(float(conf.item()), 3),
             }
             layout_res.append(new_item)
-        return layout_res
+        return layout_res
+
+    def batch_predict(self, images: list, batch_size: int) -> list:
+        images_layout_res = []
+        for index in range(0, len(images), batch_size):
+            doclayout_yolo_res = self.model.predict(
+                images[index : index + batch_size],
+                imgsz=1024,
+                conf=0.25,
+                iou=0.45,
+                verbose=True,
+                device=self.device,
+            ).cpu()
+            for image_res in doclayout_yolo_res:
+                layout_res = []
+                for xyxy, conf, cla in zip(
+                    image_res.boxes.xyxy,
+                    image_res.boxes.conf,
+                    image_res.boxes.cls,
+                ):
+                    xmin, ymin, xmax, ymax = [int(p.item()) for p in xyxy]
+                    new_item = {
+                        "category_id": int(cla.item()),
+                        "poly": [xmin, ymin, xmax, ymin, xmax, ymax, xmin, ymax],
+                        "score": round(float(conf.item()), 3),
+                    }
+                    layout_res.append(new_item)
+                images_layout_res.append(layout_res)
+
+        return images_layout_res

+ 18 - 2
magic_pdf/model/sub_modules/mfd/yolov8/YOLOv8.py

@@ -2,11 +2,27 @@ from ultralytics import YOLO
 
 
 class YOLOv8MFDModel(object):
-    def __init__(self, weight, device='cpu'):
+    def __init__(self, weight, device="cpu"):
         self.mfd_model = YOLO(weight)
         self.device = device
 
     def predict(self, image):
-        mfd_res = self.mfd_model.predict(image, imgsz=1888, conf=0.25, iou=0.45, verbose=True, device=self.device)[0]
+        mfd_res = self.mfd_model.predict(
+            image, imgsz=1888, conf=0.25, iou=0.45, verbose=True, device=self.device
+        )[0]
         return mfd_res
 
+    def batch_predict(self, images: list, batch_size: int) -> list:
+        images_mfd_res = []
+        for index in range(0, len(images), batch_size):
+            mfd_res = self.mfd_model.predict(
+                images[index : index + batch_size],
+                imgsz=1888,
+                conf=0.25,
+                iou=0.45,
+                verbose=True,
+                device=self.device,
+            ).cpu()
+            for image_res in mfd_res:
+                images_mfd_res.append(image_res)
+        return images_mfd_res

+ 70 - 27
magic_pdf/model/sub_modules/mfr/unimernet/Unimernet.py

@@ -1,13 +1,13 @@
-import os
 import argparse
+import os
 import re
 
-from PIL import Image
 import torch
-from torch.utils.data import Dataset, DataLoader
+import unimernet.tasks as tasks
+from PIL import Image
+from torch.utils.data import DataLoader, Dataset
 from torchvision import transforms
 from unimernet.common.config import Config
-import unimernet.tasks as tasks
 from unimernet.processors import load_processor
 
 
@@ -31,27 +31,25 @@ class MathDataset(Dataset):
 
 
 def latex_rm_whitespace(s: str):
-    """Remove unnecessary whitespace from LaTeX code.
-    """
-    text_reg = r'(\\(operatorname|mathrm|text|mathbf)\s?\*? {.*?})'
-    letter = '[a-zA-Z]'
-    noletter = '[\W_^\d]'
-    names = [x[0].replace(' ', '') for x in re.findall(text_reg, s)]
+    """Remove unnecessary whitespace from LaTeX code."""
+    text_reg = r"(\\(operatorname|mathrm|text|mathbf)\s?\*? {.*?})"
+    letter = "[a-zA-Z]"
+    noletter = "[\W_^\d]"
+    names = [x[0].replace(" ", "") for x in re.findall(text_reg, s)]
     s = re.sub(text_reg, lambda match: str(names.pop(0)), s)
     news = s
     while True:
         s = news
-        news = re.sub(r'(?!\\ )(%s)\s+?(%s)' % (noletter, noletter), r'\1\2', s)
-        news = re.sub(r'(?!\\ )(%s)\s+?(%s)' % (noletter, letter), r'\1\2', news)
-        news = re.sub(r'(%s)\s+?(%s)' % (letter, noletter), r'\1\2', news)
+        news = re.sub(r"(?!\\ )(%s)\s+?(%s)" % (noletter, noletter), r"\1\2", s)
+        news = re.sub(r"(?!\\ )(%s)\s+?(%s)" % (noletter, letter), r"\1\2", news)
+        news = re.sub(r"(%s)\s+?(%s)" % (letter, noletter), r"\1\2", news)
         if news == s:
             break
     return s
 
 
 class UnimernetModel(object):
-    def __init__(self, weight_dir, cfg_path, _device_='cpu'):
-
+    def __init__(self, weight_dir, cfg_path, _device_="cpu"):
         args = argparse.Namespace(cfg_path=cfg_path, options=None)
         cfg = Config(args)
         cfg.config.model.pretrained = os.path.join(weight_dir, "pytorch_model.pth")
@@ -62,20 +60,28 @@ class UnimernetModel(object):
         self.device = _device_
         self.model.to(_device_)
         self.model.eval()
-        vis_processor = load_processor('formula_image_eval', cfg.config.datasets.formula_rec_eval.vis_processor.eval)
-        self.mfr_transform = transforms.Compose([vis_processor, ])
+        vis_processor = load_processor(
+            "formula_image_eval",
+            cfg.config.datasets.formula_rec_eval.vis_processor.eval,
+        )
+        self.mfr_transform = transforms.Compose(
+            [
+                vis_processor,
+            ]
+        )
 
     def predict(self, mfd_res, image):
-
         formula_list = []
         mf_image_list = []
-        for xyxy, conf, cla in zip(mfd_res.boxes.xyxy.cpu(), mfd_res.boxes.conf.cpu(), mfd_res.boxes.cls.cpu()):
+        for xyxy, conf, cla in zip(
+            mfd_res.boxes.xyxy.cpu(), mfd_res.boxes.conf.cpu(), mfd_res.boxes.cls.cpu()
+        ):
             xmin, ymin, xmax, ymax = [int(p.item()) for p in xyxy]
             new_item = {
-                'category_id': 13 + int(cla.item()),
-                'poly': [xmin, ymin, xmax, ymin, xmax, ymax, xmin, ymax],
-                'score': round(float(conf.item()), 2),
-                'latex': '',
+                "category_id": 13 + int(cla.item()),
+                "poly": [xmin, ymin, xmax, ymin, xmax, ymax, xmin, ymax],
+                "score": round(float(conf.item()), 2),
+                "latex": "",
             }
             formula_list.append(new_item)
             pil_img = Image.fromarray(image)
@@ -88,11 +94,48 @@ class UnimernetModel(object):
         for mf_img in dataloader:
             mf_img = mf_img.to(self.device)
             with torch.no_grad():
-                output = self.model.generate({'image': mf_img})
-            mfr_res.extend(output['pred_str'])
+                output = self.model.generate({"image": mf_img})
+            mfr_res.extend(output["pred_str"])
         for res, latex in zip(formula_list, mfr_res):
-            res['latex'] = latex_rm_whitespace(latex)
+            res["latex"] = latex_rm_whitespace(latex)
         return formula_list
 
+    def batch_predict(
+        self, images_mfd_res: list, images: list, batch_size: int = 64
+    ) -> list:
+        images_formula_list = []
+        mf_image_list = []
+        backfill_list = []
+        for image_index in range(len(images_mfd_res)):
+            mfd_res = images_mfd_res[image_index]
+            pil_img = Image.fromarray(images[image_index])
+            formula_list = []
+
+            for xyxy, conf, cla in zip(
+                mfd_res.boxes.xyxy, mfd_res.boxes.conf, mfd_res.boxes.cls
+            ):
+                xmin, ymin, xmax, ymax = [int(p.item()) for p in xyxy]
+                new_item = {
+                    "category_id": 13 + int(cla.item()),
+                    "poly": [xmin, ymin, xmax, ymin, xmax, ymax, xmin, ymax],
+                    "score": round(float(conf.item()), 2),
+                    "latex": "",
+                }
+                formula_list.append(new_item)
+                bbox_img = pil_img.crop((xmin, ymin, xmax, ymax))
+                mf_image_list.append(bbox_img)
+
+            images_formula_list.append(formula_list)
+            backfill_list += formula_list
 
-
+        dataset = MathDataset(mf_image_list, transform=self.mfr_transform)
+        dataloader = DataLoader(dataset, batch_size=batch_size, num_workers=0)
+        mfr_res = []
+        for mf_img in dataloader:
+            mf_img = mf_img.to(self.device)
+            with torch.no_grad():
+                output = self.model.generate({"image": mf_img})
+            mfr_res.extend(output["pred_str"])
+        for res, latex in zip(backfill_list, mfr_res):
+            res["latex"] = latex_rm_whitespace(latex)
+        return images_formula_list