浏览代码

feat(model): improve batch analysis logic and support npu

- Add support for NPU (Neural Processing Unit) when available
- Implement batch analysis for GPU and NPU devices
- Optimize memory usage and improve performance
- Update logging and error handling
myhloli 10 月之前
父节点
当前提交
f350222614
共有 3 个文件被更改,包括 154 次插入104 次删除
  1. 87 85
      magic_pdf/model/batch_analyze.py
  2. 66 18
      magic_pdf/model/doc_analyze_by_custom_model.py
  3. 1 1
      magic_pdf/model/pdf_extract_kit.py

+ 87 - 85
magic_pdf/model/batch_analyze.py

@@ -7,17 +7,17 @@ from loguru import logger
 from PIL import Image
 
 from magic_pdf.config.constants import MODEL_NAME
-from magic_pdf.config.exceptions import CUDA_NOT_AVAILABLE
-from magic_pdf.data.dataset import Dataset
-from magic_pdf.libs.clean_memory import clean_memory
-from magic_pdf.libs.config_reader import get_device
-from magic_pdf.model.doc_analyze_by_custom_model import ModelSingleton
+# from magic_pdf.config.exceptions import CUDA_NOT_AVAILABLE
+# from magic_pdf.data.dataset import Dataset
+# from magic_pdf.libs.clean_memory import clean_memory
+# from magic_pdf.libs.config_reader import get_device
+# from magic_pdf.model.doc_analyze_by_custom_model import ModelSingleton
 from magic_pdf.model.pdf_extract_kit import CustomPEKModel
 from magic_pdf.model.sub_modules.model_utils import (
     clean_vram, crop_img, get_res_list_from_layout_res)
 from magic_pdf.model.sub_modules.ocr.paddleocr.ocr_utils import (
     get_adjusted_mfdetrec_res, get_ocr_result_list)
-from magic_pdf.operators.models import InferenceResult
+# from magic_pdf.operators.models import InferenceResult
 
 YOLO_LAYOUT_BASE_BATCH_SIZE = 4
 MFD_BASE_BATCH_SIZE = 1
@@ -91,10 +91,12 @@ class BatchAnalyze:
                 images,
                 batch_size=self.batch_ratio * MFR_BASE_BATCH_SIZE,
             )
+            mfr_count = 0
             for image_index in range(len(images)):
                 images_layout_res[image_index] += images_formula_list[image_index]
+                mfr_count += len(images_formula_list[image_index])
             logger.info(
-                f'mfr time: {round(time.time() - mfr_start_time, 2)}, image num: {len(images)}'
+                f'mfr time: {round(time.time() - mfr_start_time, 2)}, image num: {mfr_count}'
             )
 
         # 清理显存
@@ -195,81 +197,81 @@ class BatchAnalyze:
         return images_layout_res
 
 
-def doc_batch_analyze(
-    dataset: Dataset,
-    ocr: bool = False,
-    show_log: bool = False,
-    start_page_id=0,
-    end_page_id=None,
-    lang=None,
-    layout_model=None,
-    formula_enable=None,
-    table_enable=None,
-    batch_ratio: int | None = None,
-) -> InferenceResult:
-    """Perform batch analysis on a document dataset.
-
-    Args:
-        dataset (Dataset): The dataset containing document pages to be analyzed.
-        ocr (bool, optional): Flag to enable OCR (Optical Character Recognition). Defaults to False.
-        show_log (bool, optional): Flag to enable logging. Defaults to False.
-        start_page_id (int, optional): The starting page ID for analysis. Defaults to 0.
-        end_page_id (int, optional): The ending page ID for analysis. Defaults to None, which means analyze till the last page.
-        lang (str, optional): Language for OCR. Defaults to None.
-        layout_model (optional): Layout model to be used for analysis. Defaults to None.
-        formula_enable (optional): Flag to enable formula detection. Defaults to None.
-        table_enable (optional): Flag to enable table detection. Defaults to None.
-        batch_ratio (int | None, optional): Ratio for batch processing. Defaults to None, which sets it to 1.
-
-    Raises:
-        CUDA_NOT_AVAILABLE: If CUDA is not available, raises an exception as batch analysis is not supported in CPU mode.
-
-    Returns:
-        InferenceResult: The result of the batch analysis containing the analyzed data and the dataset.
-    """
-
-    if not torch.cuda.is_available():
-        raise CUDA_NOT_AVAILABLE('batch analyze not support in CPU mode')
-
-    lang = None if lang == '' else lang
-    # TODO: auto detect batch size
-    batch_ratio = 1 if batch_ratio is None else batch_ratio
-    end_page_id = end_page_id if end_page_id else len(dataset)
-
-    model_manager = ModelSingleton()
-    custom_model: CustomPEKModel = model_manager.get_model(
-        ocr, show_log, lang, layout_model, formula_enable, table_enable
-    )
-    batch_model = BatchAnalyze(model=custom_model, batch_ratio=batch_ratio)
-
-    model_json = []
-
-    # batch analyze
-    images = []
-    for index in range(len(dataset)):
-        if start_page_id <= index <= end_page_id:
-            page_data = dataset.get_page(index)
-            img_dict = page_data.get_image()
-            images.append(img_dict['img'])
-    analyze_result = batch_model(images)
-
-    for index in range(len(dataset)):
-        page_data = dataset.get_page(index)
-        img_dict = page_data.get_image()
-        page_width = img_dict['width']
-        page_height = img_dict['height']
-        if start_page_id <= index <= end_page_id:
-            result = analyze_result.pop(0)
-        else:
-            result = []
-
-        page_info = {'page_no': index, 'height': page_height, 'width': page_width}
-        page_dict = {'layout_dets': result, 'page_info': page_info}
-        model_json.append(page_dict)
-
-    # TODO: clean memory when gpu memory is not enough
-    clean_memory_start_time = time.time()
-    clean_memory(get_device())
-    logger.info(f'clean memory time: {round(time.time() - clean_memory_start_time, 2)}')
-
-    return InferenceResult(model_json, dataset)
+# def doc_batch_analyze(
+#     dataset: Dataset,
+#     ocr: bool = False,
+#     show_log: bool = False,
+#     start_page_id=0,
+#     end_page_id=None,
+#     lang=None,
+#     layout_model=None,
+#     formula_enable=None,
+#     table_enable=None,
+#     batch_ratio: int | None = None,
+# ) -> InferenceResult:
+#     """Perform batch analysis on a document dataset.
+#
+#     Args:
+#         dataset (Dataset): The dataset containing document pages to be analyzed.
+#         ocr (bool, optional): Flag to enable OCR (Optical Character Recognition). Defaults to False.
+#         show_log (bool, optional): Flag to enable logging. Defaults to False.
+#         start_page_id (int, optional): The starting page ID for analysis. Defaults to 0.
+#         end_page_id (int, optional): The ending page ID for analysis. Defaults to None, which means analyze till the last page.
+#         lang (str, optional): Language for OCR. Defaults to None.
+#         layout_model (optional): Layout model to be used for analysis. Defaults to None.
+#         formula_enable (optional): Flag to enable formula detection. Defaults to None.
+#         table_enable (optional): Flag to enable table detection. Defaults to None.
+#         batch_ratio (int | None, optional): Ratio for batch processing. Defaults to None, which sets it to 1.
+#
+#     Raises:
+#         CUDA_NOT_AVAILABLE: If CUDA is not available, raises an exception as batch analysis is not supported in CPU mode.
+#
+#     Returns:
+#         InferenceResult: The result of the batch analysis containing the analyzed data and the dataset.
+#     """
+#
+#     if not torch.cuda.is_available():
+#         raise CUDA_NOT_AVAILABLE('batch analyze not support in CPU mode')
+#
+#     lang = None if lang == '' else lang
+#     # TODO: auto detect batch size
+#     batch_ratio = 1 if batch_ratio is None else batch_ratio
+#     end_page_id = end_page_id if end_page_id else len(dataset)
+#
+#     model_manager = ModelSingleton()
+#     custom_model: CustomPEKModel = model_manager.get_model(
+#         ocr, show_log, lang, layout_model, formula_enable, table_enable
+#     )
+#     batch_model = BatchAnalyze(model=custom_model, batch_ratio=batch_ratio)
+#
+#     model_json = []
+#
+#     # batch analyze
+#     images = []
+#     for index in range(len(dataset)):
+#         if start_page_id <= index <= end_page_id:
+#             page_data = dataset.get_page(index)
+#             img_dict = page_data.get_image()
+#             images.append(img_dict['img'])
+#     analyze_result = batch_model(images)
+#
+#     for index in range(len(dataset)):
+#         page_data = dataset.get_page(index)
+#         img_dict = page_data.get_image()
+#         page_width = img_dict['width']
+#         page_height = img_dict['height']
+#         if start_page_id <= index <= end_page_id:
+#             result = analyze_result.pop(0)
+#         else:
+#             result = []
+#
+#         page_info = {'page_no': index, 'height': page_height, 'width': page_width}
+#         page_dict = {'layout_dets': result, 'page_info': page_info}
+#         model_json.append(page_dict)
+#
+#     # TODO: clean memory when gpu memory is not enough
+#     clean_memory_start_time = time.time()
+#     clean_memory(get_device())
+#     logger.info(f'clean memory time: {round(time.time() - clean_memory_start_time, 2)}')
+#
+#     return InferenceResult(model_json, dataset)

+ 66 - 18
magic_pdf/model/doc_analyze_by_custom_model.py

@@ -3,8 +3,12 @@ import time
 
 # 关闭paddle的信号处理
 import paddle
+import torch
 from loguru import logger
 
+from magic_pdf.model.batch_analyze import BatchAnalyze
+from magic_pdf.model.sub_modules.model_utils import get_vram
+
 paddle.disable_signal_handler()
 
 os.environ['NO_ALBUMENTATIONS_UPDATE'] = '1'  # 禁止albumentations检查更新
@@ -154,33 +158,77 @@ def doc_analyze(
     table_enable=None,
 ) -> InferenceResult:
 
+    end_page_id = end_page_id if end_page_id else len(dataset)
+
     model_manager = ModelSingleton()
     custom_model = model_manager.get_model(
         ocr, show_log, lang, layout_model, formula_enable, table_enable
     )
 
+    batch_analyze = False
+    device = get_device()
+
+    npu_support = False
+    if str(device).startswith("npu"):
+        import torch_npu
+        if torch_npu.npu.is_available():
+            npu_support = True
+
+    if torch.cuda.is_available() and device != 'cpu' or npu_support:
+        gpu_memory = get_vram(device)
+        if gpu_memory is not None and gpu_memory >= 7:
+            batch_ratio = int((gpu_memory-3) // 1.5)
+            if batch_ratio >= 1:
+                logger.info(f'gpu_memory: {gpu_memory} GB, batch_ratio: {batch_ratio}')
+                batch_model = BatchAnalyze(model=custom_model, batch_ratio=batch_ratio)
+                batch_analyze = True
+
     model_json = []
     doc_analyze_start = time.time()
 
-    if end_page_id is None:
-        end_page_id = len(dataset)
-
-    for index in range(len(dataset)):
-        page_data = dataset.get_page(index)
-        img_dict = page_data.get_image()
-        img = img_dict['img']
-        page_width = img_dict['width']
-        page_height = img_dict['height']
-        if start_page_id <= index <= end_page_id:
-            page_start = time.time()
-            result = custom_model(img)
-            logger.info(f'-----page_id : {index}, page total time: {round(time.time() - page_start, 2)}-----')
-        else:
-            result = []
+    if batch_analyze:
+        # batch analyze
+        images = []
+        for index in range(len(dataset)):
+            if start_page_id <= index <= end_page_id:
+                page_data = dataset.get_page(index)
+                img_dict = page_data.get_image()
+                images.append(img_dict['img'])
+        analyze_result = batch_model(images)
+
+        for index in range(len(dataset)):
+            page_data = dataset.get_page(index)
+            img_dict = page_data.get_image()
+            page_width = img_dict['width']
+            page_height = img_dict['height']
+            if start_page_id <= index <= end_page_id:
+                result = analyze_result.pop(0)
+            else:
+                result = []
+
+            page_info = {'page_no': index, 'height': page_height, 'width': page_width}
+            page_dict = {'layout_dets': result, 'page_info': page_info}
+            model_json.append(page_dict)
 
-        page_info = {'page_no': index, 'height': page_height, 'width': page_width}
-        page_dict = {'layout_dets': result, 'page_info': page_info}
-        model_json.append(page_dict)
+    else:
+        # single analyze
+
+        for index in range(len(dataset)):
+            page_data = dataset.get_page(index)
+            img_dict = page_data.get_image()
+            img = img_dict['img']
+            page_width = img_dict['width']
+            page_height = img_dict['height']
+            if start_page_id <= index <= end_page_id:
+                page_start = time.time()
+                result = custom_model(img)
+                logger.info(f'-----page_id : {index}, page total time: {round(time.time() - page_start, 2)}-----')
+            else:
+                result = []
+
+            page_info = {'page_no': index, 'height': page_height, 'width': page_width}
+            page_dict = {'layout_dets': result, 'page_info': page_info}
+            model_json.append(page_dict)
 
     gc_start = time.time()
     clean_memory(get_device())

+ 1 - 1
magic_pdf/model/pdf_extract_kit.py

@@ -228,7 +228,7 @@ class CustomPEKModel:
             logger.info(f'formula nums: {len(formula_list)}, mfr time: {mfr_cost}')
 
         # 清理显存
-        clean_vram(self.device, vram_threshold=8)
+        clean_vram(self.device, vram_threshold=6)
 
         # 从layout_res中获取ocr区域、表格区域、公式区域
         ocr_res_list, table_res_list, single_page_mfdetrec_res = (