소스 검색

feat(model): add language detection model and update related modules

- Add language detection model initialization and integration
- Update model list to include language detection
- Refactor language detection utils for better model management
myhloli 10 달 전
부모
커밋
735f3a7059
3개의 변경된 파일45개의 추가작업 그리고 12개의 파일을 삭제
  1. 1 0
      magic_pdf/model/model_list.py
  2. 24 11
      magic_pdf/model/sub_modules/language_detection/utils.py
  3. 20 1
      magic_pdf/model/sub_modules/model_init.py

+ 1 - 0
magic_pdf/model/model_list.py

@@ -9,3 +9,4 @@ class AtomicModel:
     MFR = "mfr"
     OCR = "ocr"
     Table = "table"
+    LangDetect = "langdetect"

+ 24 - 11
magic_pdf/model/sub_modules/language_detection/utils.py

@@ -12,7 +12,7 @@ from magic_pdf.data.utils import load_images_from_pdf
 from magic_pdf.libs.config_reader import get_local_models_dir, get_device
 from magic_pdf.libs.pdf_check import extract_pages
 from magic_pdf.model.model_list import AtomicModel
-from magic_pdf.model.sub_modules.language_detection.yolov11.YOLOv11 import YOLOv11LangDetModel, LangDetectMode
+from magic_pdf.model.sub_modules.language_detection.yolov11.YOLOv11 import LangDetectMode
 from magic_pdf.model.sub_modules.model_init import AtomModelSingleton
 
 
@@ -61,19 +61,32 @@ def auto_detect_lang(pdf_bytes: bytes):
     sample_pdf_bytes = sample_docs.tobytes()
     simple_images = load_images_from_pdf(sample_pdf_bytes, dpi=200)
     text_images = get_text_images(simple_images)
-    local_models_dir, device, configs = get_model_config()
-    # 用yolo11做语言分类
-    langdetect_model_weights_dir = str(
-        os.path.join(
-            local_models_dir, configs['weights'][MODEL_NAME.YOLO_V11_LangDetect]
-        )
-    )
-    langdetect_model = YOLOv11LangDetModel(langdetect_model_weights_dir, device)
+    langdetect_model = model_init(MODEL_NAME.YOLO_V11_LangDetect)
     lang = langdetect_model.do_detect(text_images)
-
     if lang in ["ch", "japan"]:
         lang = langdetect_model.do_detect(text_images, mode=LangDetectMode.CH_JP)
     elif lang in ["en", "fr", "german"]:
         lang = langdetect_model.do_detect(text_images, mode=LangDetectMode.EN_FR_GE)
 
-    return lang
+    return lang
+
+
+def model_init(model_name: str):
+    atom_model_manager = AtomModelSingleton()
+
+    if model_name == MODEL_NAME.YOLO_V11_LangDetect:
+        local_models_dir, device, configs = get_model_config()
+        model = atom_model_manager.get_atom_model(
+            atom_model_name=AtomicModel.LangDetect,
+            langdetect_model_name=MODEL_NAME.YOLO_V11_LangDetect,
+            langdetect_model_weights_dir=str(
+                os.path.join(
+                    local_models_dir, configs['weights'][MODEL_NAME.YOLO_V11_LangDetect]
+                )
+            ),
+            device=device,
+        )
+    else:
+        raise ValueError(f"model_name {model_name} not found")
+    return model
+

+ 20 - 1
magic_pdf/model/sub_modules/model_init.py

@@ -2,8 +2,8 @@ import torch
 from loguru import logger
 
 from magic_pdf.config.constants import MODEL_NAME
-from magic_pdf.libs.config_reader import get_device
 from magic_pdf.model.model_list import AtomicModel
+from magic_pdf.model.sub_modules.language_detection.yolov11.YOLOv11 import YOLOv11LangDetModel
 from magic_pdf.model.sub_modules.layout.doclayout_yolo.DocLayoutYOLO import \
     DocLayoutYOLOModel
 from magic_pdf.model.sub_modules.layout.layoutlmv3.model_init import \
@@ -63,6 +63,13 @@ def doclayout_yolo_model_init(weight, device='cpu'):
     return model
 
 
+def langdetect_model_init(langdetect_model_weights_dir, device='cpu'):
+    if str(device).startswith("npu"):
+        device = torch.device(device)
+    model = YOLOv11LangDetModel(langdetect_model_weights_dir, device)
+    return model
+
+
 def ocr_model_init(show_log: bool = False,
                    det_db_box_thresh=0.3,
                    lang=None,
@@ -130,6 +137,9 @@ def atom_model_init(model_name: str, **kwargs):
                 kwargs.get('doclayout_yolo_weights'),
                 kwargs.get('device')
             )
+        else:
+            logger.error('layout model name not allow')
+            exit(1)
     elif model_name == AtomicModel.MFD:
         atom_model = mfd_model_init(
             kwargs.get('mfd_weights'),
@@ -155,6 +165,15 @@ def atom_model_init(model_name: str, **kwargs):
             kwargs.get('device'),
             kwargs.get('ocr_engine')
         )
+    elif model_name == AtomicModel.LangDetect:
+        if kwargs.get('langdetect_model_name') == MODEL_NAME.YOLO_V11_LangDetect:
+            atom_model = langdetect_model_init(
+                kwargs.get('langdetect_model_weights_dir'),
+                kwargs.get('device')
+            )
+        else:
+            logger.error('langdetect model name not allow')
+            exit(1)
     else:
         logger.error('model name not allow')
         exit(1)