|
|
@@ -17,7 +17,7 @@ import json
|
|
|
from pathlib import Path
|
|
|
import numpy as np
|
|
|
import PIL
|
|
|
-from PIL import ImageDraw, ImageFont
|
|
|
+from PIL import ImageDraw, ImageFont, Image
|
|
|
|
|
|
from .keys import ClsKeys as K
|
|
|
from ...base import BaseTransform
|
|
|
@@ -25,7 +25,13 @@ from ...base.predictor.io import ImageWriter, ImageReader
|
|
|
from ....utils.fonts import PINGFANG_FONT_FILE_PATH
|
|
|
from ....utils import logging
|
|
|
|
|
|
-__all__ = ["Topk", "NormalizeFeatures", "PrintResult", "SaveClsResults"]
|
|
|
+__all__ = [
|
|
|
+ "Topk",
|
|
|
+ "NormalizeFeatures",
|
|
|
+ "PrintResult",
|
|
|
+ "SaveClsResults",
|
|
|
+ "MultiLabelThreshOutput",
|
|
|
+]
|
|
|
|
|
|
|
|
|
def _parse_class_id_map(class_ids):
|
|
|
@@ -282,3 +288,101 @@ class SaveClsResults(BaseTransform):
|
|
|
def get_output_keys(cls):
|
|
|
"""get output keys"""
|
|
|
return []
|
|
|
+
|
|
|
+
|
|
|
+class MultiLabelThreshOutput(BaseTransform):
|
|
|
+ def __init__(self, threshold=0.5, class_ids=None, delimiter=None):
|
|
|
+ super().__init__()
|
|
|
+ assert isinstance(threshold, (float,))
|
|
|
+ self.threshold = threshold
|
|
|
+ self.delimiter = delimiter if delimiter is not None else " "
|
|
|
+ self.class_id_map = _parse_class_id_map(class_ids)
|
|
|
+
|
|
|
+ def apply(self, data):
|
|
|
+ """apply"""
|
|
|
+ y = []
|
|
|
+ x = data[K.CLS_PRED]
|
|
|
+ pred_index = np.where(x >= self.threshold)[0].astype("int32")
|
|
|
+ index = pred_index[np.argsort(x[pred_index])][::-1]
|
|
|
+ clas_id_list = []
|
|
|
+ score_list = []
|
|
|
+ label_name_list = []
|
|
|
+ for i in index:
|
|
|
+ clas_id_list.append(i.item())
|
|
|
+ score_list.append(x[i].item())
|
|
|
+ if self.class_id_map is not None:
|
|
|
+ label_name_list.append(self.class_id_map[i.item()])
|
|
|
+ result = {
|
|
|
+ "class_ids": clas_id_list,
|
|
|
+ "scores": np.around(score_list, decimals=5).tolist(),
|
|
|
+ "label_names": label_name_list,
|
|
|
+ }
|
|
|
+ y.append(result)
|
|
|
+ data[K.CLS_RESULT] = y
|
|
|
+ return data
|
|
|
+
|
|
|
+ @classmethod
|
|
|
+ def get_input_keys(cls):
|
|
|
+ """get input keys"""
|
|
|
+ return [K.IM_PATH, K.CLS_PRED]
|
|
|
+
|
|
|
+ @classmethod
|
|
|
+ def get_output_keys(cls):
|
|
|
+ """get output keys"""
|
|
|
+ return [K.CLS_RESULT]
|
|
|
+
|
|
|
+
|
|
|
+class SaveMLClsResults(SaveClsResults, BaseTransform):
|
|
|
+ def __init__(self, save_dir, class_ids=None):
|
|
|
+ super().__init__(save_dir=save_dir)
|
|
|
+ self.save_dir = save_dir
|
|
|
+ self.class_id_map = _parse_class_id_map(class_ids)
|
|
|
+ self._writer = ImageWriter(backend="pillow")
|
|
|
+
|
|
|
+ def apply(self, data):
|
|
|
+ """Draw label on image"""
|
|
|
+ ori_path = data[K.IM_PATH]
|
|
|
+ results = data[K.CLS_RESULT]
|
|
|
+ scores = results[0]["scores"]
|
|
|
+ label_names = results[0]["label_names"]
|
|
|
+ file_name = os.path.basename(ori_path)
|
|
|
+ save_path = os.path.join(self.save_dir, file_name)
|
|
|
+ image = ImageReader(backend="pil").read(ori_path)
|
|
|
+ image = image.convert("RGB")
|
|
|
+ image_width, image_height = image.size
|
|
|
+ font_size = int(image_width * 0.06)
|
|
|
+
|
|
|
+ font = ImageFont.truetype(PINGFANG_FONT_FILE_PATH, font_size)
|
|
|
+ text_lines = []
|
|
|
+ row_width = 0
|
|
|
+ row_height = 0
|
|
|
+ row_text = "\t"
|
|
|
+ for label_name, score in zip(label_names, scores):
|
|
|
+ text = f"{label_name}({score})\t"
|
|
|
+ text_width, row_height = font.getsize(text)
|
|
|
+ if row_width + text_width <= image_width:
|
|
|
+ row_text += text
|
|
|
+ row_width += text_width
|
|
|
+ else:
|
|
|
+ text_lines.append(row_text)
|
|
|
+ row_text = "\t" + text
|
|
|
+ row_width = text_width
|
|
|
+ text_lines.append(row_text)
|
|
|
+ color_list = self._get_colormap(rgb=True)
|
|
|
+ color = tuple(color_list[0])
|
|
|
+ new_image_height = image_height + len(text_lines) * int(row_height * 1.2)
|
|
|
+ new_image = Image.new("RGB", (image_width, new_image_height), color)
|
|
|
+ new_image.paste(image, (0, 0))
|
|
|
+
|
|
|
+ draw = ImageDraw.Draw(new_image)
|
|
|
+ font_color = tuple(self._get_font_colormap(3))
|
|
|
+ for i, text in enumerate(text_lines):
|
|
|
+ text_width, _ = font.getsize(text)
|
|
|
+ draw.text(
|
|
|
+ (0, image_height + i * int(row_height * 1.2)),
|
|
|
+ text,
|
|
|
+ fill=font_color,
|
|
|
+ font=font,
|
|
|
+ )
|
|
|
+ self._write_image(save_path, new_image)
|
|
|
+ return data
|