瀏覽代碼

support high version PIL (#1846)

zhangyubo0722 1 年之前
父節點
當前提交
4b352c0a9e

+ 15 - 4
paddlex/modules/image_classification/dataset_checker/dataset_src/utils/visualizer.py

@@ -12,11 +12,11 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-
 import os
 import numpy as np
 import json
 from pathlib import Path
+import PIL
 from PIL import Image, ImageDraw, ImageFont
 
 from ......utils.fonts import PINGFANG_FONT_FILE_PATH
@@ -62,8 +62,13 @@ def draw_label(image, label, label_map_dict):
     for font_size in range(max_font_size, min_font_size - 1, -1):
         font = ImageFont.truetype(
             PINGFANG_FONT_FILE_PATH, font_size, encoding="utf-8")
-        text_width_tmp, text_height_tmp = draw.textsize(
-            label_map_dict[int(label)], font)
+        if tuple(map(int, PIL.__version__.split('.'))) <= (10, 0, 0):
+            text_width_tmp, text_height_tmp = draw.textsize(
+                label_map_dict[int(label)], font)
+        else:
+            left, top, right, bottom = draw.textbbox(
+                (0, 0), label_map_dict[int(label)], font)
+            text_width_tmp, text_height_tmp = right - left, bottom - top
         if text_width_tmp <= image_size[0]:
             break
         else:
@@ -71,7 +76,13 @@ def draw_label(image, label, label_map_dict):
     color_list = colormap(rgb=True)
     color = tuple(color_list[0])
     font_color = tuple(font_colormap(3))
-    text_width, text_height = draw.textsize(label_map_dict[int(label)], font)
+    if tuple(map(int, PIL.__version__.split('.'))) <= (10, 0, 0):
+        text_width, text_height = draw.textsize(label_map_dict[int(label)],
+                                                font)
+    else:
+        left, top, right, bottom = draw.textbbox(
+            (0, 0), label_map_dict[int(label)], font)
+        text_width, text_height = right - left, bottom - top
 
     rect_left = 3
     rect_top = 3

+ 12 - 3
paddlex/modules/image_classification/predictor/transforms.py

@@ -12,11 +12,11 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-
 import os
 import json
 from pathlib import Path
 import numpy as np
+import PIL
 from PIL import ImageDraw, ImageFont
 
 from .keys import ClsKeys as K
@@ -179,7 +179,12 @@ class SaveClsResults(BaseTransform):
         for font_size in range(max_font_size, min_font_size - 1, -1):
             font = ImageFont.truetype(
                 PINGFANG_FONT_FILE_PATH, font_size, encoding="utf-8")
-            text_width_tmp, text_height_tmp = draw.textsize(label_str, font)
+            if tuple(map(int, PIL.__version__.split('.'))) <= (10, 0, 0):
+                text_width_tmp, text_height_tmp = draw.textsize(label_str, font)
+            else:
+                left, top, right, bottom = draw.textbbox((0, 0), label_str,
+                                                         font)
+                text_width_tmp, text_height_tmp = right - left, bottom - top
             if text_width_tmp <= image_size[0]:
                 break
             else:
@@ -188,7 +193,11 @@ class SaveClsResults(BaseTransform):
         color_list = self._get_colormap(rgb=True)
         color = tuple(color_list[0])
         font_color = tuple(self._get_font_colormap(3))
-        text_width, text_height = draw.textsize(label_str, font)
+        if tuple(map(int, PIL.__version__.split('.'))) <= (10, 0, 0):
+            text_width, text_height = draw.textsize(label_str, font)
+        else:
+            left, top, right, bottom = draw.textbbox((0, 0), label_str, font)
+            text_width, text_height = right - left, bottom - top
 
         rect_left = 3
         rect_top = 3

+ 6 - 2
paddlex/modules/instance_segmentation/dataset_checker/dataset_src/utils/visualizer.py

@@ -12,11 +12,11 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-
 import os
 import numpy as np
 import json
 from pathlib import Path
+import PIL
 from PIL import Image, ImageDraw, ImageFont
 from pycocotools.coco import COCO
 
@@ -119,7 +119,11 @@ def draw_bbox(image, coco_info: COCO, img_id):
         # draw label
         label = coco_info.loadCats(catid)[0]['name']
         text = "{}".format(label)
-        tw, th = draw.textsize(text, font=font)
+        if tuple(map(int, PIL.__version__.split('.'))) <= (10, 0, 0):
+            tw, th = draw.textsize(text, font=font)
+        else:
+            left, top, right, bottom = draw.textbbox((0, 0), text, font)
+            tw, th = right - left, bottom - top
         if ymin < th:
             draw.rectangle(
                 [(xmin, ymin), (xmin + tw + 4, ymin + th + 1)], fill=color)

+ 6 - 2
paddlex/modules/object_detection/predictor/transforms.py

@@ -12,11 +12,11 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-
 import os
 
 import numpy as np
 import math
+import PIL
 from PIL import Image, ImageDraw, ImageFont
 
 from .keys import DetKeys as K
@@ -126,7 +126,11 @@ def draw_box(img, np_boxes, labels, threshold=0.5):
 
         # draw label
         text = "{} {:.2f}".format(labels[clsid], score)
-        tw, th = draw.textsize(text, font=font)
+        if tuple(map(int, PIL.__version__.split('.'))) <= (10, 0, 0):
+            tw, th = draw.textsize(text, font=font)
+        else:
+            left, top, right, bottom = draw.textbbox((0, 0), text, font)
+            tw, th = right - left, bottom - top
         if ymin < th:
             draw.rectangle(
                 [(xmin, ymin), (xmin + tw + 4, ymin + th + 1)], fill=color)