result.py 3.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596
  1. # Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import copy
  15. import numpy as np
  16. import PIL
  17. from PIL import Image, ImageDraw, ImageFont
  18. from ....utils.fonts import PINGFANG_FONT
  19. from ...common.result import BaseCVResult, JsonMixin
  20. from ...utils.color_map import get_colormap
  21. class MLClassResult(BaseCVResult):
  22. def _to_str(self, *args, **kwargs):
  23. data = copy.deepcopy(self)
  24. data.pop("input_img")
  25. return JsonMixin._to_str(data, *args, **kwargs)
  26. def _to_json(self, *args, **kwargs):
  27. data = copy.deepcopy(self)
  28. data.pop("input_img")
  29. return JsonMixin._to_json(data, *args, **kwargs)
  30. def _to_img(self):
  31. """Draw label on image"""
  32. image = Image.fromarray(self["input_img"])
  33. label_names = self["label_names"]
  34. scores = self["scores"]
  35. image = image.convert("RGB")
  36. image_width, image_height = image.size
  37. font_size = int(image_width * 0.06)
  38. font = ImageFont.truetype(PINGFANG_FONT.path, font_size)
  39. text_lines = []
  40. row_width = 0
  41. row_height = 0
  42. row_text = "\t"
  43. for label_name, score in zip(label_names, scores):
  44. text = f"{label_name}({score})\t"
  45. if int(PIL.__version__.split(".")[0]) < 10:
  46. text_width, row_height = font.getsize(text)
  47. else:
  48. text_width, row_height = font.getbbox(text)[2:]
  49. if row_width + text_width <= image_width:
  50. row_text += text
  51. row_width += text_width
  52. else:
  53. text_lines.append(row_text)
  54. row_text = "\t" + text
  55. row_width = text_width
  56. text_lines.append(row_text)
  57. color_list = get_colormap(rgb=True)
  58. color = tuple(color_list[0])
  59. new_image_height = image_height + len(text_lines) * int(row_height * 1.2)
  60. new_image = Image.new("RGB", (image_width, new_image_height), color)
  61. new_image.paste(image, (0, 0))
  62. draw = ImageDraw.Draw(new_image)
  63. font_color = tuple(self._get_font_colormap(3))
  64. for i, text in enumerate(text_lines):
  65. if int(PIL.__version__.split(".")[0]) < 10:
  66. text_width, _ = font.getsize(text)
  67. else:
  68. text_width, _ = font.getbbox(text)[2:]
  69. draw.text(
  70. (0, image_height + i * int(row_height * 1.2)),
  71. text,
  72. fill=font_color,
  73. font=font,
  74. )
  75. return {"res": new_image}
  76. def _get_font_colormap(self, color_index):
  77. """
  78. Get font colormap
  79. """
  80. dark = np.array([0x14, 0x0E, 0x35])
  81. light = np.array([0xFF, 0xFF, 0xFF])
  82. light_indexs = [0, 3, 4, 8, 9, 13, 14, 18, 19]
  83. if color_index in light_indexs:
  84. return light.astype("int32")
  85. else:
  86. return dark.astype("int32")