|
|
@@ -1,7 +1,9 @@
|
|
|
# Copyright (c) Opendatalab. All rights reserved.
|
|
|
+import os
|
|
|
from collections import Counter
|
|
|
from uuid import uuid4
|
|
|
|
|
|
+import torch
|
|
|
from PIL import Image
|
|
|
from loguru import logger
|
|
|
from ultralytics import YOLO
|
|
|
@@ -17,6 +19,11 @@ language_dict = {
|
|
|
"ru": "俄语"
|
|
|
}
|
|
|
|
|
|
+class LangDetectMode:
|
|
|
+ BASE = "base"
|
|
|
+ CH_JP = "ch_jp"
|
|
|
+ EN_FR_GE = "en_fr_ge"
|
|
|
+
|
|
|
|
|
|
def split_images(image, result_images=None):
|
|
|
"""
|
|
|
@@ -83,11 +90,25 @@ def resize_images_to_224(image):
|
|
|
|
|
|
|
|
|
class YOLOv11LangDetModel(object):
|
|
|
- def __init__(self, weight, device):
|
|
|
- self.model = YOLO(weight)
|
|
|
- self.device = device
|
|
|
-
|
|
|
- def do_detect(self, images: list):
|
|
|
+ def __init__(self, langdetect_model_weights_dir, device):
|
|
|
+ langdetect_model_base_weight = str(
|
|
|
+ os.path.join(langdetect_model_weights_dir, 'yolo_v11_cls_ft.pt')
|
|
|
+ )
|
|
|
+ langdetect_model_ch_jp_weight = str(
|
|
|
+ os.path.join(langdetect_model_weights_dir, 'yolo_v11_cls_ch_jp.pt')
|
|
|
+ )
|
|
|
+ langdetect_model_en_fr_ge_weight = str(
|
|
|
+ os.path.join(langdetect_model_weights_dir, 'yolo_v11_cls_en_fr_ge.pt')
|
|
|
+ )
|
|
|
+ self.model = YOLO(langdetect_model_base_weight)
|
|
|
+ self.ch_jp_model = YOLO(langdetect_model_ch_jp_weight)
|
|
|
+ self.en_fr_ge_model = YOLO(langdetect_model_en_fr_ge_weight)
|
|
|
+
|
|
|
+ if str(device).startswith("npu"):
|
|
|
+ self.device = torch.device(device)
|
|
|
+ else:
|
|
|
+ self.device = device
|
|
|
+ def do_detect(self, images: list, mode=LangDetectMode.BASE):
|
|
|
all_images = []
|
|
|
for image in images:
|
|
|
width, height = image.size
|
|
|
@@ -98,7 +119,7 @@ class YOLOv11LangDetModel(object):
|
|
|
for temp_image in temp_images:
|
|
|
all_images.append(resize_images_to_224(temp_image))
|
|
|
|
|
|
- images_lang_res = self.batch_predict(all_images, batch_size=8)
|
|
|
+ images_lang_res = self.batch_predict(all_images, batch_size=8, mode=mode)
|
|
|
logger.info(f"images_lang_res: {images_lang_res}")
|
|
|
if len(images_lang_res) > 0:
|
|
|
count_dict = Counter(images_lang_res)
|
|
|
@@ -107,20 +128,39 @@ class YOLOv11LangDetModel(object):
|
|
|
language = None
|
|
|
return language
|
|
|
|
|
|
+ def predict(self, image, mode=LangDetectMode.BASE):
|
|
|
+
|
|
|
+ if mode == LangDetectMode.BASE:
|
|
|
+ model = self.model
|
|
|
+ elif mode == LangDetectMode.CH_JP:
|
|
|
+ model = self.ch_jp_model
|
|
|
+ elif mode == LangDetectMode.EN_FR_GE:
|
|
|
+ model = self.en_fr_ge_model
|
|
|
+ else:
|
|
|
+ model = self.model
|
|
|
|
|
|
- def predict(self, image):
|
|
|
- results = self.model.predict(image, verbose=False, device=self.device)
|
|
|
+ results = model.predict(image, verbose=False, device=self.device)
|
|
|
predicted_class_id = int(results[0].probs.top1)
|
|
|
- predicted_class_name = self.model.names[predicted_class_id]
|
|
|
+ predicted_class_name = model.names[predicted_class_id]
|
|
|
return predicted_class_name
|
|
|
|
|
|
|
|
|
- def batch_predict(self, images: list, batch_size: int) -> list:
|
|
|
+ def batch_predict(self, images: list, batch_size: int, mode=LangDetectMode.BASE) -> list:
|
|
|
images_lang_res = []
|
|
|
+
|
|
|
+ if mode == LangDetectMode.BASE:
|
|
|
+ model = self.model
|
|
|
+ elif mode == LangDetectMode.CH_JP:
|
|
|
+ model = self.ch_jp_model
|
|
|
+ elif mode == LangDetectMode.EN_FR_GE:
|
|
|
+ model = self.en_fr_ge_model
|
|
|
+ else:
|
|
|
+ model = self.model
|
|
|
+
|
|
|
for index in range(0, len(images), batch_size):
|
|
|
lang_res = [
|
|
|
image_res.cpu()
|
|
|
- for image_res in self.model.predict(
|
|
|
+ for image_res in model.predict(
|
|
|
images[index: index + batch_size],
|
|
|
verbose = False,
|
|
|
device=self.device,
|
|
|
@@ -128,7 +168,7 @@ class YOLOv11LangDetModel(object):
|
|
|
]
|
|
|
for res in lang_res:
|
|
|
predicted_class_id = int(res.probs.top1)
|
|
|
- predicted_class_name = self.model.names[predicted_class_id]
|
|
|
+ predicted_class_name = model.names[predicted_class_id]
|
|
|
images_lang_res.append(predicted_class_name)
|
|
|
|
|
|
return images_lang_res
|