Преглед изворни кода

optimize visualize&preprocess speed

Channingss пре 5 година
родитељ
комит
d4986632dd

+ 21 - 11
deploy/lite/android/sdk/src/main/java/com/baidu/paddlex/preprocess/Transforms.java

@@ -23,6 +23,7 @@ import org.opencv.core.Scalar;
 import org.opencv.core.Size;
 import org.opencv.imgproc.Imgproc;
 import java.util.ArrayList;
+import java.util.Date;
 import java.util.HashMap;
 import java.util.List;
 
@@ -101,6 +102,15 @@ public class Transforms {
                 if (info.containsKey("coarsest_stride")) {
                     padding.coarsest_stride = (int) info.get("coarsest_stride");
                 }
+                if (info.containsKey("im_padding_value")) {
+                    List<Double> im_padding_value = (List<Double>) info.get("im_padding_value");
+                    if (im_padding_value.size()!=3){
+                        Log.e(TAG, "len of im_padding_value in padding must == 3.");
+                    }
+                    for (int k =0; i<im_padding_value.size(); i++){
+                        padding.paddding_value[k] = im_padding_value.get(k);
+                    }
+                }
                 if (info.containsKey("target_size")) {
                     if (info.get("target_size") instanceof Integer) {
                         padding.width = (int) info.get("target_size");
@@ -124,7 +134,7 @@ public class Transforms {
         if(transformsMode.equalsIgnoreCase("RGB")){
             Imgproc.cvtColor(inputMat, inputMat, Imgproc.COLOR_BGR2RGB);
         }else if(!transformsMode.equalsIgnoreCase("BGR")){
-            Log.e(TAG, "transformsMode only support RGB or BGR");
+            Log.e(TAG, "transformsMode only support RGB or BGR.");
         }
         inputMat.convertTo(inputMat, CvType.CV_32FC(3));
 
@@ -136,16 +146,15 @@ public class Transforms {
         int h = inputMat.height();
         int c = inputMat.channels();
         imageBlob.setImageData(new float[w * h * c]);
-        int[] channelStride = new int[]{w * h, w * h * 2};
-        for (int y = 0; y < h; y++) {
-            for (int x = 0;
-                 x < w; x++) {
-                double[] color = inputMat.get(y, x);
-                imageBlob.getImageData()[y * w + x]  =  (float) (color[0]);
-                imageBlob.getImageData()[y * w + x +  channelStride[0]] = (float) (color[1]);
-                imageBlob.getImageData()[y * w + x +  channelStride[1]] = (float) (color[2]);
-            }
+
+        Mat singleChannelMat = new Mat(h, w, CvType.CV_32FC(1));
+        float[] singleChannelImageData = new float[w * h];
+        for (int i = 0; i < c; i++) {
+            Core.extractChannel(inputMat, singleChannelMat, i);
+            singleChannelMat.get(0, 0, singleChannelImageData);
+            System.arraycopy(singleChannelImageData ,0, imageBlob.getImageData(),i*w*h, w*h);
         }
+
         return imageBlob;
     }
 
@@ -248,6 +257,7 @@ public class Transforms {
         private double width;
         private double height;
         private double coarsest_stride;
+        private double[] paddding_value = {0.0, 0.0, 0.0};
 
         public Mat run(Mat inputMat, ImageBlob imageBlob) {
             int origin_w = inputMat.width();
@@ -264,7 +274,7 @@ public class Transforms {
             }
             imageBlob.setNewImageSize(inputMat.height(),2);
             imageBlob.setNewImageSize(inputMat.width(),3);
-            Core.copyMakeBorder(inputMat, inputMat, 0, (int)padding_h, 0, (int)padding_w, Core.BORDER_CONSTANT, new Scalar(0));
+            Core.copyMakeBorder(inputMat, inputMat, 0, (int)padding_h, 0, (int)padding_w, Core.BORDER_CONSTANT, new Scalar(paddding_value));
             return inputMat;
         }
     }

+ 6 - 5
deploy/lite/android/sdk/src/main/java/com/baidu/paddlex/visual/Visualize.java

@@ -31,8 +31,11 @@ import org.opencv.core.Scalar;
 import org.opencv.core.Size;
 import org.opencv.imgproc.Imgproc;
 
+import java.nio.ByteBuffer;
+import java.nio.FloatBuffer;
 import java.util.ArrayList;
 import java.util.Arrays;
+import java.util.Date;
 import java.util.List;
 import java.util.ListIterator;
 import java.util.Map;
@@ -120,13 +123,11 @@ public class Visualize {
         int new_w = (int)imageBlob.getNewImageSize()[3];
         Mat mask = new Mat(new_h, new_w, CvType.CV_32FC(1));
         float[] scoreData = new float[new_h*new_w];
-        for  (int h = 0; h < new_h; h++) {
-            for  (int w = 0; w < new_w; w++){
-                scoreData[new_h * h + w] =  (1-result.getMask().getScoreData()[cutoutClass + h * new_h + w]) * 255;
-            }
-        }
+        System.arraycopy(result.getMask().getScoreData() ,cutoutClass*new_h*new_w, scoreData ,0, new_h*new_w);
         mask.put(0,0, scoreData);
+        Core.multiply(mask, new Scalar(255), mask);
         mask.convertTo(mask,CvType.CV_8UC(1));
+
         ListIterator<Map.Entry<String, int[]>> reverseReshapeInfo = new ArrayList<Map.Entry<String, int[]>>(imageBlob.getReshapeInfo().entrySet()).listIterator(imageBlob.getReshapeInfo().size());
         while (reverseReshapeInfo.hasPrevious()) {
             Map.Entry<String, int[]> entry = reverseReshapeInfo.previous();