Browse Source

fix solov2 bug

zhangyubo0722 1 year ago
parent
commit
ddea2d9dac

+ 21 - 3
paddlex/inference/components/paddle_predictor/predictor.py

@@ -209,6 +209,7 @@ class ImageDetPredictor(BasePaddlePredictor):
                 np.stack(scale_factors, axis=0).astype(dtype=np.float32, copy=False),
             ]
         else:
+            # img_size = [img_size[::-1] for img_size in img_size]
             return [
                 np.stack(img_size, axis=0).astype(dtype=np.float32, copy=False),
                 np.stack(img, axis=0).astype(dtype=np.float32, copy=False),
@@ -218,7 +219,23 @@ class ImageDetPredictor(BasePaddlePredictor):
     def format_output(self, pred):
         box_idx_start = 0
         pred_box = []
+
+        if len(pred) == 4:
+            # Adapt to SOLOv2
+            pred_class_id = []
+            pred_mask = []
+            pred_class_id.append([pred[1], pred[2]])
+            pred_mask.append(pred[3])
+            return [
+                {
+                    "class_id": np.array(pred_class_id[i]),
+                    "masks": np.array(pred_mask[i]),
+                }
+                for i in range(len(pred_class_id))
+            ]
+
         if len(pred) == 3:
+            # Adapt to Instance Segmentation
             pred_mask = []
         for idx in range(len(pred[1])):
             np_boxes_num = pred[1][idx]
@@ -230,10 +247,11 @@ class ImageDetPredictor(BasePaddlePredictor):
                 pred_mask.append(np_masks)
             box_idx_start = box_idx_end
 
-        boxes = [{"boxes": np.array(res)} for res in pred_box]
         if len(pred) == 3:
-            masks = [{"masks": np.array(res)} for res in pred_mask]
-            return [{"boxes": boxes[0]["boxes"], "masks": masks[0]["masks"]}]
+            return [
+                {"boxes": np.array(pred_box[i]), "masks": np.array(pred_mask[i])}
+                for i in range(len(pred_box))
+            ]
         else:
             return [{"boxes": np.array(res)} for res in pred_box]
 

+ 1 - 1
paddlex/inference/components/task_related/__init__.py

@@ -22,7 +22,7 @@ from .text_det import (
 )
 from .text_rec import OCRReisizeNormImg, CTCLabelDecode
 from .table_rec import TableLabelDecode
-from .det import DetPostProcess, CropByBoxes, DetPad
+from .det import DetPostProcess, CropByBoxes, DetPad, WarpAffine
 from .instance_seg import InstanceSegPostProcess
 from .warp import DocTrPostProcess
 from .seg import Map_to_mask

+ 158 - 0
paddlex/inference/components/task_related/det.py

@@ -13,6 +13,7 @@
 # limitations under the License.
 
 import os
+import cv2
 
 import numpy as np
 from ...utils.io import ImageReader
@@ -42,6 +43,163 @@ def restructured_boxes(boxes, labels, img_size):
     return box_list
 
 
+def rotate_point(pt, angle_rad):
+    """Rotate a point by an angle.
+    Args:
+        pt (list[float]): 2 dimensional point to be rotated
+        angle_rad (float): rotation angle by radian
+    Returns:
+        list[float]: Rotated point.
+    """
+    assert len(pt) == 2
+    sn, cs = np.sin(angle_rad), np.cos(angle_rad)
+    new_x = pt[0] * cs - pt[1] * sn
+    new_y = pt[0] * sn + pt[1] * cs
+    rotated_pt = [new_x, new_y]
+
+    return rotated_pt
+
+
+def _get_3rd_point(a, b):
+    """To calculate the affine matrix, three pairs of points are required. This
+    function is used to get the 3rd point, given 2D points a & b.
+    The 3rd point is defined by rotating vector `a - b` by 90 degrees
+    anticlockwise, using b as the rotation center.
+    Args:
+        a (np.ndarray): point(x,y)
+        b (np.ndarray): point(x,y)
+    Returns:
+        np.ndarray: The 3rd point.
+    """
+    assert len(a) == 2
+    assert len(b) == 2
+    direction = a - b
+    third_pt = b + np.array([-direction[1], direction[0]], dtype=np.float32)
+
+    return third_pt
+
+
+def get_affine_transform(
+    center, input_size, rot, output_size, shift=(0.0, 0.0), inv=False
+):
+    """Get the affine transform matrix, given the center/scale/rot/output_size.
+    Args:
+        center (np.ndarray[2, ]): Center of the bounding box (x, y).
+        scale (np.ndarray[2, ]): Scale of the bounding box
+            wrt [width, height].
+        rot (float): Rotation angle (degree).
+        output_size (np.ndarray[2, ]): Size of the destination heatmaps.
+        shift (0-100%): Shift translation ratio wrt the width/height.
+            Default (0., 0.).
+        inv (bool): Option to inverse the affine transform direction.
+            (inv=False: src->dst or inv=True: dst->src)
+    Returns:
+        np.ndarray: The transform matrix.
+    """
+    assert len(center) == 2
+    assert len(output_size) == 2
+    assert len(shift) == 2
+    if not isinstance(input_size, (np.ndarray, list)):
+        input_size = np.array([input_size, input_size], dtype=np.float32)
+    scale_tmp = input_size
+
+    shift = np.array(shift)
+    src_w = scale_tmp[0]
+    dst_w = output_size[0]
+    dst_h = output_size[1]
+
+    rot_rad = np.pi * rot / 180
+    src_dir = rotate_point([0.0, src_w * -0.5], rot_rad)
+    dst_dir = np.array([0.0, dst_w * -0.5])
+
+    src = np.zeros((3, 2), dtype=np.float32)
+    src[0, :] = center + scale_tmp * shift
+    src[1, :] = center + src_dir + scale_tmp * shift
+    src[2, :] = _get_3rd_point(src[0, :], src[1, :])
+
+    dst = np.zeros((3, 2), dtype=np.float32)
+    dst[0, :] = [dst_w * 0.5, dst_h * 0.5]
+    dst[1, :] = np.array([dst_w * 0.5, dst_h * 0.5]) + dst_dir
+    dst[2, :] = _get_3rd_point(dst[0, :], dst[1, :])
+
+    if inv:
+        trans = cv2.getAffineTransform(np.float32(dst), np.float32(src))
+    else:
+        trans = cv2.getAffineTransform(np.float32(src), np.float32(dst))
+
+    return trans
+
+
+class WarpAffine(BaseComponent):
+    """Warp affine the image"""
+
+    INPUT_KEYS = ["img"]
+    OUTPUT_KEYS = ["img", "img_size", "scale_factors"]
+    DEAULT_INPUTS = {"img": "img"}
+    DEAULT_OUTPUTS = {
+        "img": "img",
+        "img_size": "img_size",
+        "scale_factors": "scale_factors",
+    }
+
+    def __init__(
+        self,
+        keep_res=False,
+        pad=31,
+        input_h=512,
+        input_w=512,
+        scale=0.4,
+        shift=0.1,
+        down_ratio=4,
+    ):
+        super().__init__()
+        self.keep_res = keep_res
+        self.pad = pad
+        self.input_h = input_h
+        self.input_w = input_w
+        self.scale = scale
+        self.shift = shift
+        self.down_ratio = down_ratio
+
+    def apply(self, img):
+
+        img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
+
+        h, w = img.shape[:2]
+
+        if self.keep_res:
+            # True in detection eval/infer
+            input_h = (h | self.pad) + 1
+            input_w = (w | self.pad) + 1
+            s = np.array([input_w, input_h], dtype=np.float32)
+            c = np.array([w // 2, h // 2], dtype=np.float32)
+
+        else:
+            # False in centertrack eval_mot/eval_mot
+            s = max(h, w) * 1.0
+            input_h, input_w = self.input_h, self.input_w
+            c = np.array([w / 2.0, h / 2.0], dtype=np.float32)
+
+        trans_input = get_affine_transform(c, s, 0, [input_w, input_h])
+        img = cv2.resize(img, (w, h))
+        inp = cv2.warpAffine(
+            img, trans_input, (input_w, input_h), flags=cv2.INTER_LINEAR
+        )
+
+        if not self.keep_res:
+            out_h = input_h // self.down_ratio
+            out_w = input_w // self.down_ratio
+            trans_output = get_affine_transform(c, s, 0, [out_w, out_h])
+
+        im_scale_w, im_scale_h = [input_w / w, input_h / h]
+
+        return {
+            "img": inp,
+            "img_size": [inp.shape[1], inp.shape[0]],
+            "scale_factors": [im_scale_w, im_scale_h],
+        }
+
+
 class DetPostProcess(BaseComponent):
     """Save Result Transform"""
 

+ 31 - 8
paddlex/inference/components/task_related/instance_seg.py

@@ -20,6 +20,10 @@ from ..base import BaseComponent
 from .det import restructured_boxes
 
 
+import cv2
+import numpy as np
+
+
 def extract_masks_from_boxes(boxes, masks):
     """
     Extracts the portion of each mask that is within the corresponding box.
@@ -41,7 +45,7 @@ def extract_masks_from_boxes(boxes, masks):
 class InstanceSegPostProcess(BaseComponent):
     """Save Result Transform"""
 
-    INPUT_KEYS = ["boxes", "masks", "img_size"]
+    INPUT_KEYS = [["boxes", "masks", "img_size"], ["class_id", "masks", "img_size"]]
     OUTPUT_KEYS = ["img_path", "boxes", "masks"]
     DEAULT_INPUTS = {"boxes": "boxes", "masks": "masks", "img_size": "ori_img_size"}
     DEAULT_OUTPUTS = {
@@ -54,13 +58,32 @@ class InstanceSegPostProcess(BaseComponent):
         self.threshold = threshold
         self.labels = labels
 
-    def apply(self, boxes, masks, img_size):
+    def apply(self, masks, img_size, boxes=None, class_id=None):
         """apply"""
-        expect_boxes = (boxes[:, 1] > self.threshold) & (boxes[:, 0] > -1)
-        boxes = boxes[expect_boxes, :]
-        boxes = restructured_boxes(boxes, self.labels, img_size)
-        masks = masks[expect_boxes, :, :]
-        masks = extract_masks_from_boxes(boxes, masks)
-        result = {"boxes": boxes, "masks": masks}
+        if boxes is not None:
+            expect_boxes = (boxes[:, 1] > self.threshold) & (boxes[:, 0] > -1)
+            boxes = boxes[expect_boxes, :]
+            boxes = restructured_boxes(boxes, self.labels, img_size)
+            masks = masks[expect_boxes, :, :]
+            masks = extract_masks_from_boxes(boxes, masks)
+            result = {"boxes": boxes, "masks": masks}
+        else:
+            mask_info = []
+            class_id = [list(item) for item in zip(class_id[0], class_id[1])]
+
+            selected_masks = []
+            for i, info in enumerate(class_id):
+                label_id = int(info[0])
+                if info[1] < self.threshold:
+                    continue
+                mask_info.append(
+                    {
+                        "label": self.labels[label_id],
+                        "score": info[1],
+                        "class_id": label_id,
+                    }
+                )
+                selected_masks.append(masks[i])
+            result = {"boxes": mask_info, "masks": selected_masks}
 
         return result

+ 1 - 1
paddlex/inference/components/transforms/image/common.py

@@ -306,7 +306,7 @@ class Resize(_BaseResize):
 
         if self.keep_ratio:
             h, w = img.shape[0:2]
-            target_size, _ = self._rescale_size((w, h), self.target_size)
+            target_size, _ = self._rescale_size((h, w), self.target_size)
 
         if self.size_divisor:
             target_size = [

+ 16 - 9
paddlex/inference/models/instance_segmentation.py

@@ -40,20 +40,27 @@ class InstanceSegPredictor(DetPredictor):
             model_prefix=self.MODEL_FILE_PREFIX,
             option=self.pp_option,
         )
-        if "RT-DETR" in self.model_name:
+        model_names = ["RT-DETR", "SOLOv2", "RCNN"]
+        if any(name in self.model_name for name in model_names):
             predictor.set_inputs(
                 {"img": "img", "scale_factors": "scale_factors", "img_size": "img_size"}
             )
-        self._add_component(
-            [
-                predictor,
-                InstanceSegPostProcess(
-                    threshold=self.config["draw_threshold"],
-                    labels=self.config["label_list"],
-                ),
-            ]
+
+        postprecss = InstanceSegPostProcess(
+            threshold=self.config["draw_threshold"],
+            labels=self.config["label_list"],
         )
 
+        if "SOLOv2" in self.model_name:
+            postprecss.set_inputs(
+                {
+                    "class_id": "class_id",
+                    "masks": "masks",
+                    "img_size": "img_size",
+                }
+            )
+        self._add_component([predictor, postprecss])
+
     def _pack_res(self, single):
         keys = ["img_path", "boxes", "masks"]
         return InstanceSegResult({key: single[key] for key in keys})

+ 6 - 1
paddlex/inference/models/object_detection.py

@@ -43,7 +43,8 @@ class DetPredictor(BasicPredictor):
             model_prefix=self.MODEL_FILE_PREFIX,
             option=self.pp_option,
         )
-        if "DETR" in self.model_name or "RCNN" in self.model_name:
+        model_names = ["DETR", "RCNN", "YOLOv3", "CenterNet"]
+        if any(name in self.model_name for name in model_names):
             predictor.set_inputs(
                 {
                     "img": "img",
@@ -111,6 +112,10 @@ class DetPredictor(BasicPredictor):
     def build_pad_stride(self, stride=32):
         return PadStride(stride=stride)
 
+    @register("WarpAffine")
+    def build_warp_affine(self, input_h=512, input_w=512, keep_res=True):
+        return WarpAffine(input_h=input_h, input_w=input_w, keep_res=keep_res)
+
     def _pack_res(self, single):
         keys = ["img_path", "boxes"]
         return DetResult({key: single[key] for key in keys})

+ 64 - 5
paddlex/inference/results/instance_seg.py

@@ -18,6 +18,7 @@ import numpy as np
 import math
 import copy
 import json
+import cv2
 import PIL
 from PIL import Image, ImageDraw, ImageFont
 
@@ -29,6 +30,62 @@ from .base import BaseResult
 from .det import draw_box
 
 
+def draw_segm(im, masks, mask_info, alpha=0.7):
+    """
+    Draw segmentation on image
+    """
+    mask_color_id = 0
+    w_ratio = 0.4
+    color_list = get_colormap(rgb=True)
+    im = np.array(im).astype("float32")
+    clsid2color = {}
+    masks = np.array(masks)
+    masks = masks.astype(np.uint8)
+    for i in range(masks.shape[0]):
+        mask, score, clsid = masks[i], mask_info[i]["score"], mask_info[i]["class_id"]
+
+        if clsid not in clsid2color:
+            color_index = i % len(color_list)
+            clsid2color[clsid] = color_list[color_index]
+        color_mask = clsid2color[clsid]
+        for c in range(3):
+            color_mask[c] = color_mask[c] * (1 - w_ratio) + w_ratio * 255
+        idx = np.nonzero(mask)
+        color_mask = np.array(color_mask)
+        idx0 = np.minimum(idx[0], im.shape[0] - 1)
+        idx1 = np.minimum(idx[1], im.shape[1] - 1)
+        im[idx0, idx1, :] *= 1.0 - alpha
+        im[idx0, idx1, :] += alpha * color_mask
+        sum_x = np.sum(mask, axis=0)
+        x = np.where(sum_x > 0.5)[0]
+        sum_y = np.sum(mask, axis=1)
+        y = np.where(sum_y > 0.5)[0]
+        x0, x1, y0, y1 = x[0], x[-1], y[0], y[-1]
+        cv2.rectangle(
+            im, (x0, y0), (x1, y1), tuple(color_mask.astype("int32").tolist()), 1
+        )
+        bbox_text = "%s %.2f" % (mask_info[i]["label"], score)
+        t_size = cv2.getTextSize(bbox_text, 0, 0.3, thickness=1)[0]
+        cv2.rectangle(
+            im,
+            (x0, y0),
+            (x0 + t_size[0], y0 - t_size[1] - 3),
+            tuple(color_mask.astype("int32").tolist()),
+            -1,
+        )
+        cv2.putText(
+            im,
+            bbox_text,
+            (x0, y0 - 2),
+            cv2.FONT_HERSHEY_SIMPLEX,
+            0.3,
+            (0, 0, 0),
+            1,
+            lineType=cv2.LINE_AA,
+        )
+    return Image.fromarray(im.astype("uint8"))
+
+
 def restore_to_draw_masks(img_size, boxes, masks):
     """
     Restores extracted masks to the original shape and draws them on a blank image.
@@ -90,15 +147,17 @@ class InstanceSegResult(BaseResult):
 
     def _get_res_img(self):
         """apply"""
-        boxes = np.array(self["boxes"])
-        masks = self["masks"]
         img_path = self["img_path"]
         file_name = os.path.basename(img_path)
-
         image = self._img_reader.read(img_path)
         ori_img_size = list(image.size)[::-1]
-        image = draw_mask(image, boxes, masks, ori_img_size)
-        image = draw_box(image, boxes)
+        boxes = self["boxes"]
+        masks = self["masks"]
+        if next((True for item in self["boxes"] if "coordinate" in item), False):
+            image = draw_mask(image, boxes, masks, ori_img_size)
+            image = draw_box(image, boxes)
+        else:
+            image = draw_segm(image, masks, boxes)
 
         return image
 

+ 11 - 0
paddlex/inference/results/warp.py

@@ -13,6 +13,10 @@
 # limitations under the License.
 
 import numpy as np
+import copy
+import json
+
+from ...utils import logging
 from .base import BaseResult
 
 
@@ -25,3 +29,10 @@ class DocTrResult(BaseResult):
     def _get_res_img(self):
         doctr_img = np.array(self["doctr_img"])
         return doctr_img
+
+    def print(self, json_format=True, indent=4, ensure_ascii=False):
+        str_ = copy.deepcopy(self)
+        del str_["doctr_img"]
+        if json_format:
+            str_ = json.dumps(str_, indent=indent, ensure_ascii=ensure_ascii)
+        logging.info(str_)