FlyingQianMM пре 4 година
родитељ
комит
428c82a687

+ 3 - 3
dygraph/examples/meter_reader/README.md

@@ -102,7 +102,7 @@ meter_test/
 
 PaddleX提供了丰富的视觉模型,在目标检测中提供了RCNN和YOLO系列模型,在语义分割中提供了DeepLabV3P和BiSeNetV2等模型。
 
-因最终部署场景是本地化的服务器GPU端,算力相对充足,因此在本项目中采用精度和预测性能的PPYOLOV2进行表计检测。
+因最终部署场景是本地化的服务器GPU端,算力相对充足,因此在本项目中采用精度和预测性能皆优的PPYOLOV2进行表计检测。
 
 考虑到指针和刻度均为细小区域,我们采用精度更优的DeepLabV3P进行指针和刻度的分割。
 
@@ -221,7 +221,7 @@ eval_transforms = T.Compose([
 * 定义数据集路径
 
 ```python
-# 下载和解压指针刻度分割数据集,如果已经预先下载,可注掉下面两行
+# 下载和解压指针刻度分割数据集,如果已经预先下载,可注掉下面两行
 meter_seg_dataset = 'https://bj.bcebos.com/paddlex/examples/meter_reader/datasets/meter_seg.tar.gz'
 pdx.utils.download_and_decompress(meter_seg_dataset, path='./')
 
@@ -295,7 +295,7 @@ def predict(self,
             erode_kernel=4,
             score_threshold=0.5,
             seg_batch_size=2):
-    """检测图像中的表盘,而后分割出各表盘中的指针和刻度,对分割结果进行读数后后得到各表盘的读数。
+    """检测图像中的表盘,而后分割出各表盘中的指针和刻度,对分割结果进行读数后处理后得到各表盘的读数。
 
 
         参数:

+ 67 - 0
dygraph/examples/meter_reader/deploy/cpp/meter_reader/include/meter_config.h

@@ -0,0 +1,67 @@
+// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//     http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#pragma once
+
+#include <vector>
+#include <string>
+#include <map>
+
+struct MeterConfig {
+  float scale_interval_value_;
+  float range_;
+  std::string unit_;
+
+  MeterConfig() {}
+
+  MeterConfig(const float &scale_interval_value,
+              const float &range,
+              const std::string &unit) :
+    scale_interval_value_(scale_interval_value),
+    range_(range), unit_(unit) {}
+};
+
+struct MeterResult {
+  // the number of scales
+  int num_scales_;
+  // the pointer location relative to the scales
+  float pointed_scale_;
+
+  MeterResult() {}
+
+  MeterResult(const int &num_scales, const float &pointed_scale) :
+    num_scales_(num_scales), pointed_scale_(pointed_scale) {}
+};
+
+// The size of inputting images of the segmenter.
+extern std::vector<int> METER_SHAPE;  // height x width
+// Center of a circular meter
+extern std::vector<int> CIRCLE_CENTER;  // height x width
+// Radius of a circular meter
+extern int CIRCLE_RADIUS;
+extern float PI;
+
+// During the postprocess phase, annulus formed by the radius from
+// 130 to 250 of a circular meter will be converted to a rectangle.
+// So the height of the rectangle is 120.
+extern int RECTANGLE_HEIGHT;
+// The width of the rectangle is 1570, that is to say the perimeter
+// of a circular meter.
+extern int RECTANGLE_WIDTH;
+
+// The configuration information of a meter,
+// composed of scale value, range, unit
+extern int TYPE_THRESHOLD;
+extern std::vector<MeterConfig> METER_CONFIG;
+extern std::map<std::string, uint8_t> SEG_CNAME2CLSID;

+ 61 - 0
dygraph/examples/meter_reader/deploy/cpp/meter_reader/include/reader_postprocess.h

@@ -0,0 +1,61 @@
+// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//     http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+
+#pragma once
+
+#include <vector>
+
+#include "pipeline/include/pipeline.h"
+#include "meter_reader/include/meter_config.h"
+
+bool Erode(const int32_t &kernel_size,
+           const std::vector<PaddleDeploy::Result> &seg_results,
+           std::vector<std::vector<uint8_t>> *seg_label_maps);
+
+bool CircleToRectangle(
+  const std::vector<uint8_t> &seg_label_map,
+  std::vector<uint8_t> *rectangle_meter);
+
+bool RectangleToLine(const std::vector<uint8_t> &rectangle_meter,
+                     std::vector<int> *line_scale,
+                     std::vector<int> *line_pointer);
+
+bool MeanBinarization(const std::vector<int> &data,
+                      std::vector<int> *binaried_data);
+
+bool LocateScale(const std::vector<int> &scale,
+                 std::vector<float> *scale_location);
+
+bool LocatePointer(const std::vector<int> &pointer,
+                   float *pointer_location);
+
+bool GetRelativeLocation(
+  const std::vector<float> &scale_location,
+  const float &pointer_location,
+  MeterResult *result);
+
+bool CalculateReading(const MeterResult &result,
+                      float *reading);
+
+bool PrintMeterReading(const std::vector<float> &readings);
+
+bool Visualize(const cv::Mat& img,
+               const PaddleDeploy::Result &det_result,
+               const std::vector<float> &reading,
+               cv::Mat* vis_img);
+
+bool GetMeterReading(
+  const std::vector<std::vector<uint8_t>> &seg_label_maps,
+  std::vector<float> *readings);

+ 78 - 0
dygraph/examples/meter_reader/deploy/cpp/meter_reader/meter_reader.cpp

@@ -0,0 +1,78 @@
+// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//     http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include <gflags/gflags.h>
+#include <string>
+#include <iostream>
+#include <vector>
+#include <utility>
+#include <limits>
+
+#include <opencv2/opencv.hpp>
+#include <opencv2/highgui.hpp>
+#include <opencv2/core/core.hpp>
+
+#include "pipeline/include/pipeline.h"
+#include "meter_reader/include/reader_postprocess.h"
+
+DEFINE_string(pipeline_cfg, "", "Path of pipeline config file");
+DEFINE_bool(use_erode, true, "Eroding predicted label map");
+DEFINE_int32(erode_kernel, 4, "Eroding kernel size");
+DEFINE_string(image, "", "Path of test image file");
+DEFINE_string(save_dir, "", "Path to save visualized results");
+
+int main(int argc, char **argv) {
+  google::ParseCommandLineFlags(&argc, &argv, true);
+  if (FLAGS_pipeline_cfg == "") {
+    std::cerr << "--pipeline_cfg need to be defined" << std::endl;
+    return -1;
+  }
+  if (FLAGS_image == "") {
+    std::cerr << "--image need to be defined "
+              << "when the camera is not been used" << std::endl;
+    return -1;
+  }
+
+  std::vector<std::string> image_paths = {FLAGS_image};
+  PaddleXPipeline::Pipeline pipeline;
+  if (pipeline.Init(FLAGS_pipeline_cfg)) {
+    pipeline.SetInput("src0", image_paths);
+    pipeline.Run();
+    std::vector<PaddleDeploy::Result> det_results;
+    std::vector<PaddleDeploy::Result> seg_results;
+    pipeline.GetOutput("sink0", &det_results);
+    pipeline.GetOutput("sink1", &seg_results);
+
+    // Do image erosion for the predicted label map of each meter
+    std::vector<std::vector<uint8_t>> seg_label_maps;
+    Erode(FLAGS_erode_kernel, seg_results, &seg_label_maps);
+
+    // The postprocess are done to get the reading or each meter
+    std::vector<float> readings;
+    GetMeterReading(seg_label_maps, &readings);
+    PrintMeterReading(readings);
+    if (FLAGS_save_dir != "") {
+      cv::Mat img = cv::imread(FLAGS_image);
+      cv::Mat vis_img;
+      Visualize(img, det_results[0], readings, &vis_img);
+      std::string save_path;
+      if (PaddleXPipeline::GenerateSavePath(
+          FLAGS_save_dir, FLAGS_image, &save_path)) {
+         cv::imwrite(save_path, vis_img);
+      }
+    }
+  }
+
+  return 0;
+}

+ 33 - 0
dygraph/examples/meter_reader/deploy/cpp/meter_reader/src/meter_config.cpp

@@ -0,0 +1,33 @@
+// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//     http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+
+#include "meter_reader/include/meter_config.h"
+
+std::vector<int> METER_SHAPE = {512, 512};  // height x width
+std::vector<int> CIRCLE_CENTER = {256, 256};
+int CIRCLE_RADIUS = 250;
+float PI = 3.1415926536;
+int RECTANGLE_HEIGHT = 120;
+int RECTANGLE_WIDTH = 1570;
+
+int TYPE_THRESHOLD = 40;
+std::vector<MeterConfig> METER_CONFIG = {
+  MeterConfig(25.0f/50.0f, 25.0f, "(MPa)"),
+  MeterConfig(1.6f/32.0f,  1.6f,   "(MPa)")
+};
+
+std::map<std::string, uint8_t> SEG_CNAME2CLSID = {
+  {"background", 0}, {"pointer", 1}, {"scale", 2}
+};

+ 264 - 0
dygraph/examples/meter_reader/deploy/cpp/meter_reader/src/reader_postprocess.cpp

@@ -0,0 +1,264 @@
+// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//     http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+
+#include <iostream>
+#include <vector>
+#include <utility>
+#include <limits>
+#include <cmath>
+#include <string>
+
+#include <opencv2/opencv.hpp>
+#include <opencv2/highgui.hpp>
+#include <opencv2/core/core.hpp>
+
+#include "meter_reader/include/reader_postprocess.h"
+
+bool Erode(const int32_t &kernel_size,
+           const std::vector<PaddleDeploy::Result> &seg_results,
+           std::vector<std::vector<uint8_t>> *seg_label_maps) {
+  cv::Mat kernel(kernel_size, kernel_size, CV_8U, cv::Scalar(1));
+  for (auto result : seg_results) {
+    std::vector<uint8_t> label_map(result.seg_result->label_map.data.begin(),
+                                   result.seg_result->label_map.data.end());
+    cv::Mat mask(result.seg_result->label_map.shape[0],
+                 result.seg_result->label_map.shape[1],
+                 CV_8UC1,
+                 label_map.data());
+    cv::erode(mask, mask, kernel);
+    std::vector<uint8_t> map;
+    if (mask.isContinuous()) {
+        map.assign(mask.data, mask.data + mask.total() * mask.channels());
+    } else {
+      for (int r = 0; r < mask.rows; r++) {
+        map.insert(map.end(),
+                   mask.ptr<int64_t>(r),
+                   mask.ptr<int64_t>(r) + mask.cols * mask.channels());
+      }
+    }
+    seg_label_maps->push_back(map);
+  }
+  return true;
+}
+
+
+bool CircleToRectangle(
+  const std::vector<uint8_t> &seg_label_map,
+  std::vector<uint8_t> *rectangle_meter) {
+  float theta;
+  int rho;
+  int image_x;
+  int image_y;
+
+  // The minimum scale value is at the bottom left, the maximum scale value
+  // is at the bottom right, so the vertical down axis is the starting axis and
+  // rotates around the meter ceneter counterclockwise.
+  *rectangle_meter =
+    std::vector<uint8_t> (RECTANGLE_WIDTH * RECTANGLE_HEIGHT, 0);
+  for (int row = 0; row < RECTANGLE_HEIGHT; row++) {
+    for (int col = 0; col < RECTANGLE_WIDTH; col++) {
+      theta = PI * 2 / RECTANGLE_WIDTH * (col + 1);
+      rho = CIRCLE_RADIUS - row - 1;
+      int y = static_cast<int>(CIRCLE_CENTER[0] + rho * cos(theta) + 0.5);
+      int x = static_cast<int>(CIRCLE_CENTER[1] - rho * sin(theta) + 0.5);
+      (*rectangle_meter)[row * RECTANGLE_WIDTH + col] =
+        seg_label_map[y * METER_SHAPE[1] + x];
+    }
+  }
+
+  return true;
+}
+
+bool RectangleToLine(const std::vector<uint8_t> &rectangle_meter,
+                     std::vector<int> *line_scale,
+                     std::vector<int> *line_pointer) {
+  // Accumulte the number of positions whose label is 1 along the height axis.
+  // Accumulte the number of positions whose label is 2 along the height axis.
+  (*line_scale) = std::vector<int> (RECTANGLE_WIDTH, 0);
+  (*line_pointer) = std::vector<int> (RECTANGLE_WIDTH, 0);
+  for (int col = 0; col < RECTANGLE_WIDTH; col++) {
+    for (int row = 0; row < RECTANGLE_HEIGHT; row++) {
+        if (rectangle_meter[row * RECTANGLE_WIDTH + col] ==
+          static_cast<uint8_t>(SEG_CNAME2CLSID["pointer"])) {
+            (*line_pointer)[col]++;
+        } else if (rectangle_meter[row * RECTANGLE_WIDTH + col] ==
+          static_cast<uint8_t>(SEG_CNAME2CLSID["scale"])) {
+            (*line_scale)[col]++;
+        }
+    }
+  }
+  return true;
+}
+
+bool MeanBinarization(const std::vector<int> &data,
+                      std::vector<int> *binaried_data) {
+  int sum = 0;
+  float mean = 0;
+  for (auto i = 0; i < data.size(); i++) {
+    sum = sum + data[i];
+  }
+  mean = static_cast<float>(sum) / static_cast<float>(data.size());
+
+  for (auto i = 0; i < data.size(); i++) {
+    if (static_cast<float>(data[i]) >= mean) {
+      binaried_data->push_back(1);
+    } else {
+      binaried_data->push_back(0);
+    }
+  }
+  return  true;
+}
+
+bool LocateScale(const std::vector<int> &scale,
+                 std::vector<float> *scale_location) {
+  float one_scale_location = 0;
+  bool find_start = false;
+  int one_scale_start = 0;
+  int one_scale_end = 0;
+
+  for (int i = 0; i < RECTANGLE_WIDTH; i++) {
+    // scale location
+    if (scale[i] > 0 && scale[i + 1] > 0) {
+      if (!find_start) {
+        one_scale_start = i;
+        find_start = true;
+      }
+    }
+    if (find_start) {
+      if (scale[i] == 0 && scale[i + 1] == 0) {
+          one_scale_end = i - 1;
+          one_scale_location = (one_scale_start + one_scale_end) / 2.;
+          scale_location->push_back(one_scale_location);
+          one_scale_start = 0;
+          one_scale_end = 0;
+          find_start = false;
+      }
+    }
+  }
+  return true;
+}
+
+bool LocatePointer(const std::vector<int> &pointer,
+                   float *pointer_location) {
+  bool find_start = false;
+  int one_pointer_start = 0;
+  int one_pointer_end = 0;
+
+  for (int i = 0; i < RECTANGLE_WIDTH; i++) {
+    // pointer location
+    if (pointer[i] > 0 && pointer[i + 1] > 0) {
+      if (!find_start) {
+        one_pointer_start = i;
+        find_start = true;
+      }
+    }
+    if (find_start) {
+      if ((pointer[i] == 0) && (pointer[i+1] == 0)) {
+        one_pointer_end = i - 1;
+        *pointer_location = (one_pointer_start + one_pointer_end) / 2.;
+        one_pointer_start = 0;
+        one_pointer_end = 0;
+        find_start = false;
+        break;
+      }
+    }
+  }
+  return true;
+}
+
+bool GetRelativeLocation(
+  const std::vector<float> &scale_location,
+  const float &pointer_location,
+  MeterResult *result) {
+  int num_scales = static_cast<int>(scale_location.size());
+  result->num_scales_ = num_scales;
+  result->pointed_scale_ = -1;
+  if (num_scales > 0) {
+    for (auto i = 0; i < num_scales - 1; i++) {
+      if (scale_location[i] <= pointer_location &&
+            pointer_location < scale_location[i + 1]) {
+        result->pointed_scale_ = i + 1 +
+          (pointer_location-scale_location[i]) /
+          (scale_location[i+1]-scale_location[i] + 1e-05);
+      }
+    }
+  }
+  return true;
+}
+
+bool CalculateReading(const MeterResult &result,
+                      float *reading) {
+  // Provide a digital readout according to point location relative
+  // to the scales
+  if (result.num_scales_ > TYPE_THRESHOLD) {
+    *reading = result.pointed_scale_ * METER_CONFIG[0].scale_interval_value_;
+  } else {
+    *reading = result.pointed_scale_ * METER_CONFIG[1].scale_interval_value_;
+  }
+  return true;
+}
+
+bool PrintMeterReading(const std::vector<float> &readings) {
+  for (auto i = 0; i < readings.size(); ++i) {
+    std::cout << "Meter " << i + 1 << ": " << readings[i] << std::endl;
+  }
+  return true;
+}
+
+bool Visualize(const cv::Mat& img,
+               const PaddleDeploy::Result &det_result,
+               const std::vector<float> &reading,
+               cv::Mat* vis_img) {
+  for (auto i = 0; i < det_result.det_result->boxes.size(); ++i) {
+     std::string category = std::to_string(reading[i]);
+     det_result.det_result->boxes[i].category = category;
+  }
+
+  PaddleDeploy::Visualize(img, *(det_result.det_result), vis_img);
+  return true;
+}
+
+bool GetMeterReading(
+  const std::vector<std::vector<uint8_t>> &seg_label_maps,
+  std::vector<float> *readings) {
+  for (auto i = 0; i < seg_label_maps.size(); i++) {
+    std::vector<uint8_t> rectangle_meter;
+    CircleToRectangle(seg_label_maps[i], &rectangle_meter);
+
+    std::vector<int> line_scale;
+    std::vector<int> line_pointer;
+    RectangleToLine(rectangle_meter, &line_scale, &line_pointer);
+
+    std::vector<int> binaried_scale;
+    MeanBinarization(line_scale, &binaried_scale);
+    std::vector<int> binaried_pointer;
+    MeanBinarization(line_pointer, &binaried_pointer);
+
+    std::vector<float> scale_location;
+    LocateScale(binaried_scale, &scale_location);
+
+    float pointer_location;
+    LocatePointer(binaried_pointer, &pointer_location);
+
+    MeterResult result;
+    GetRelativeLocation(
+      scale_location, pointer_location, &result);
+
+    float reading;
+    CalculateReading(result, &reading);
+    readings->push_back(reading);
+  }
+  return true;
+}

+ 1 - 15
dygraph/examples/meter_reader/reader_infer.py

@@ -528,21 +528,7 @@ class MeterReader:
                 erode_kernel=4,
                 score_threshold=0.5,
                 seg_batch_size=2):
-        """Detect meters in a image, segment scales and points in these meters, the postprocess are
-        done to provide a digital readout according to scale and point location.
-
-        Args:
-            im_file (str):  the path of a image to be predicted.
-            save_dir (str): the directory to save the visual prediction. Default: './'.
-            use_erode (bool, optional): whether to do image erosion by using a specific structuring element for
-                the label map output from the segmenter. Default: True.
-            erode_kernel (int, optional): structuring element used for erosion. Default: 4.
-            score_threshold (float, optional): detected meters whose scores are not lower than `score_threshold`
-                will be fed into the following segmenter. Default: 0.5.
-            seg_batch_size (int, optional): batch size of meters when do segmentation. Default: 2.
-
-        """
-        """检测图像中的表盘,而后分割出各表盘中的指针和刻度,对分割结果进行读数后厨后得到各表盘的读数。
+        """检测图像中的表盘,而后分割出各表盘中的指针和刻度,对分割结果进行读数后处理后得到各表盘的读数。
 
 
         参数:

+ 2 - 2
dygraph/examples/meter_reader/train_segmentation.py

@@ -16,7 +16,7 @@ eval_transforms = T.Compose([
         mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
 ])
 
-# 下载和解压指针刻度分割数据集,如果已经预先下载,可注掉下面两行
+# 下载和解压指针刻度分割数据集,如果已经预先下载,可注掉下面两行
 meter_seg_dataset = 'https://bj.bcebos.com/paddlex/examples/meter_reader/datasets/meter_seg.tar.gz'
 pdx.utils.download_and_decompress(meter_seg_dataset, path='./')
 
@@ -48,7 +48,7 @@ model.train(
     num_epochs=20,
     train_dataset=train_dataset,
     train_batch_size=4,
-    pretrain_weights='IMAGENET',
+    #pretrain_weights='IMAGENET',
     eval_dataset=eval_dataset,
     learning_rate=0.1,
     save_dir='output/deeplabv3p_r50vd')