| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174 |
- # Copyright (c) Opendatalab. All rights reserved.
- import os
- from collections import Counter
- from uuid import uuid4
- import torch
- from PIL import Image
- from loguru import logger
- from ultralytics import YOLO
- language_dict = {
- "ch": "中文简体",
- "en": "英语",
- "japan": "日语",
- "korean": "韩语",
- "fr": "法语",
- "german": "德语",
- "ar": "阿拉伯语",
- "ru": "俄语"
- }
- class LangDetectMode:
- BASE = "base"
- CH_JP = "ch_jp"
- EN_FR_GE = "en_fr_ge"
- def split_images(image, result_images=None):
- """
- 对输入文件夹内的图片进行处理,若图片竖向(y方向)分辨率超过400,则进行拆分,
- 每次平分图片,直至拆分出的图片竖向分辨率都满足400以下,将处理后的图片(拆分后的子图片)保存到输出文件夹。
- 避免保存因裁剪区域超出图片范围导致出现的无效黑色图片部分。
- """
- if result_images is None:
- result_images = []
- width, height = image.size
- long_side = max(width, height) # 获取较长边长度
- if long_side <= 400:
- result_images.append(image)
- return result_images
- new_long_side = long_side // 2
- sub_images = []
- if width >= height: # 如果宽度是较长边
- for x in range(0, width, new_long_side):
- # 判断裁剪区域是否超出图片范围,如果超出则不进行裁剪保存操作
- if x + new_long_side > width:
- continue
- box = (x, 0, x + new_long_side, height)
- sub_image = image.crop(box)
- sub_images.append(sub_image)
- else: # 如果高度是较长边
- for y in range(0, height, new_long_side):
- # 判断裁剪区域是否超出图片范围,如果超出则不进行裁剪保存操作
- if y + new_long_side > height:
- continue
- box = (0, y, width, y + new_long_side)
- sub_image = image.crop(box)
- sub_images.append(sub_image)
- for sub_image in sub_images:
- split_images(sub_image, result_images)
- return result_images
- def resize_images_to_224(image):
- """
- 若分辨率小于224则用黑色背景补齐到224*224大小,若大于等于224则调整为224*224大小,并保存到输出文件夹中。
- """
- try:
- width, height = image.size
- if width < 224 or height < 224:
- new_image = Image.new('RGB', (224, 224), (0, 0, 0))
- paste_x = (224 - width) // 2
- paste_y = (224 - height) // 2
- new_image.paste(image, (paste_x, paste_y))
- image = new_image
- else:
- image = image.resize((224, 224), Image.Resampling.LANCZOS)
- # uuid = str(uuid4())
- # image.save(f"/tmp/{uuid}.jpg")
- return image
- except Exception as e:
- logger.exception(e)
- class YOLOv11LangDetModel(object):
- def __init__(self, langdetect_model_weights_dir, device):
- langdetect_model_base_weight = str(
- os.path.join(langdetect_model_weights_dir, 'yolo_v11_cls_ft.pt')
- )
- langdetect_model_ch_jp_weight = str(
- os.path.join(langdetect_model_weights_dir, 'yolo_v11_cls_ch_jp.pt')
- )
- langdetect_model_en_fr_ge_weight = str(
- os.path.join(langdetect_model_weights_dir, 'yolo_v11_cls_en_fr_ge.pt')
- )
- self.model = YOLO(langdetect_model_base_weight)
- self.ch_jp_model = YOLO(langdetect_model_ch_jp_weight)
- self.en_fr_ge_model = YOLO(langdetect_model_en_fr_ge_weight)
- if str(device).startswith("npu"):
- self.device = torch.device(device)
- else:
- self.device = device
- def do_detect(self, images: list, mode=LangDetectMode.BASE):
- all_images = []
- for image in images:
- width, height = image.size
- # logger.info(f"image size: {width} x {height}")
- if width < 100 and height < 100:
- continue
- temp_images = split_images(image)
- for temp_image in temp_images:
- all_images.append(resize_images_to_224(temp_image))
- images_lang_res = self.batch_predict(all_images, batch_size=8, mode=mode)
- logger.info(f"images_lang_res: {images_lang_res}")
- if len(images_lang_res) > 0:
- count_dict = Counter(images_lang_res)
- language = max(count_dict, key=count_dict.get)
- else:
- language = None
- return language
- def predict(self, image, mode=LangDetectMode.BASE):
- if mode == LangDetectMode.BASE:
- model = self.model
- elif mode == LangDetectMode.CH_JP:
- model = self.ch_jp_model
- elif mode == LangDetectMode.EN_FR_GE:
- model = self.en_fr_ge_model
- else:
- model = self.model
- results = model.predict(image, verbose=False, device=self.device)
- predicted_class_id = int(results[0].probs.top1)
- predicted_class_name = model.names[predicted_class_id]
- return predicted_class_name
- def batch_predict(self, images: list, batch_size: int, mode=LangDetectMode.BASE) -> list:
- images_lang_res = []
- if mode == LangDetectMode.BASE:
- model = self.model
- elif mode == LangDetectMode.CH_JP:
- model = self.ch_jp_model
- elif mode == LangDetectMode.EN_FR_GE:
- model = self.en_fr_ge_model
- else:
- model = self.model
- for index in range(0, len(images), batch_size):
- lang_res = [
- image_res.cpu()
- for image_res in model.predict(
- images[index: index + batch_size],
- verbose = False,
- device=self.device,
- )
- ]
- for res in lang_res:
- predicted_class_id = int(res.probs.top1)
- predicted_class_name = model.names[predicted_class_id]
- images_lang_res.append(predicted_class_name)
- return images_lang_res
|