Forráskód Böngészése

refactor(langdetect): simplify language detection model and improve logging

- Remove LangDetectMode and related conditional logic
- Use a single model weight for language detection
- Add logging for language detection results
- Update model initialization and prediction methods
myhloli 10 hónapja
szülő
commit
3271cf75d3

+ 1 - 0
magic_pdf/data/dataset.py

@@ -153,6 +153,7 @@ class PymuDocDataset(Dataset):
             logger.info(f"lang: {lang}, detect_lang: {self._lang}")
         else:
             self._lang = lang
+            logger.info(f"lang: {lang}")
     def __len__(self) -> int:
         """The page number of the pdf."""
         return len(self._records)

+ 1 - 7
magic_pdf/model/sub_modules/language_detection/utils.py

@@ -12,7 +12,6 @@ from magic_pdf.data.utils import load_images_from_pdf
 from magic_pdf.libs.config_reader import get_local_models_dir, get_device
 from magic_pdf.libs.pdf_check import extract_pages
 from magic_pdf.model.model_list import AtomicModel
-from magic_pdf.model.sub_modules.language_detection.yolov11.YOLOv11 import LangDetectMode
 from magic_pdf.model.sub_modules.model_init import AtomModelSingleton
 
 
@@ -63,11 +62,6 @@ def auto_detect_lang(pdf_bytes: bytes):
     text_images = get_text_images(simple_images)
     langdetect_model = model_init(MODEL_NAME.YOLO_V11_LangDetect)
     lang = langdetect_model.do_detect(text_images)
-    if lang in ["ch", "japan"]:
-        lang = langdetect_model.do_detect(text_images, mode=LangDetectMode.CH_JP)
-    elif lang in ["en", "fr", "german"]:
-        lang = langdetect_model.do_detect(text_images, mode=LangDetectMode.EN_FR_GE)
-
     return lang
 
 
@@ -79,7 +73,7 @@ def model_init(model_name: str):
         model = atom_model_manager.get_atom_model(
             atom_model_name=AtomicModel.LangDetect,
             langdetect_model_name=MODEL_NAME.YOLO_V11_LangDetect,
-            langdetect_model_weights_dir=str(
+            langdetect_model_weight=str(
                 os.path.join(
                     local_models_dir, configs['weights'][MODEL_NAME.YOLO_V11_LangDetect]
                 )

+ 12 - 47
magic_pdf/model/sub_modules/language_detection/yolov11/YOLOv11.py

@@ -1,5 +1,4 @@
 # Copyright (c) Opendatalab. All rights reserved.
-import os
 from collections import Counter
 from uuid import uuid4
 
@@ -19,11 +18,6 @@ language_dict = {
     "ru": "俄语"
 }
 
-class LangDetectMode:
-    BASE = "base"
-    CH_JP = "ch_jp"
-    EN_FR_GE = "en_fr_ge"
-
 
 def split_images(image, result_images=None):
     """
@@ -90,25 +84,15 @@ def resize_images_to_224(image):
 
 
 class YOLOv11LangDetModel(object):
-    def __init__(self, langdetect_model_weights_dir, device):
-        langdetect_model_base_weight = str(
-            os.path.join(langdetect_model_weights_dir, 'yolo_v11_cls_ft.pt')
-        )
-        langdetect_model_ch_jp_weight = str(
-            os.path.join(langdetect_model_weights_dir, 'yolo_v11_cls_ch_jp.pt')
-        )
-        langdetect_model_en_fr_ge_weight = str(
-            os.path.join(langdetect_model_weights_dir, 'yolo_v11_cls_en_fr_ge.pt')
-        )
-        self.model = YOLO(langdetect_model_base_weight)
-        self.ch_jp_model = YOLO(langdetect_model_ch_jp_weight)
-        self.en_fr_ge_model = YOLO(langdetect_model_en_fr_ge_weight)
+    def __init__(self, langdetect_model_weight, device):
+
+        self.model = YOLO(langdetect_model_weight)
 
         if str(device).startswith("npu"):
             self.device = torch.device(device)
         else:
             self.device = device
-    def do_detect(self, images: list, mode=LangDetectMode.BASE):
+    def do_detect(self, images: list):
         all_images = []
         for image in images:
             width, height = image.size
@@ -119,8 +103,8 @@ class YOLOv11LangDetModel(object):
             for temp_image in temp_images:
                 all_images.append(resize_images_to_224(temp_image))
 
-        images_lang_res = self.batch_predict(all_images, batch_size=8, mode=mode)
-        logger.info(f"images_lang_res: {images_lang_res}")
+        images_lang_res = self.batch_predict(all_images, batch_size=8)
+        # logger.info(f"images_lang_res: {images_lang_res}")
         if len(images_lang_res) > 0:
             count_dict = Counter(images_lang_res)
             language = max(count_dict, key=count_dict.get)
@@ -128,39 +112,20 @@ class YOLOv11LangDetModel(object):
             language = None
         return language
 
-    def predict(self, image, mode=LangDetectMode.BASE):
-
-        if mode == LangDetectMode.BASE:
-            model = self.model
-        elif mode == LangDetectMode.CH_JP:
-            model = self.ch_jp_model
-        elif mode == LangDetectMode.EN_FR_GE:
-            model = self.en_fr_ge_model
-        else:
-            model = self.model
-
-        results = model.predict(image, verbose=False, device=self.device)
+    def predict(self, image):
+        results = self.model.predict(image, verbose=False, device=self.device)
         predicted_class_id = int(results[0].probs.top1)
-        predicted_class_name = model.names[predicted_class_id]
+        predicted_class_name = self.model.names[predicted_class_id]
         return predicted_class_name
 
 
-    def batch_predict(self, images: list, batch_size: int, mode=LangDetectMode.BASE) -> list:
+    def batch_predict(self, images: list, batch_size: int) -> list:
         images_lang_res = []
 
-        if mode == LangDetectMode.BASE:
-            model = self.model
-        elif mode == LangDetectMode.CH_JP:
-            model = self.ch_jp_model
-        elif mode == LangDetectMode.EN_FR_GE:
-            model = self.en_fr_ge_model
-        else:
-            model = self.model
-
         for index in range(0, len(images), batch_size):
             lang_res = [
                 image_res.cpu()
-                for image_res in model.predict(
+                for image_res in self.model.predict(
                     images[index: index + batch_size],
                     verbose = False,
                     device=self.device,
@@ -168,7 +133,7 @@ class YOLOv11LangDetModel(object):
             ]
             for res in lang_res:
                 predicted_class_id = int(res.probs.top1)
-                predicted_class_name = model.names[predicted_class_id]
+                predicted_class_name = self.model.names[predicted_class_id]
                 images_lang_res.append(predicted_class_name)
 
         return images_lang_res

+ 3 - 3
magic_pdf/model/sub_modules/model_init.py

@@ -63,10 +63,10 @@ def doclayout_yolo_model_init(weight, device='cpu'):
     return model
 
 
-def langdetect_model_init(langdetect_model_weights_dir, device='cpu'):
+def langdetect_model_init(langdetect_model_weight, device='cpu'):
     if str(device).startswith("npu"):
         device = torch.device(device)
-    model = YOLOv11LangDetModel(langdetect_model_weights_dir, device)
+    model = YOLOv11LangDetModel(langdetect_model_weight, device)
     return model
 
 
@@ -168,7 +168,7 @@ def atom_model_init(model_name: str, **kwargs):
     elif model_name == AtomicModel.LangDetect:
         if kwargs.get('langdetect_model_name') == MODEL_NAME.YOLO_V11_LangDetect:
             atom_model = langdetect_model_init(
-                kwargs.get('langdetect_model_weights_dir'),
+                kwargs.get('langdetect_model_weight'),
                 kwargs.get('device')
             )
         else:

+ 1 - 1
magic_pdf/resources/model_config/model_configs.yaml

@@ -6,4 +6,4 @@ weights:
   struct_eqtable: TabRec/StructEqTable
   tablemaster: TabRec/TableMaster
   rapid_table: TabRec/RapidTable
-  yolo_v11n_langdetect: LangDetect/YOLO
+  yolo_v11n_langdetect: LangDetect/YOLO/yolo_v11_ft.pt