Browse Source

rewrite postprocess for rcnn predict

FlyingQianMM 5 năm trước cách đây
mục cha
commit
84e25e9fb8
3 tập tin đã thay đổi với 78 bổ sung38 xóa
  1. 15 12
      paddlex/cv/models/faster_rcnn.py
  2. 19 15
      paddlex/cv/models/mask_rcnn.py
  3. 44 11
      paddlex/deploy.py

+ 15 - 12
paddlex/cv/models/faster_rcnn.py

@@ -409,14 +409,7 @@ class FasterRCNN(BaseAPI):
         return im, im_resize_info, im_shape
 
     @staticmethod
-    def _postprocess(results, test_outputs_keys, batch_size, num_classes,
-                     labels):
-        res = {
-            k: (np.array(v), v.recursive_sequence_lengths())
-            for k, v in zip(list(test_outputs_keys), results)
-        }
-        res['im_id'] = (np.array(
-            [[i] for i in range(batch_size)]).astype('int32'), [])
+    def _postprocess(res, batch_size, num_classes, labels):
         clsid2catid = dict({i: i for i in range(num_classes)})
         xywh_results = bbox2out([res], clsid2catid)
         preds = [[] for i in range(batch_size)]
@@ -463,8 +456,13 @@ class FasterRCNN(BaseAPI):
                                   return_numpy=False,
                                   use_program_cache=True)
 
-        preds = FasterRCNN._postprocess(result,
-                                        list(self.test_outputs.keys()),
+        res = {
+            k: (np.array(v), v.recursive_sequence_lengths())
+            for k, v in zip(list(test_outputs_keys), results)
+        }
+        res['im_id'] = (np.array(
+            [[i] for i in range(batch_size)]).astype('int32'), [])
+        preds = FasterRCNN._postprocess(res,
                                         len(images), self.num_classes,
                                         self.labels)
 
@@ -507,8 +505,13 @@ class FasterRCNN(BaseAPI):
                                   return_numpy=False,
                                   use_program_cache=True)
 
-        preds = FasterRCNN._postprocess(result,
-                                        list(self.test_outputs.keys()),
+        res = {
+            k: (np.array(v), v.recursive_sequence_lengths())
+            for k, v in zip(list(test_outputs_keys), results)
+        }
+        res['im_id'] = (np.array(
+            [[i] for i in range(batch_size)]).astype('int32'), [])
+        preds = FasterRCNN._postprocess(res,
                                         len(img_file_list), self.num_classes,
                                         self.labels)
 

+ 19 - 15
paddlex/cv/models/mask_rcnn.py

@@ -338,15 +338,8 @@ class MaskRCNN(FasterRCNN):
         return metrics
 
     @staticmethod
-    def _postprocess(results, im_shape, test_outputs_keys, batch_size,
-                     num_classes, mask_head_resolution, labels):
-        res = {
-            k: (np.array(v), v.recursive_sequence_lengths())
-            for k, v in zip(list(test_outputs_keys), results)
-        }
-        res['im_id'] = (np.array(
-            [[i] for i in range(batch_size)]).astype('int32'), [])
-        res['im_shape'] = (np.array(im_shape), [])
+    def _postprocess(res, batch_size, num_classes, mask_head_resolution,
+                     labels):
         clsid2catid = dict({i: i for i in range(num_classes)})
         xywh_results = bbox2out([res], clsid2catid)
         segm_results = mask2out([res], clsid2catid, mask_head_resolution)
@@ -398,8 +391,14 @@ class MaskRCNN(FasterRCNN):
                                   return_numpy=False,
                                   use_program_cache=True)
 
-        preds = MaskRCNN._postprocess(result, im_shape,
-                                      list(self.test_outputs.keys()),
+        res = {
+            k: (np.array(v), v.recursive_sequence_lengths())
+            for k, v in zip(list(test_outputs_keys), results)
+        }
+        res['im_id'] = (np.array(
+            [[i] for i in range(batch_size)]).astype('int32'), [])
+        res['im_shape'] = (np.array(im_shape), [])
+        preds = MaskRCNN._postprocess(res,
                                       len(images), self.num_classes,
                                       self.mask_head_resolution, self.labels)
 
@@ -442,9 +441,14 @@ class MaskRCNN(FasterRCNN):
                                   return_numpy=False,
                                   use_program_cache=True)
 
-        preds = MaskRCNN._postprocess(result, im_shape,
-                                      list(self.test_outputs.keys()),
-                                      len(img_file_list), self.num_classes,
+        res = {
+            k: (np.array(v), v.recursive_sequence_lengths())
+            for k, v in zip(list(test_outputs_keys), results)
+        }
+        res['im_id'] = (np.array(
+            [[i] for i in range(batch_size)]).astype('int32'), [])
+        res['im_shape'] = (np.array(im_shape), [])
+        preds = MaskRCNN._postprocess(res,
+                                      len(images), self.num_classes,
                                       self.mask_head_resolution, self.labels)
-
         return preds

+ 44 - 11
paddlex/deploy.py

@@ -155,23 +155,42 @@ class Predictor:
             res['im_info'] = im_info
         return res
 
-    def postprocess(self, results, topk=1, batch_size=1, im_shape=None):
+    def postprocess(self,
+                    results,
+                    topk=1,
+                    batch_size=1,
+                    im_shape=None,
+                    im_info=None):
+        def offset_to_lengths(lod):
+            offset = lod[0]
+            lengths = [
+                offset[i + 1] - offset[i] for i in range(len(offset) - 1)
+            ]
+            return [lengths]
+
         if self.model_type == "classifier":
             true_topk = min(self.num_classes, topk)
-            preds = BaseClassifier._postprocess(results, true_topk,
+            preds = BaseClassifier._postprocess([results[0][0]], true_topk,
                                                 self.labels)
         elif self.model_type == "detector":
+            res = {'bbox': (results[0][0], offset_to_lengths(results[0][1])), }
+            res['im_id'] = (np.array(
+                [[i] for i in range(batch_size)]).astype('int32'), [[]])
             if self.model_name == "YOLOv3":
-                preds = YOLOv3._postprocess(results, ['bbox'], batch_size,
-                                            self.num_classes, self.labels)
+                preds = YOLOv3._postprocess(res, batch_size, self.num_classes,
+                                            self.labels)
             elif self.model_name == "FasterRCNN":
-                preds = FasterRCNN._postprocess(results, ['bbox'], batch_size,
+                preds = FasterRCNN._postprocess(res, batch_size,
                                                 self.num_classes, self.labels)
             elif self.model_name == "MaskRCNN":
+                res['mask'] = (results[1][0], offset_to_lengths(results[1][1]))
+                res['im_shape'] = (im_shape, [])
                 preds = MaskRCNN._postprocess(
-                    results, ['bbox', 'mask'], batch_size, self.num_classes,
+                    res, batch_size, self.num_classes,
                     self.mask_head_resolution, self.labels)
-
+        elif self.model_type == "segmenter":
+            res = [results[0][0], results[1][0]]
+            preds = DeepLabv3p._postprocess(res, im_info)
         return preds
 
     def raw_predict(self, inputs):
@@ -191,7 +210,9 @@ class Predictor:
         output_results = list()
         for name in output_names:
             output_tensor = self.predictor.get_output_tensor(name)
-            output_results.append(output_tensor.copy_to_cpu())
+            output_tensor_lod = output_tensor.lod()
+            output_results.append(
+                [output_tensor.copy_to_cpu(), output_tensor_lod])
         return output_results
 
     def predict(self, image, topk=1):
@@ -207,8 +228,14 @@ class Predictor:
         model_pred = self.raw_predict(preprocessed_input)
         im_shape = None if 'im_shape' not in preprocessed_input else preprocessed_input[
             'im_shape']
+        im_info = None if 'im_info' not in preprocessed_input else preprocessed_input[
+            'im_info']
         results = self.postprocess(
-            model_pred, topk=topk, batch_size=1, im_shape=im_shape)
+            model_pred,
+            topk=topk,
+            batch_size=1,
+            im_shape=im_shape,
+            im_info=im_info)
 
         return results[0]
 
@@ -223,9 +250,15 @@ class Predictor:
         """
         preprocessed_input = self.preprocess(image_list)
         model_pred = self.raw_predict(preprocessed_input)
-        im_shape = None if 'im_shape' in preprocessed_input else preprocessed_input[
+        im_shape = None if 'im_shape' not in preprocessed_input else preprocessed_input[
             'im_shape']
+        im_info = None if 'im_info' not in preprocessed_input else preprocessed_input[
+            'im_info']
         results = self.postprocess(
-            model_pred, topk=topk, batch_size=1, im_shape=im_shape)
+            model_pred,
+            topk=topk,
+            batch_size=len(image_list),
+            im_shape=im_shape,
+            im_info=im_info)
 
         return results