YOLOv11.py 5.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147
  1. # Copyright (c) Opendatalab. All rights reserved.
  2. import time
  3. from collections import Counter
  4. from uuid import uuid4
  5. import cv2
  6. import numpy as np
  7. import torch
  8. from loguru import logger
  9. from ultralytics import YOLO
  10. language_dict = {
  11. "ch": "中文简体",
  12. "en": "英语",
  13. "japan": "日语",
  14. "korean": "韩语",
  15. "fr": "法语",
  16. "german": "德语",
  17. "ar": "阿拉伯语",
  18. "ru": "俄语"
  19. }
  20. def split_images(image, result_images=None):
  21. """
  22. 对输入文件夹内的图片进行处理,若图片竖向(y方向)分辨率超过400,则进行拆分,
  23. 每次平分图片,直至拆分出的图片竖向分辨率都满足400以下,将处理后的图片(拆分后的子图片)保存到输出文件夹。
  24. 避免保存因裁剪区域超出图片范围导致出现的无效黑色图片部分。
  25. """
  26. if result_images is None:
  27. result_images = []
  28. height, width = image.shape[:2]
  29. long_side = max(width, height) # 获取较长边长度
  30. if long_side <= 400:
  31. result_images.append(image)
  32. return result_images
  33. new_long_side = long_side // 2
  34. sub_images = []
  35. if width >= height: # 如果宽度是较长边
  36. for x in range(0, width, new_long_side):
  37. # 判断裁剪区域是否超出图片范围,如果超出则不进行裁剪保存操作
  38. if x + new_long_side > width:
  39. continue
  40. box = (x, 0, x + new_long_side, height)
  41. sub_image = image[0:height, x:x + new_long_side]
  42. sub_images.append(sub_image)
  43. else: # 如果高度是较长边
  44. for y in range(0, height, new_long_side):
  45. # 判断裁剪区域是否超出图片范围,如果超出则不进行裁剪保存操作
  46. if y + new_long_side > height:
  47. continue
  48. box = (0, y, width, y + new_long_side)
  49. sub_image = image[y:y + new_long_side, 0:width]
  50. sub_images.append(sub_image)
  51. for sub_image in sub_images:
  52. split_images(sub_image, result_images)
  53. return result_images
  54. def resize_images_to_224(image):
  55. """
  56. 若分辨率小于224则用黑色背景补齐到224*224大小,若大于等于224则调整为224*224大小。
  57. Works directly with NumPy arrays.
  58. """
  59. try:
  60. height, width = image.shape[:2]
  61. if width < 224 or height < 224:
  62. # Create black background
  63. new_image = np.zeros((224, 224, 3), dtype=np.uint8)
  64. # Calculate paste position (ensure they're not negative)
  65. paste_x = max(0, (224 - width) // 2)
  66. paste_y = max(0, (224 - height) // 2)
  67. # Make sure we don't exceed the boundaries of new_image
  68. paste_width = min(width, 224)
  69. paste_height = min(height, 224)
  70. # Paste original image onto black background
  71. new_image[paste_y:paste_y + paste_height, paste_x:paste_x + paste_width] = image[:paste_height, :paste_width]
  72. image = new_image
  73. else:
  74. # Resize using cv2
  75. image = cv2.resize(image, (224, 224), interpolation=cv2.INTER_LANCZOS4)
  76. return image
  77. except Exception as e:
  78. logger.exception(f"Error in resize_images_to_224: {e}")
  79. return None
  80. class YOLOv11LangDetModel(object):
  81. def __init__(self, langdetect_model_weight, device):
  82. self.model = YOLO(langdetect_model_weight)
  83. if str(device).startswith("npu"):
  84. self.device = torch.device(device)
  85. else:
  86. self.device = device
  87. def do_detect(self, images: list):
  88. all_images = []
  89. for image in images:
  90. height, width = image.shape[:2]
  91. if width < 100 and height < 100:
  92. continue
  93. temp_images = split_images(image)
  94. for temp_image in temp_images:
  95. all_images.append(resize_images_to_224(temp_image))
  96. # langdetect_start = time.time()
  97. images_lang_res = self.batch_predict(all_images, batch_size=256)
  98. # logger.info(f"image number of langdetect: {len(images_lang_res)}, langdetect time: {round(time.time() - langdetect_start, 2)}")
  99. if len(images_lang_res) > 0:
  100. count_dict = Counter(images_lang_res)
  101. language = max(count_dict, key=count_dict.get)
  102. else:
  103. language = None
  104. return language
  105. def predict(self, image):
  106. results = self.model.predict(image, verbose=False, device=self.device)
  107. predicted_class_id = int(results[0].probs.top1)
  108. predicted_class_name = self.model.names[predicted_class_id]
  109. return predicted_class_name
  110. def batch_predict(self, images: list, batch_size: int) -> list:
  111. images_lang_res = []
  112. for index in range(0, len(images), batch_size):
  113. lang_res = [
  114. image_res.cpu()
  115. for image_res in self.model.predict(
  116. images[index: index + batch_size],
  117. verbose = False,
  118. device=self.device,
  119. )
  120. ]
  121. for res in lang_res:
  122. predicted_class_id = int(res.probs.top1)
  123. predicted_class_name = self.model.names[predicted_class_id]
  124. images_lang_res.append(predicted_class_name)
  125. return images_lang_res