Răsfoiți Sursa

Merge pull request #1453 from myhloli/dev

refactor(langdetect): simplify language detection model
Xiaomeng Zhao 10 luni în urmă
părinte
comite
aa53531670

+ 2 - 1
docs/README_Ascend_NPU_Acceleration_zh_CN.md

@@ -51,6 +51,7 @@ magic-pdf --help
 
 ## 已知问题
 
-- paddleocr使用内嵌onnx模型,仅支持中英文ocr,不支持其他语言ocr
+- paddleocr使用内嵌onnx模型,仅在默认语言配置下能以较快速度对中英文进行识别
+- 自定义lang参数时,paddleocr速度会存在明显下降情况
 - layout模型使用layoutlmv3时会发生间歇性崩溃,建议使用默认配置的doclayout_yolo模型
 - 表格解析仅适配了rapid_table模型,其他模型可能会无法使用

+ 1 - 0
magic_pdf/data/dataset.py

@@ -153,6 +153,7 @@ class PymuDocDataset(Dataset):
             logger.info(f"lang: {lang}, detect_lang: {self._lang}")
         else:
             self._lang = lang
+            logger.info(f"lang: {lang}")
     def __len__(self) -> int:
         """The page number of the pdf."""
         return len(self._records)

+ 1 - 0
magic_pdf/model/model_list.py

@@ -9,3 +9,4 @@ class AtomicModel:
     MFR = "mfr"
     OCR = "ocr"
     Table = "table"
+    LangDetect = "langdetect"

+ 24 - 11
magic_pdf/model/sub_modules/language_detection/utils.py

@@ -12,7 +12,6 @@ from magic_pdf.data.utils import load_images_from_pdf
 from magic_pdf.libs.config_reader import get_local_models_dir, get_device
 from magic_pdf.libs.pdf_check import extract_pages
 from magic_pdf.model.model_list import AtomicModel
-from magic_pdf.model.sub_modules.language_detection.yolov11.YOLOv11 import YOLOv11LangDetModel
 from magic_pdf.model.sub_modules.model_init import AtomModelSingleton
 
 
@@ -59,15 +58,29 @@ def get_text_images(simple_images):
 def auto_detect_lang(pdf_bytes: bytes):
     sample_docs = extract_pages(pdf_bytes)
     sample_pdf_bytes = sample_docs.tobytes()
-    simple_images = load_images_from_pdf(sample_pdf_bytes, dpi=96)
+    simple_images = load_images_from_pdf(sample_pdf_bytes, dpi=200)
     text_images = get_text_images(simple_images)
-    local_models_dir, device, configs = get_model_config()
-    # 用yolo11做语言分类
-    langdetect_model_weights = str(
-        os.path.join(
-            local_models_dir, configs['weights'][MODEL_NAME.YOLO_V11_LangDetect]
-        )
-    )
-    langdetect_model = YOLOv11LangDetModel(langdetect_model_weights, device)
+    langdetect_model = model_init(MODEL_NAME.YOLO_V11_LangDetect)
     lang = langdetect_model.do_detect(text_images)
-    return lang
+    return lang
+
+
+def model_init(model_name: str):
+    atom_model_manager = AtomModelSingleton()
+
+    if model_name == MODEL_NAME.YOLO_V11_LangDetect:
+        local_models_dir, device, configs = get_model_config()
+        model = atom_model_manager.get_atom_model(
+            atom_model_name=AtomicModel.LangDetect,
+            langdetect_model_name=MODEL_NAME.YOLO_V11_LangDetect,
+            langdetect_model_weight=str(
+                os.path.join(
+                    local_models_dir, configs['weights'][MODEL_NAME.YOLO_V11_LangDetect]
+                )
+            ),
+            device=device,
+        )
+    else:
+        raise ValueError(f"model_name {model_name} not found")
+    return model
+

+ 10 - 5
magic_pdf/model/sub_modules/language_detection/yolov11/YOLOv11.py

@@ -2,6 +2,7 @@
 from collections import Counter
 from uuid import uuid4
 
+import torch
 from PIL import Image
 from loguru import logger
 from ultralytics import YOLO
@@ -83,10 +84,14 @@ def resize_images_to_224(image):
 
 
 class YOLOv11LangDetModel(object):
-    def __init__(self, weight, device):
-        self.model = YOLO(weight)
-        self.device = device
+    def __init__(self, langdetect_model_weight, device):
 
+        self.model = YOLO(langdetect_model_weight)
+
+        if str(device).startswith("npu"):
+            self.device = torch.device(device)
+        else:
+            self.device = device
     def do_detect(self, images: list):
         all_images = []
         for image in images:
@@ -99,7 +104,7 @@ class YOLOv11LangDetModel(object):
                 all_images.append(resize_images_to_224(temp_image))
 
         images_lang_res = self.batch_predict(all_images, batch_size=8)
-        logger.info(f"images_lang_res: {images_lang_res}")
+        # logger.info(f"images_lang_res: {images_lang_res}")
         if len(images_lang_res) > 0:
             count_dict = Counter(images_lang_res)
             language = max(count_dict, key=count_dict.get)
@@ -107,7 +112,6 @@ class YOLOv11LangDetModel(object):
             language = None
         return language
 
-
     def predict(self, image):
         results = self.model.predict(image, verbose=False, device=self.device)
         predicted_class_id = int(results[0].probs.top1)
@@ -117,6 +121,7 @@ class YOLOv11LangDetModel(object):
 
     def batch_predict(self, images: list, batch_size: int) -> list:
         images_lang_res = []
+
         for index in range(0, len(images), batch_size):
             lang_res = [
                 image_res.cpu()

+ 20 - 1
magic_pdf/model/sub_modules/model_init.py

@@ -2,8 +2,8 @@ import torch
 from loguru import logger
 
 from magic_pdf.config.constants import MODEL_NAME
-from magic_pdf.libs.config_reader import get_device
 from magic_pdf.model.model_list import AtomicModel
+from magic_pdf.model.sub_modules.language_detection.yolov11.YOLOv11 import YOLOv11LangDetModel
 from magic_pdf.model.sub_modules.layout.doclayout_yolo.DocLayoutYOLO import \
     DocLayoutYOLOModel
 from magic_pdf.model.sub_modules.layout.layoutlmv3.model_init import \
@@ -63,6 +63,13 @@ def doclayout_yolo_model_init(weight, device='cpu'):
     return model
 
 
+def langdetect_model_init(langdetect_model_weight, device='cpu'):
+    if str(device).startswith("npu"):
+        device = torch.device(device)
+    model = YOLOv11LangDetModel(langdetect_model_weight, device)
+    return model
+
+
 def ocr_model_init(show_log: bool = False,
                    det_db_box_thresh=0.3,
                    lang=None,
@@ -130,6 +137,9 @@ def atom_model_init(model_name: str, **kwargs):
                 kwargs.get('doclayout_yolo_weights'),
                 kwargs.get('device')
             )
+        else:
+            logger.error('layout model name not allow')
+            exit(1)
     elif model_name == AtomicModel.MFD:
         atom_model = mfd_model_init(
             kwargs.get('mfd_weights'),
@@ -155,6 +165,15 @@ def atom_model_init(model_name: str, **kwargs):
             kwargs.get('device'),
             kwargs.get('ocr_engine')
         )
+    elif model_name == AtomicModel.LangDetect:
+        if kwargs.get('langdetect_model_name') == MODEL_NAME.YOLO_V11_LangDetect:
+            atom_model = langdetect_model_init(
+                kwargs.get('langdetect_model_weight'),
+                kwargs.get('device')
+            )
+        else:
+            logger.error('langdetect model name not allow')
+            exit(1)
     else:
         logger.error('model name not allow')
         exit(1)

+ 5 - 5
magic_pdf/model/sub_modules/ocr/paddleocr/ppocr_273_mod.py

@@ -21,7 +21,7 @@ class ModifiedPaddleOCR(PaddleOCR):
     def __init__(self, *args, **kwargs):
 
         super().__init__(*args, **kwargs)
-
+        self.lang = kwargs.get('lang', 'ch')
         # 在cpu架构为arm且不支持cuda时调用onnx、
         if not torch.cuda.is_available() and platform.machine() in ['arm64', 'aarch64']:
             self.use_onnx = True
@@ -94,7 +94,7 @@ class ModifiedPaddleOCR(PaddleOCR):
             ocr_res = []
             for img in imgs:
                 img = preprocess_image(img)
-                if self.use_onnx:
+                if self.lang in ['ch'] and self.use_onnx:
                     dt_boxes, elapse = self.additional_ocr.text_detector(img)
                 else:
                     dt_boxes, elapse = self.text_detector(img)
@@ -124,7 +124,7 @@ class ModifiedPaddleOCR(PaddleOCR):
                     img, cls_res_tmp, elapse = self.text_classifier(img)
                     if not rec:
                         cls_res.append(cls_res_tmp)
-                if self.use_onnx:
+                if self.lang in ['ch'] and self.use_onnx:
                     rec_res, elapse = self.additional_ocr.text_recognizer(img)
                 else:
                     rec_res, elapse = self.text_recognizer(img)
@@ -142,7 +142,7 @@ class ModifiedPaddleOCR(PaddleOCR):
 
         start = time.time()
         ori_im = img.copy()
-        if self.use_onnx:
+        if self.lang in ['ch'] and self.use_onnx:
             dt_boxes, elapse = self.additional_ocr.text_detector(img)
         else:
             dt_boxes, elapse = self.text_detector(img)
@@ -183,7 +183,7 @@ class ModifiedPaddleOCR(PaddleOCR):
             time_dict['cls'] = elapse
             logger.debug("cls num  : {}, elapsed : {}".format(
                 len(img_crop_list), elapse))
-        if self.use_onnx:
+        if self.lang in ['ch'] and self.use_onnx:
             rec_res, elapse = self.additional_ocr.text_recognizer(img_crop_list)
         else:
             rec_res, elapse = self.text_recognizer(img_crop_list)

+ 1 - 1
magic_pdf/resources/model_config/model_configs.yaml

@@ -6,4 +6,4 @@ weights:
   struct_eqtable: TabRec/StructEqTable
   tablemaster: TabRec/TableMaster
   rapid_table: TabRec/RapidTable
-  yolo_v11n_langdetect: LangDetect/YOLO/yolo_v11_cls_ft.pt
+  yolo_v11n_langdetect: LangDetect/YOLO/yolo_v11_ft.pt