Pārlūkot izejas kodu

revise as reviews

FlyingQianMM 5 gadi atpakaļ
vecāks
revīzija
bf4091bbcb
33 mainītis faili ar 698 papildinājumiem un 327 dzēšanām
  1. 6 2
      deploy/cpp/demo/classifier.cpp
  2. 6 2
      deploy/cpp/demo/detector.cpp
  3. 6 2
      deploy/cpp/demo/segmenter.cpp
  4. 3 1
      deploy/cpp/demo/video_classifier.cpp
  5. 3 1
      deploy/cpp/demo/video_detector.cpp
  6. 3 1
      deploy/cpp/demo/video_segmenter.cpp
  7. 6 0
      deploy/cpp/include/paddlex/transforms.h
  8. 4 1
      deploy/cpp/src/paddlex.cpp
  9. 46 34
      deploy/cpp/src/transforms.cpp
  10. 6 2
      deploy/openvino/demo/classifier.cpp
  11. 6 2
      deploy/openvino/demo/detector.cpp
  12. 6 2
      deploy/openvino/demo/segmenter.cpp
  13. 6 0
      deploy/openvino/include/paddlex/transforms.h
  14. 45 34
      deploy/openvino/src/transforms.cpp
  15. 77 0
      docs/apis/tools.md
  16. 1 2
      docs/apis/visualize.md
  17. 3 3
      docs/examples/industrial_quality_inspection/README.md
  18. 1 1
      docs/examples/industrial_quality_inspection/accuracy_improvement.md
  19. 10 10
      docs/examples/industrial_quality_inspection/dataset.md
  20. 1 0
      docs/index.rst
  21. 3 3
      examples/industrial_quality_inspection/README.md
  22. 1 1
      examples/industrial_quality_inspection/accuracy_improvement.md
  23. 183 95
      examples/industrial_quality_inspection/cal_tp_fp.py
  24. 128 100
      examples/industrial_quality_inspection/compare.py
  25. 10 10
      examples/industrial_quality_inspection/dataset.md
  26. 9 4
      examples/industrial_quality_inspection/params_analysis.py
  27. 1 2
      examples/industrial_quality_inspection/predict.py
  28. 58 0
      examples/industrial_quality_inspection/train_pruned_yolov3.py
  29. 4 4
      examples/industrial_quality_inspection/train_rcnn.py
  30. 4 3
      examples/industrial_quality_inspection/train_yolov3.py
  31. 0 2
      paddlex/cv/datasets/voc.py
  32. 1 1
      paddlex/cv/models/ppyolo.py
  33. 51 2
      paddlex/cv/models/utils/detection_eval.py

+ 6 - 2
deploy/cpp/demo/classifier.cpp

@@ -92,7 +92,9 @@ int main(int argc, char** argv) {
       for (int j = i; j < im_vec_size; ++j) {
         im_vec[j - i] = std::move(cv::imread(image_paths[j], 1));
       }
-      model.predict(im_vec, &results, thread_num);
+      if (!model.predict(im_vec, &results, thread_num)) {
+        return -1;
+      }
       for (int j = i; j < im_vec_size; ++j) {
         std::cout << "Path:" << image_paths[j]
                   << ", predict label: " << results[j - i].category
@@ -103,7 +105,9 @@ int main(int argc, char** argv) {
   } else {
     PaddleX::ClsResult result;
     cv::Mat im = cv::imread(FLAGS_image, 1);
-    model.predict(im, &result);
+    if (!model.predict(im, &result)) {
+      return -1;
+    }
     std::cout << "Predict label: " << result.category
               << ", label_id:" << result.category_id
               << ", score: " << result.score << std::endl;

+ 6 - 2
deploy/cpp/demo/detector.cpp

@@ -95,7 +95,9 @@ int main(int argc, char** argv) {
       for (int j = i; j < im_vec_size; ++j) {
         im_vec[j - i] = std::move(cv::imread(image_paths[j], 1));
       }
-      model.predict(im_vec, &results, thread_num);
+      if (!model.predict(im_vec, &results, thread_num)) {
+        return -1;
+      }
       // Output predicted bounding boxes
       for (int j = 0; j < im_vec_size - i; ++j) {
         for (int k = 0; k < results[j].boxes.size(); ++k) {
@@ -123,7 +125,9 @@ int main(int argc, char** argv) {
   } else {
     PaddleX::DetResult result;
     cv::Mat im = cv::imread(FLAGS_image, 1);
-    model.predict(im, &result);
+    if (!model.predict(im, &result)) {
+      return -1;
+    }
     // Output predicted bounding boxes
     for (int i = 0; i < result.boxes.size(); ++i) {
       std::cout << "image file: " << FLAGS_image << std::endl;

+ 6 - 2
deploy/cpp/demo/segmenter.cpp

@@ -91,7 +91,9 @@ int main(int argc, char** argv) {
       for (int j = i; j < im_vec_size; ++j) {
         im_vec[j - i] = std::move(cv::imread(image_paths[j], 1));
       }
-      model.predict(im_vec, &results, thread_num);
+      if (!model.predict(im_vec, &results, thread_num)) {
+        return -1;
+      }
       // Visualize results
       for (int j = 0; j < im_vec_size - i; ++j) {
         cv::Mat vis_img =
@@ -105,7 +107,9 @@ int main(int argc, char** argv) {
   } else {
     PaddleX::SegResult result;
     cv::Mat im = cv::imread(FLAGS_image, 1);
-    model.predict(im, &result);
+    if (!model.predict(im, &result)) {
+      return -1;
+    }
     // Visualize results
     cv::Mat vis_img = PaddleX::Visualize(im, result, model.labels);
     std::string save_path =

+ 3 - 1
deploy/cpp/demo/video_classifier.cpp

@@ -140,7 +140,9 @@ int main(int argc, char** argv) {
       break;
     }
     // Begin to predict
-    model.predict(frame, &result);
+    if (!model.predict(frame, &result)) {
+      return -1;
+    }
     // Visualize results
     cv::Mat vis_img = frame.clone();
     auto colormap = PaddleX::GenerateColorMap(model.labels.size());

+ 3 - 1
deploy/cpp/demo/video_detector.cpp

@@ -141,7 +141,9 @@ int main(int argc, char** argv) {
       break;
     }
     // Begin to predict
-    model.predict(frame, &result);
+    if (!model.predict(frame, &result)) {
+      return -1;
+    }
     // Visualize results
     cv::Mat vis_img =
         PaddleX::Visualize(frame, result, model.labels, FLAGS_threshold);

+ 3 - 1
deploy/cpp/demo/video_segmenter.cpp

@@ -140,7 +140,9 @@ int main(int argc, char** argv) {
       break;
     }
     // Begin to predict
-    model.predict(frame, &result);
+    if (!model.predict(frame, &result)) {
+      return -1;
+    }
     // Visualize results
     cv::Mat vis_img = PaddleX::Visualize(frame, result, model.labels);
     if (FLAGS_show_result || FLAGS_use_camera) {

+ 6 - 0
deploy/cpp/include/paddlex/transforms.h

@@ -234,6 +234,12 @@ class Padding : public Transform {
     }
   }
   virtual bool Run(cv::Mat* im, ImageBlob* data);
+  virtual void GeneralPadding(cv::Mat* im,
+                              const std::vector<float> &padding_val,
+                              int padding_w, int padding_h);
+  virtual void MultichannelPadding(cv::Mat* im,
+                                   const std::vector<float> &padding_val,
+                                   int padding_w, int padding_h);
 
  private:
   int coarsest_stride_ = -1;

+ 4 - 1
deploy/cpp/src/paddlex.cpp

@@ -171,7 +171,10 @@ bool Model::predict(const cv::Mat& im, ClsResult* result) {
   inputs_.clear();
   if (type == "detector") {
     std::cerr << "Loading model is a 'detector', DetResult should be passed to "
-                 "function predict()!"
+                 "function predict()!" << std::endl;
+    return false;
+  } else if (type == "segmenter") {
+    std::cerr << "Loading model is a 'segmenter', SegResult should be passed "
                  "to function predict()!" << std::endl;
     return false;
   }

+ 46 - 34
deploy/cpp/src/transforms.cpp

@@ -94,6 +94,50 @@ bool CenterCrop::Run(cv::Mat* im, ImageBlob* data) {
   return true;
 }
 
+void Padding::GeneralPadding(cv::Mat* im,
+                             const std::vector<float> &padding_val,
+                             int padding_w, int padding_h) {
+  cv::Scalar value;
+  if (im->channels() == 1) {
+    value = cv::Scalar(padding_val[0]);
+  } else if (im->channels() == 2) {
+    value = cv::Scalar(padding_val[0], padding_val[1]);
+  } else if (im->channels() == 3) {
+    value = cv::Scalar(padding_val[0], padding_val[1], padding_val[2]);
+  } else if (im->channels() == 4) {
+    value = cv::Scalar(padding_val[0], padding_val[1], padding_val[2],
+                                  padding_val[3]);
+  }
+  cv::copyMakeBorder(
+  *im,
+  *im,
+  0,
+  padding_h,
+  0,
+  padding_w,
+  cv::BORDER_CONSTANT,
+  value);
+}
+
+void Padding::MultichannelPadding(cv::Mat* im,
+                                  const std::vector<float> &padding_val,
+                                  int padding_w, int padding_h) {
+  std::vector<cv::Mat> padded_im_per_channel(im->channels());
+  #pragma omp parallel for num_threads(im->channels())
+  for (size_t i = 0; i < im->channels(); i++) {
+    const cv::Mat per_channel = cv::Mat(im->rows + padding_h,
+                                        im->cols + padding_w,
+                                        CV_32FC1,
+                                        cv::Scalar(padding_val[i]));
+    padded_im_per_channel[i] = per_channel;
+  }
+  cv::Mat padded_im;
+  cv::merge(padded_im_per_channel, padded_im);
+  cv::Rect im_roi = cv::Rect(0, 0, im->cols, im->rows);
+  im->copyTo(padded_im(im_roi));
+  *im = padded_im;
+}
+
 bool Padding::Run(cv::Mat* im, ImageBlob* data) {
   data->im_size_before_resize_.push_back({im->rows, im->cols});
   data->reshape_order_.push_back("padding");
@@ -119,41 +163,9 @@ bool Padding::Run(cv::Mat* im, ImageBlob* data) {
     return false;
   }
   if (im->channels() < 5) {
-    cv::Scalar value;
-    if (im->channels() == 1) {
-      value = cv::Scalar(im_value_[0]);
-    } else if (im->channels() == 2) {
-      value = cv::Scalar(im_value_[0], im_value_[1]);
-    } else if (im->channels() == 3) {
-      value = cv::Scalar(im_value_[0], im_value_[1], im_value_[2]);
-    } else if (im->channels() == 4) {
-      value = cv::Scalar(im_value_[0], im_value_[1], im_value_[2],
-                                    im_value_[3]);
-    }
-    cv::copyMakeBorder(
-    *im,
-    *im,
-    0,
-    padding_h,
-    0,
-    padding_w,
-    cv::BORDER_CONSTANT,
-    value);
+    Padding::GeneralPadding(im, im_value_, padding_w, padding_h);
   } else {
-    std::vector<cv::Mat> padded_im_per_channel(im->channels());
-    #pragma omp parallel for num_threads(im->channels())
-    for (size_t i = 0; i < im->channels(); i++) {
-      const cv::Mat per_channel = cv::Mat(im->rows + padding_h,
-                                          im->cols + padding_w,
-                                          CV_32FC1,
-                                          cv::Scalar(im_value_[i]));
-      padded_im_per_channel[i] = per_channel;
-    }
-    cv::Mat padded_im;
-    cv::merge(padded_im_per_channel, padded_im);
-    cv::Rect im_roi = cv::Rect(0, 0, im->cols, im->rows);
-    im->copyTo(padded_im(im_roi));
-    *im = padded_im;
+    Padding::MultichannelPadding(im, im_value_, padding_w, padding_h);
   }
   data->new_im_size_[0] = im->rows;
   data->new_im_size_[1] = im->cols;

+ 6 - 2
deploy/openvino/demo/classifier.cpp

@@ -59,7 +59,9 @@ int main(int argc, char** argv) {
     while (getline(inf, image_path)) {
       PaddleX::ClsResult result;
       cv::Mat im = cv::imread(image_path, 1);
-      model.predict(im, &result);
+      if (!model.predict(im, &result)) {
+        return -1;
+      }
       std::cout << "Predict label: " << result.category
                 << ", label_id:" << result.category_id
                 << ", score: " << result.score << std::endl;
@@ -67,7 +69,9 @@ int main(int argc, char** argv) {
   } else {
     PaddleX::ClsResult result;
     cv::Mat im = cv::imread(FLAGS_image, 1);
-    model.predict(im, &result);
+    if (!model.predict(im, &result)) {
+      return -1;
+    }
     std::cout << "Predict label: " << result.category
               << ", label_id:" << result.category_id
               << ", score: " << result.score << std::endl;

+ 6 - 2
deploy/openvino/demo/detector.cpp

@@ -70,7 +70,9 @@ int main(int argc, char** argv) {
     while (getline(inf, image_path)) {
       PaddleX::DetResult result;
       cv::Mat im = cv::imread(image_path, 1);
-      model.predict(im, &result);
+      if (!model.predict(im, &result)) {
+        return -1;
+      }
       if (FLAGS_save_dir != "") {
         cv::Mat vis_img = PaddleX::Visualize(
           im, result, model.labels, colormap, FLAGS_threshold);
@@ -83,7 +85,9 @@ int main(int argc, char** argv) {
   } else {
   PaddleX::DetResult result;
   cv::Mat im = cv::imread(FLAGS_image, 1);
-  model.predict(im, &result);
+  if (!model.predict(im, &result)) {
+    return -1;
+  }
   for (int i = 0; i < result.boxes.size(); ++i) {
       std::cout << "image file: " << FLAGS_image << std::endl;
       std::cout << ", predict label: " << result.boxes[i].category

+ 6 - 2
deploy/openvino/demo/segmenter.cpp

@@ -64,7 +64,9 @@ int main(int argc, char** argv) {
     while (getline(inf, image_path)) {
       PaddleX::SegResult result;
       cv::Mat im = cv::imread(image_path, 1);
-      model.predict(im, &result);
+      if (!model.predict(im, &result)) {
+        return -1;
+      }
       if (FLAGS_save_dir != "") {
       cv::Mat vis_img = PaddleX::Visualize(im, result, model.labels, colormap);
         std::string save_path =
@@ -76,7 +78,9 @@ int main(int argc, char** argv) {
   } else {
     PaddleX::SegResult result;
     cv::Mat im = cv::imread(FLAGS_image, 1);
-    model.predict(im, &result);
+    if (!model.predict(im, &result)) {
+      return -1;
+    }
     if (FLAGS_save_dir != "") {
       cv::Mat vis_img = PaddleX::Visualize(im, result, model.labels, colormap);
       std::string save_path =

+ 6 - 0
deploy/openvino/include/paddlex/transforms.h

@@ -213,6 +213,12 @@ class Padding : public Transform {
   }
 
   virtual bool Run(cv::Mat* im, ImageBlob* data);
+  virtual void GeneralPadding(cv::Mat* im,
+                              const std::vector<float> &padding_val,
+                              int padding_w, int padding_h);
+  virtual void MultichannelPadding(cv::Mat* im,
+                                   const std::vector<float> &padding_val,
+                                   int padding_w, int padding_h);
 
  private:
   int coarsest_stride_ = -1;

+ 45 - 34
deploy/openvino/src/transforms.cpp

@@ -97,6 +97,49 @@ bool CenterCrop::Run(cv::Mat* im, ImageBlob* data) {
   return true;
 }
 
+void Padding::GeneralPadding(cv::Mat* im,
+                             const std::vector<float> &padding_val,
+                             int padding_w, int padding_h) {
+  cv::Scalar value;
+  if (im->channels() == 1) {
+    value = cv::Scalar(padding_val[0]);
+  } else if (im->channels() == 2) {
+    value = cv::Scalar(padding_val[0], padding_val[1]);
+  } else if (im->channels() == 3) {
+    value = cv::Scalar(padding_val[0], padding_val[1], padding_val[2]);
+  } else if (im->channels() == 4) {
+    value = cv::Scalar(padding_val[0], padding_val[1], padding_val[2],
+                                  padding_val[3]);
+  }
+  cv::copyMakeBorder(
+  *im,
+  *im,
+  0,
+  padding_h,
+  0,
+  padding_w,
+  cv::BORDER_CONSTANT,
+  value);
+}
+
+void Padding::MultichannelPadding(cv::Mat* im,
+                                  const std::vector<float> &padding_val,
+                                  int padding_w, int padding_h) {
+  std::vector<cv::Mat> padded_im_per_channel(im->channels());
+  #pragma omp parallel for num_threads(im->channels())
+  for (size_t i = 0; i < im->channels(); i++) {
+    const cv::Mat per_channel = cv::Mat(im->rows + padding_h,
+                                        im->cols + padding_w,
+                                        CV_32FC1,
+                                        cv::Scalar(padding_val[i]));
+    padded_im_per_channel[i] = per_channel;
+  }
+  cv::Mat padded_im;
+  cv::merge(padded_im_per_channel, padded_im);
+  cv::Rect im_roi = cv::Rect(0, 0, im->cols, im->rows);
+  im->copyTo(padded_im(im_roi));
+  *im = padded_im;
+}
 
 bool Padding::Run(cv::Mat* im, ImageBlob* data) {
   data->im_size_before_resize_.push_back({im->rows, im->cols});
@@ -123,41 +166,9 @@ bool Padding::Run(cv::Mat* im, ImageBlob* data) {
     return false;
   }
   if (im->channels() < 5) {
-    cv::Scalar value;
-    if (im->channels() == 1) {
-      value = cv::Scalar(im_value_[0]);
-    } else if (im->channels() == 2) {
-      value = cv::Scalar(im_value_[0], im_value_[1]);
-    } else if (im->channels() == 3) {
-      value = cv::Scalar(im_value_[0], im_value_[1], im_value_[2]);
-    } else if (im->channels() == 4) {
-      value = cv::Scalar(im_value_[0], im_value_[1], im_value_[2],
-                                    im_value_[3]);
-    }
-    cv::copyMakeBorder(
-    *im,
-    *im,
-    0,
-    padding_h,
-    0,
-    padding_w,
-    cv::BORDER_CONSTANT,
-    value);
+    Padding::GeneralPadding(im, im_value_, padding_w, padding_h);
   } else {
-    std::vector<cv::Mat> padded_im_per_channel(im->channels());
-    #pragma omp parallel for num_threads(im->channels())
-    for (size_t i = 0; i < im->channels(); i++) {
-      const cv::Mat per_channel = cv::Mat(im->rows + padding_h,
-                                          im->cols + padding_w,
-                                          CV_32FC1,
-                                          cv::Scalar(im_value_[i]));
-      padded_im_per_channel[i] = per_channel;
-    }
-    cv::Mat padded_im;
-    cv::merge(padded_im_per_channel, padded_im);
-    cv::Rect im_roi = cv::Rect(0, 0, im->cols, im->rows);
-    im->copyTo(padded_im(im_roi));
-    *im = padded_im;
+    Padding::MultichannelPadding(im, im_value_, padding_w, padding_h);
   }
   data->new_im_size_[0] = im->rows;
   data->new_im_size_[1] = im->cols;

+ 77 - 0
docs/apis/tools.md

@@ -0,0 +1,77 @@
+# 数据集工具
+
+## 数据集分析
+
+### paddlex.datasets.analysis.Seg
+```python
+paddlex.datasets.analysis.Seg(data_dir, file_list, label_list)
+```
+
+构建统计分析语义分类数据集的分析器。
+
+> **参数**
+> > * **data_dir** (str): 数据集所在的目录路径。  
+> > * **file_list** (str): 描述数据集图片文件和类别id的文件路径(文本内每行路径为相对`data_dir`的相对路径)。  
+> > * **label_list** (str): 描述数据集包含的类别信息文件路径。  
+
+#### analysis
+```python
+analysis(self)
+```
+
+Seg分析器的分析接口,完成以下信息的分析统计:
+
+> * 图像数量
+> * 图像最大和最小的尺寸
+> * 图像通道数量
+> * 图像各通道的最小值和最大值
+> * 图像各通道的像素值分布
+> * 图像各通道归一化后的均值和方差
+> * 标注图中各类别的数量及比重
+
+[代码示例](https://github.com/PaddlePaddle/PaddleX/blob/develop/examples/multi-channel_remote_sensing/tools/analysis.py)
+
+[统计信息示例](../../examples/multi-channel_remote_sensing/analysis.html#id2)
+
+#### cal_clipped_mean_std
+```python
+cal_clipped_mean_std(self, clip_min_value, clip_max_value, data_info_file)
+```
+
+Seg分析器用于计算图像截断后的均值和方差的接口。
+
+> **参数**
+> > * **clip_min_value** (list):  截断的下限,小于min_val的数值均设为min_val。
+> > * **clip_max_value** (list): 截断的上限,大于max_val的数值均设为max_val。
+> > * **data_info_file** (str): 在analysis()接口中保存的分析结果文件(名为`train_information.pkl`)的路径。
+
+[代码示例](https://github.com/PaddlePaddle/PaddleX/blob/develop/examples/multi-channel_remote_sensing/tools/cal_clipped_mean_std.py)
+
+[计算结果示例](../../examples/multi-channel_remote_sensing/analysis.html#id4)
+
+## 数据集生成
+
+### paddlex.det.paste_objects
+```python
+paddlex.det.paste_objects(templates, background, save_dir='dataset_clone')
+```
+
+将目标物体粘贴在背景图片上生成新的图片和标注文件
+
+> **参数**
+> > * **templates** (list|tuple):可以将多张图像上的目标物体同时粘贴在同一个背景图片上,因此templates是一个列表,其中每个元素是一个dict,表示一张图片的目标物体。一张图片的目标物体有`image`和`annos`两个关键字,`image`的键值是图像的路径,或者是解码后的排列格式为(H, W, C)且类型为uint8且为BGR格式的数组。图像上可以有多个目标物体,因此`annos`的键值是一个列表,列表中每个元素是一个dict,表示一个目标物体的信息。该dict包含`polygon`和`category`两个关键字,其中`polygon`表示目标物体的边缘坐标,例如[[0, 0], [0, 1], [1, 1], [1, 0]],`category`表示目标物体的类别,例如'dog'。
+> > * **background** (dict): 背景图片可以有真值,因此background是一个dict,包含`image`和`annos`两个关键字,`image`的键值是背景图像的路径,或者是解码后的排列格式为(H, W, C)且类型为uint8且为BGR格式的数组。若背景图片上没有真值,则`annos`的键值是空列表[],若有,则`annos`的键值是由多个dict组成的列表,每个dict表示一个物体的信息,包含`bbox`和`category`两个关键字,`bbox`的键值是物体框左上角和右下角的坐标,即[x1, y1, x2, y2],`category`表示目标物体的类别,例如'dog'。
+> > * **save_dir** (str):新图片及其标注文件的存储目录。默认值为`dataset_clone`。
+
+> **代码示例**
+
+```python
+import paddlex as pdx
+templates = [{'image': 'dataset/JPEGImages/budaodian-10.jpg',
+              'annos': [{'polygon': [[146, 169], [909, 169], [909, 489], [146, 489]],
+                        'category': 'lou_di'},
+                        {'polygon': [[146, 169], [909, 169], [909, 489], [146, 489]],
+                        'category': 'lou_di'}]}]
+background = {'image': 'dataset/JPEGImages/budaodian-12.jpg', 'annos': []}
+pdx.det.paste_objects(templates, background, save_dir='dataset_clone')
+```

+ 1 - 2
docs/apis/visualize.md

@@ -169,8 +169,7 @@ paddlex.det.coco_error_analysis(eval_details_file=None, gt=None, pred_bbox=None,
 > * **gt** (list): 数据集的真值信息。默认值为None。
 > * **pred_bbox** (list): 模型在数据集上的预测框。默认值为None。
 > * **pred_mask** (list): 模型在数据集上的预测mask。默认值为None。
-> * **iou_thresh** (float): 判断预测框或预测mask为真阳时的IoU阈值。默认值为0.5。
-> * **save_dir** (str): 可视化结果保存路径。默认值为'./'。
+> * **save_dir** (str): 可视化结果保存路径。默认值为'./output'。
 
 **注意:**`eval_details_file`的优先级更高,只要`eval_details_file`不为None,就会从`eval_details_file`提取真值信息和预测结果做分析。当`eval_details_file`为None时,则用`gt`、`pred_mask`、`pred_mask`做分析。
 

+ 3 - 3
docs/examples/industrial_quality_inspection/README.md

@@ -6,7 +6,7 @@
 
 ### 1.1 数据集介绍
 
-本案例使用天池铝材表面缺陷检测初赛数据集,共有3005张图片,分别检测擦花、杂色、漏底、不导电、桔皮、喷流、漆泡、起坑、脏点和角位漏底10种缺陷,这10种缺陷的定义和示例可点击文档[天池铝材表面缺陷检测初赛数据集示例](./dataset.md)查看。
+本案例使用[天池铝材表面缺陷检测初赛](https://tianchi.aliyun.com/competition/entrance/231682/introduction)数据集,共有3005张图片,分别检测擦花、杂色、漏底、不导电、桔皮、喷流、漆泡、起坑、脏点和角位漏底10种缺陷,这10种缺陷的定义和示例可点击文档[天池铝材表面缺陷检测初赛数据集示例](./dataset.md)查看。
 
 将这3005张图片按9:1随机切分成2713张图片的训练集和292张图片的验证集。
 
@@ -80,10 +80,10 @@ python train_yolov3.py
 
 ### 模型剪裁
 
-运行以下代码,计算在不同的精度损失下,模型各层的剪裁比例:
+运行以下代码,分析在不同的精度损失下模型各层的剪裁比例:
 
 ```
-python cal_sensitivities_file.py
+python params_analysis.py
 ```
 
 设置可允许的精度损失为0.05,对模型进行剪裁,剪裁后需要重新训练模型:

+ 1 - 1
docs/examples/industrial_quality_inspection/accuracy_improvement.md

@@ -12,7 +12,7 @@
 
 | all classes| 擦花 | 杂色 | 漏底 | 不导电 | 桔皮 | 喷流 | 漆泡 | 起坑 | 脏点 | 角位漏底 |
 | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- |
-|![](https://agroup-bos-bj.cdn.bcebos.com/bj-972007ed33acba896af4aee11cda6abd00ce9ba3)|![图片](https://agroup-bos-bj.cdn.bcebos.com/bj-91790922f4b137880a143f79134391657830c7d2)|![图片](https://agroup-bos-bj.cdn.bcebos.com/bj-46e25934ea0bdc5a7f819bac853883d19e0edc5f)| ![图片](https://agroup-bos-bj.cdn.bcebos.com/bj-236eb142601c01ea3f239c771534813ad3fae439) | ![图片](https://agroup-bos-bj.cdn.bcebos.com/bj-a8872384daca0e499fbf0f281980d4a28a89bb97) | ![](https://agroup-bos-bj.cdn.bcebos.com/bj-068e9d36fa6a172215c93bafe00c56d55fb9890d) | ![图片](https://agroup-bos-bj.cdn.bcebos.com/bj-da89bff70e7100002993b4ca7000ba6028b7abf4) | ![图片](https://agroup-bos-bj.cdn.bcebos.com/bj-f1540c077a30b012da41077941e49235e0f844ed) | ![图片](https://agroup-bos-bj.cdn.bcebos.com/bj-44fbd0c3af5c833f80e19b8fee576c1c49464385) | ![图片](https://agroup-bos-bj.cdn.bcebos.com/bj-72dce9276ee20349a09a29aa3967dd79e31b9174) | ![图片](https://agroup-bos-bj.cdn.bcebos.com/bj-c173f25931c901ade26a2ceab99d8eae5310e0ec) |
+| ![](https://bj.bcebos.com/paddlex/examples/industrial_quality_inspection/datasets/allclasses_analysis_example.png) | ![](https://bj.bcebos.com/paddlex/examples/industrial_quality_inspection/datasets/cahua_analysis_example.png) | ![](https://bj.bcebos.com/paddlex/examples/industrial_quality_inspection/datasets/zase_analysis_example.png) | ![](https://bj.bcebos.com/paddlex/examples/industrial_quality_inspection/datasets/loudi_analysis_example.png) | ![](https://bj.bcebos.com/paddlex/examples/industrial_quality_inspection/datasets/budaodian_analysis_example.png) | ![](https://bj.bcebos.com/paddlex/examples/industrial_quality_inspection/datasets/jupi_analysis_example.png) | ![](https://bj.bcebos.com/paddlex/examples/industrial_quality_inspection/datasets/penliu_analysis_example.png) | ![](https://bj.bcebos.com/paddlex/examples/industrial_quality_inspection/datasets/qipao_analysis_example.png) | ![](https://bj.bcebos.com/paddlex/examples/industrial_quality_inspection/datasets/qikeng_analysis_example.png) | ![](https://bj.bcebos.com/paddlex/examples/industrial_quality_inspection/datasets/zangdian_analysis_example.png) | ![](https://bj.bcebos.com/paddlex/examples/industrial_quality_inspection/datasets/jiaoweiloudi_analysis_example.png) |
 
 分析图表展示了7条Precision-Recall(PR)曲线,每一条曲线表示的Average Precision (AP)比它左边那条高,原因是逐步放宽了评估要求。以擦花类为例,各条PR曲线的评估要求解释如下:
 

+ 10 - 10
docs/examples/industrial_quality_inspection/dataset.md

@@ -2,13 +2,13 @@
 
 | 序号 | 瑕疵名称 | 瑕疵成因 | 瑕疵图片示例 | 图片数量 |
 | -- | -- | -- | -- | -- |
-| 1 | 擦花(擦伤)| 表面处理(喷涂)后有轻微擦到其它的东西,造成痕迹 | ![图片](https://agroup-bos-bj.cdn.bcebos.com/bj-be236181808cd68e3ad941efdc16f11d4a04dd35) | 128 |
-| 2 | 杂色 | 喷涂换颜料的时候,装颜料的容器未清洗干净,造成喷涂时有少量其它颜色掺入 |![图片](https://agroup-bos-bj.cdn.bcebos.com/bj-58d7fe6d1a0b72cfd735fa9e192f4f04e58c0901) |365 |
-| 3 | 漏底 | 喷粉效果不好,铝材大量底色露出 | ![图片](https://agroup-bos-bj.cdn.bcebos.com/bj-16d7005e897900aa916ea639cc49810fe77fa982) | 538 |
-| 4 | 不导电 | 直接喷不到铝材表面上去 | ![图片](https://agroup-bos-bj.cdn.bcebos.com/bj-3c886ff349280c22796a46f296fedd3296ae4120) | 390 |
-|5 | 桔皮 | 表面处理后涂层表面粗糙,大颗粒 | ![图片](https://agroup-bos-bj.cdn.bcebos.com/bj-726ce1ab1ac7aa47dd4c8bff89df85b4e3b7ae4d) | 173 |
-| 6 | 喷流| 喷涂时油漆稀从上流下来,有流动痕迹 | ![图片](https://agroup-bos-bj.cdn.bcebos.com/bj-1b245e25169d618a083669acde2c293bb548bea6) | 86 |
-| 7 |漆泡 | 喷涂后表面起泡,小而多| ![](https://agroup-bos-bj.cdn.bcebos.com/bj-00ed3f730ce41f7f18a4d6a5402fd3e61bfa0db9) | 82 |
-| 8 | 起坑 | 型材模具问题,做出来的型材一整条都有一条凹下去的部分 | ![图片](https://agroup-bos-bj.cdn.bcebos.com/bj-739abb8a6d9144560290cd8f2eaa05b3aa183e17) | 407 |
-| 9 | 脏点 | 表面处理时,有灰尘或一些脏东西未能擦掉,导致涂层有颗粒比较突出 | ![图片](https://agroup-bos-bj.cdn.bcebos.com/bj-487a749bfbb149294f540cc3710aacee30e154f2) | 261 |
-| 10 | 角位漏底 | 在型材角落出现的露底 | ![图片](https://agroup-bos-bj.cdn.bcebos.com/bj-91a9fe0f3a69f1b4ab6006bae870e5f687fab111) | 346 |
+| 1 | 擦花(擦伤)| 表面处理(喷涂)后有轻微擦到其它的东西,造成痕迹 | ![](https://bj.bcebos.com/paddlex/examples/industrial_quality_inspection/datasets/ca_hua_example.png) | 128 |
+| 2 | 杂色 | 喷涂换颜料的时候,装颜料的容器未清洗干净,造成喷涂时有少量其它颜色掺入 | ![](https://bj.bcebos.com/paddlex/examples/industrial_quality_inspection/datasets/za_se_example.png) |365 |
+| 3 | 漏底 | 喷粉效果不好,铝材大量底色露出 | ![](https://bj.bcebos.com/paddlex/examples/industrial_quality_inspection/datasets/lou_di_example.png) | 538 |
+| 4 | 不导电 | 直接喷不到铝材表面上去 | ![](https://bj.bcebos.com/paddlex/examples/industrial_quality_inspection/datasets/bu_dao_dian_example.png) | 390 |
+|5 | 桔皮 | 表面处理后涂层表面粗糙,大颗粒 | ![](https://bj.bcebos.com/paddlex/examples/industrial_quality_inspection/datasets/ju_pi_example.png) | 173 |
+| 6 | 喷流| 喷涂时油漆稀从上流下来,有流动痕迹 | ![](https://bj.bcebos.com/paddlex/examples/industrial_quality_inspection/datasets/pen_liu_example.png) | 86 |
+| 7 |漆泡 | 喷涂后表面起泡,小而多| ![](https://bj.bcebos.com/paddlex/examples/industrial_quality_inspection/datasets/qi_pao_example.png) | 82 |
+| 8 | 起坑 | 型材模具问题,做出来的型材一整条都有一条凹下去的部分 | ![](https://bj.bcebos.com/paddlex/examples/industrial_quality_inspection/datasets/qi_keng_example.png.png) | 407 |
+| 9 | 脏点 | 表面处理时,有灰尘或一些脏东西未能擦掉,导致涂层有颗粒比较突出 | ![](https://bj.bcebos.com/paddlex/examples/industrial_quality_inspection/datasets/zang_dian_example.png) | 261 |
+| 10 | 角位漏底 | 在型材角落出现的露底 | ![](https://bj.bcebos.com/paddlex/examples/industrial_quality_inspection/datasets/jiao_wei_lou_di_example.png) | 346 |

+ 1 - 0
docs/index.rst

@@ -64,6 +64,7 @@ PaddleX是基于飞桨核心框架、开发套件和工具组件的深度学习
    examples/remote_sensing.md
    examples/multi-channel_remote_sensing/README.md
    examples/change_detection.md
+   examples/industrial_quality_inspection/README.md
 
 .. toctree::
    :maxdepth: 1

+ 3 - 3
examples/industrial_quality_inspection/README.md

@@ -6,7 +6,7 @@
 
 ### 1.1 数据集介绍
 
-本案例使用天池铝材表面缺陷检测初赛数据集,共有3005张图片,分别检测擦花、杂色、漏底、不导电、桔皮、喷流、漆泡、起坑、脏点和角位漏底10种缺陷,这10种缺陷的定义和示例可点击文档[天池铝材表面缺陷检测初赛数据集示例](./dataset.md)查看。
+本案例使用[天池铝材表面缺陷检测初赛](https://tianchi.aliyun.com/competition/entrance/231682/introduction)数据集,共有3005张图片,分别检测擦花、杂色、漏底、不导电、桔皮、喷流、漆泡、起坑、脏点和角位漏底10种缺陷,这10种缺陷的定义和示例可点击文档[天池铝材表面缺陷检测初赛数据集示例](./dataset.md)查看。
 
 将这3005张图片按9:1随机切分成2713张图片的训练集和292张图片的验证集。
 
@@ -80,10 +80,10 @@ python train_yolov3.py
 
 ### 模型剪裁
 
-运行以下代码,计算在不同的精度损失下,模型各层的剪裁比例:
+运行以下代码,分析在不同的精度损失下模型各层的剪裁比例:
 
 ```
-python cal_sensitivities_file.py
+python params_analysis.py
 ```
 
 设置可允许的精度损失为0.05,对模型进行剪裁,剪裁后需要重新训练模型:

+ 1 - 1
examples/industrial_quality_inspection/accuracy_improvement.md

@@ -12,7 +12,7 @@
 
 | all classes| 擦花 | 杂色 | 漏底 | 不导电 | 桔皮 | 喷流 | 漆泡 | 起坑 | 脏点 | 角位漏底 |
 | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- | -- |
-|![](https://agroup-bos-bj.cdn.bcebos.com/bj-972007ed33acba896af4aee11cda6abd00ce9ba3)|![图片](https://agroup-bos-bj.cdn.bcebos.com/bj-91790922f4b137880a143f79134391657830c7d2)|![图片](https://agroup-bos-bj.cdn.bcebos.com/bj-46e25934ea0bdc5a7f819bac853883d19e0edc5f)| ![图片](https://agroup-bos-bj.cdn.bcebos.com/bj-236eb142601c01ea3f239c771534813ad3fae439) | ![图片](https://agroup-bos-bj.cdn.bcebos.com/bj-a8872384daca0e499fbf0f281980d4a28a89bb97) | ![](https://agroup-bos-bj.cdn.bcebos.com/bj-068e9d36fa6a172215c93bafe00c56d55fb9890d) | ![图片](https://agroup-bos-bj.cdn.bcebos.com/bj-da89bff70e7100002993b4ca7000ba6028b7abf4) | ![图片](https://agroup-bos-bj.cdn.bcebos.com/bj-f1540c077a30b012da41077941e49235e0f844ed) | ![图片](https://agroup-bos-bj.cdn.bcebos.com/bj-44fbd0c3af5c833f80e19b8fee576c1c49464385) | ![图片](https://agroup-bos-bj.cdn.bcebos.com/bj-72dce9276ee20349a09a29aa3967dd79e31b9174) | ![图片](https://agroup-bos-bj.cdn.bcebos.com/bj-c173f25931c901ade26a2ceab99d8eae5310e0ec) |
+| ![](https://bj.bcebos.com/paddlex/examples/industrial_quality_inspection/datasets/allclasses_analysis_example.png) | ![](https://bj.bcebos.com/paddlex/examples/industrial_quality_inspection/datasets/cahua_analysis_example.png) | ![](https://bj.bcebos.com/paddlex/examples/industrial_quality_inspection/datasets/zase_analysis_example.png) | ![](https://bj.bcebos.com/paddlex/examples/industrial_quality_inspection/datasets/loudi_analysis_example.png) | ![](https://bj.bcebos.com/paddlex/examples/industrial_quality_inspection/datasets/budaodian_analysis_example.png) | ![](https://bj.bcebos.com/paddlex/examples/industrial_quality_inspection/datasets/jupi_analysis_example.png) | ![](https://bj.bcebos.com/paddlex/examples/industrial_quality_inspection/datasets/penliu_analysis_example.png) | ![](https://bj.bcebos.com/paddlex/examples/industrial_quality_inspection/datasets/qipao_analysis_example.png) | ![](https://bj.bcebos.com/paddlex/examples/industrial_quality_inspection/datasets/qikeng_analysis_example.png) | ![](https://bj.bcebos.com/paddlex/examples/industrial_quality_inspection/datasets/zangdian_analysis_example.png) | ![](https://bj.bcebos.com/paddlex/examples/industrial_quality_inspection/datasets/jiaoweiloudi_analysis_example.png) |
 
 分析图表展示了7条Precision-Recall(PR)曲线,每一条曲线表示的Average Precision (AP)比它左边那条高,原因是逐步放宽了评估要求。以擦花类为例,各条PR曲线的评估要求解释如下:
 

+ 183 - 95
examples/industrial_quality_inspection/cal_tp_fp.py

@@ -15,6 +15,7 @@
 
 # 环境变量配置,用于控制是否使用GPU
 # 说明文档:https://paddlex.readthedocs.io/zh_CN/develop/appendix/parameters.html#gpu
+import argparse
 import os
 os.environ['CUDA_VISIBLE_DEVICES'] = '0'
 
@@ -25,102 +26,189 @@ matplotlib.use('Agg')
 import matplotlib.pyplot as plt
 import paddlex as pdx
 
-data_dir = 'aluminum_inspection/'
-positive_file_list = 'aluminum_inspection/val_list.txt'
-negative_dir = 'aluminum_inspection/val_wu_xia_ci'
-model_dir = 'output/faster_rcnn_r50_vd_dcn/best_model/'
-save_dir = 'visualize/faster_rcnn_r50_vd_dcn'
-if not osp.exists(save_dir):
-    os.makedirs(save_dir)
-
-tp = np.zeros((101, 1))
-fp = np.zeros((101, 1))
-
-# 导入模型
-model = pdx.load_model(model_dir)
-
-# 计算图片级召回率
-print(
-    "Begin to calculate image-level recall rate of positive images. Please wait for a moment..."
-)
-positive_num = 0
-with open(positive_file_list, 'r') as fr:
-    while True:
-        line = fr.readline()
-        if not line:
-            break
-        img_file, xml_file = [osp.join(data_dir, x) \
-                for x in line.strip().split()[:2]]
-        if not osp.exists(img_file):
-            continue
-        if not osp.exists(xml_file):
-            continue
-
-        positive_num += 1
-        results = model.predict(img_file)
+
+def cal_image_level_recall_rate(model, dataset_dir):
+    """计算置信度(Score)在[0, 1]内以间隔0.01递增取值时,模型在有目标的图像上的图片级召回率。
+
+    图片级召回率:只要在有目标的图片上检测出目标(不论框的个数),该图片被认为召回,
+       批量有目标图片中被召回的图片所占的比例,即为图片级别的召回率。
+
+    Args:
+        model (PaddleX model object): 已加载的PaddleX模型。
+        dataset_dir (str):数据集路径。
+
+    Returns:
+        numpy.array: 形状为101x1的数组,对应置信度从0到1按0.01递增取值时,计算所得图片级别的召回率。
+    """
+
+    print(
+        "Begin to calculate image-level recall rate of positive images. Please wait for a moment..."
+    )
+    file_list = osp.join(dataset_dir, 'val_list.txt')
+    tp = np.zeros((101, 1))
+    positive_num = 0
+    with open(file_list, 'r') as fr:
+        while True:
+            line = fr.readline()
+            if not line:
+                break
+            img_file, xml_file = [osp.join(dataset_dir, x) \
+                    for x in line.strip().split()[:2]]
+            if not osp.exists(img_file):
+                continue
+            if not osp.exists(xml_file):
+                continue
+
+            positive_num += 1
+            results = model.predict(img_file)
+            scores = list()
+            for res in results:
+                scores.append(res['score'])
+            if len(scores) > 0:
+                tp[0:int(np.round(max(scores) / 0.01)), 0] += 1
+    tp = tp / positive_num
+    return tp
+
+
+def cal_image_level_false_positive_rate(model, negative_dir):
+    """计算置信度(Score)在[0, 1]内以间隔0.01递增取值时,模型在无目标的图像上的图片级误检率。
+
+    图片级误检率:只要在无目标的图片上检测出目标(不论框的个数),该图片被认为误检,
+       批量无目标图片中被误检的图片所占的比例,即为图片级别的误检率。
+
+    Args:
+        model (PaddleX model object): 已加载的PaddleX模型。
+        negative_dir (str):无目标图片的文件夹路径。
+
+    Returns:
+        numpy.array: 形状为101x1的数组,对应置信度从0到1按0.01递增取值时,计算所得图片级别的误检率。
+    """
+
+    print(
+        "Begin to calculate image-level false positive rate of negative(background) images. Please wait for a moment..."
+    )
+    fp = np.zeros((101, 1))
+    negative_num = 0
+    for file in os.listdir(negative_dir):
+        file = osp.join(negative_dir, file)
+        results = model.predict(file)
+        negative_num += 1
         scores = list()
         for res in results:
             scores.append(res['score'])
         if len(scores) > 0:
-            tp[0:int(np.round(max(scores) / 0.01)), 0] += 1
-tp = tp / positive_num
-
-# 计算图片级误检率
-print(
-    "Begin to calculate image-level false-positive rate of background images. Please wait for a moment..."
-)
-negative_num = 0
-for file in os.listdir(negative_dir):
-    file = osp.join(negative_dir, file)
-    results = model.predict(file)
-    negative_num += 1
-    scores = list()
-    for res in results:
-        scores.append(res['score'])
-    if len(scores) > 0:
-        fp[0:int(np.round(max(scores) / 0.01)), 0] += 1
-fp = fp / negative_num
-
-# 保存结果
-tp_fp_list_file = osp.join(save_dir, 'tp_fp_list.txt')
-with open(tp_fp_list_file, 'w') as f:
-    f.write("| score | recall rate | false-positive rate |\n")
-    f.write("| -- | -- | -- |\n")
-    for i in range(100):
-        f.write("| {:2f} | {:2f} | {:2f} |\n".format(0.01 * i, tp[i, 0], fp[
-            i, 0]))
-print("The numerical score-recall_rate-false_positive_rate is saved as {}".
-      format(tp_fp_list_file))
-
-plt.subplot(1, 2, 1)
-plt.title("image-level false_positive-recall")
-plt.xlabel("recall")
-plt.ylabel("false_positive")
-plt.xlim(0, 1)
-plt.ylim(0, 1)
-plt.grid(linestyle='--', linewidth=1)
-plt.plot([0, 1], [0, 1], 'r--', linewidth=1)
-my_x_ticks = np.arange(0, 1, 0.1)
-my_y_ticks = np.arange(0, 1, 0.1)
-plt.xticks(my_x_ticks, fontsize=5)
-plt.yticks(my_y_ticks, fontsize=5)
-plt.plot(tp, fp, color='b', label="image level", linewidth=1)
-plt.legend(loc="lower left", fontsize=5)
-
-plt.subplot(1, 2, 2)
-plt.title("score-recall")
-plt.xlabel('recall')
-plt.ylabel('score')
-plt.xlim(0, 1)
-plt.ylim(0, 1)
-plt.grid(linestyle='--', linewidth=1)
-plt.xticks(my_x_ticks, fontsize=5)
-plt.yticks(my_y_ticks, fontsize=5)
-plt.plot(
-    tp, np.arange(0, 1.01, 0.01), color='b', label="image level", linewidth=1)
-plt.legend(loc="lower left", fontsize=5)
-tp_fp_chart_file = os.path.join(save_dir, "image-level_tp_fp.png")
-plt.savefig(tp_fp_chart_file, dpi=800)
-plt.close()
-print("The diagrammatic score-recall_rate-false_positive_rate is saved as {}".
-      format(tp_fp_chart_file))
+            fp[0:int(np.round(max(scores) / 0.01)), 0] += 1
+    fp = fp / negative_num
+    return fp
+
+
+def result2textfile(tp_list, fp_list, save_dir):
+    """将不同置信度阈值下的图片级召回率和图片级误检率保存为文本文件。
+
+    文本文件中内容按照| 置信度阈值 | 图片级召回率 | 图片级误检率 |的格式保存。
+
+    Args:
+        tp_list (numpy.array): 形状为101x1的数组,对应置信度从0到1按0.01递增取值时,计算所得图片级别的召回率。
+        fp_list (numpy.array): 形状为101x1的数组,对应置信度从0到1按0.01递增取值时,计算所得图片级别的误检率。
+        save_dir (str): 文本文件的保存路径。
+
+    """
+
+    tp_fp_list_file = osp.join(save_dir, 'tp_fp_list.txt')
+    with open(tp_fp_list_file, 'w') as f:
+        f.write("| score | recall rate | false-positive rate |\n")
+        f.write("| -- | -- | -- |\n")
+        for i in range(100):
+            f.write("| {:2f} | {:2f} | {:2f} |\n".format(0.01 * i, tp_list[
+                i, 0], fp_list[i, 0]))
+    print("The numerical score-recall_rate-false_positive_rate is saved as {}".
+          format(tp_fp_list_file))
+
+
+def result2imagefile(tp_list, fp_list, save_dir):
+    """将不同置信度阈值下的图片级召回率和图片级误检率保存为图片。
+
+    图片中左子图横坐标表示不同置信度阈值下计算得到的图片级召回率,纵坐标表示各图片级召回率对应的图片级误检率。
+        右边子图横坐标表示图片级召回率,纵坐标表示各图片级召回率对应的置信度阈值。
+
+    Args:
+        tp_list (numpy.array): 形状为101x1的数组,对应置信度从0到1按0.01递增取值时,计算所得图片级别的召回率。
+        fp_list (numpy.array): 形状为101x1的数组,对应置信度从0到1按0.01递增取值时,计算所得图片级别的误检率。
+        save_dir (str): 文本文件的保存路径。
+
+    """
+
+    plt.subplot(1, 2, 1)
+    plt.title("image-level false_positive-recall")
+    plt.xlabel("recall")
+    plt.ylabel("false_positive")
+    plt.xlim(0, 1)
+    plt.ylim(0, 1)
+    plt.grid(linestyle='--', linewidth=1)
+    plt.plot([0, 1], [0, 1], 'r--', linewidth=1)
+    my_x_ticks = np.arange(0, 1, 0.1)
+    my_y_ticks = np.arange(0, 1, 0.1)
+    plt.xticks(my_x_ticks, fontsize=5)
+    plt.yticks(my_y_ticks, fontsize=5)
+    plt.plot(tp_list, fp_list, color='b', label="image level", linewidth=1)
+    plt.legend(loc="lower left", fontsize=5)
+
+    plt.subplot(1, 2, 2)
+    plt.title("score-recall")
+    plt.xlabel('recall')
+    plt.ylabel('score')
+    plt.xlim(0, 1)
+    plt.ylim(0, 1)
+    plt.grid(linestyle='--', linewidth=1)
+    plt.xticks(my_x_ticks, fontsize=5)
+    plt.yticks(my_y_ticks, fontsize=5)
+    plt.plot(
+        tp_list,
+        np.arange(0, 1.01, 0.01),
+        color='b',
+        label="image level",
+        linewidth=1)
+    plt.legend(loc="lower left", fontsize=5)
+    tp_fp_chart_file = os.path.join(save_dir, "image-level_tp_fp.png")
+    plt.savefig(tp_fp_chart_file, dpi=800)
+    plt.close()
+    print(
+        "The diagrammatic score-recall_rate-false_positive_rate is saved as {}".
+        format(tp_fp_chart_file))
+
+
+if __name__ == '__main__':
+    parser = argparse.ArgumentParser(description=__doc__)
+    parser.add_argument(
+        "--model_dir",
+        default="./output/faster_rcnn_r50_vd_dcn/best_model/",
+        type=str,
+        help="The model directory path.")
+    parser.add_argument(
+        "--dataset_dir",
+        default="./aluminum_inspection",
+        type=str,
+        help="The VOC-format dataset directory path.")
+    parser.add_argument(
+        "--background_image_dir",
+        default="./aluminum_inspection/val_wu_xia_ci",
+        type=str,
+        help="The directory path of background images.")
+    parser.add_argument(
+        "--save_dir",
+        default="./visualize/faster_rcnn_r50_vd_dcn",
+        type=str,
+        help="The directory path of result.")
+
+    args = parser.parse_args()
+
+    if not osp.exists(args.save_dir):
+        os.makedirs(args.save_dir)
+
+    model = pdx.load_model(args.model_dir)
+
+    tp_list = cal_image_level_recall_rate(model, args.dataset_dir)
+    fp_list = cal_image_level_false_positive_rate(model,
+                                                  args.background_image_dir)
+    result2textfile(tp_list, fp_list, args.save_dir)
+    result2imagefile(tp_list, fp_list, args.save_dir)

+ 128 - 100
examples/industrial_quality_inspection/compare.py

@@ -15,6 +15,7 @@
 
 # 环境变量配置,用于控制是否使用GPU
 # 说明文档:https://paddlex.readthedocs.io/zh_CN/develop/appendix/parameters.html#gpu
+import argparse
 import os
 os.environ['CUDA_VISIBLE_DEVICES'] = '0'
 
@@ -24,109 +25,136 @@ import re
 import xml.etree.ElementTree as ET
 import paddlex as pdx
 
-data_dir = 'aluminum_inspection/'
-file_list = 'aluminum_inspection/val_list.txt'
-model_dir = 'output/faster_rcnn_r50_vd_dcn/best_model/'
-save_dir = './visualize/compare'
-# 设置置信度阈值
-score_threshold = 0.1
 
-if not os.path.exists(save_dir):
-    os.makedirs(save_dir)
+def parse_xml_file(xml_file):
+    tree = ET.parse(xml_file)
+    pattern = re.compile('<object>', re.IGNORECASE)
+    obj_match = pattern.findall(str(ET.tostringlist(tree.getroot())))
+    if len(obj_match) == 0:
+        return False
+    obj_tag = obj_match[0][1:-1]
+    objs = tree.findall(obj_tag)
+    pattern = re.compile('<size>', re.IGNORECASE)
+    size_tag = pattern.findall(str(ET.tostringlist(tree.getroot())))[0][1:-1]
+    size_element = tree.find(size_tag)
+    pattern = re.compile('<width>', re.IGNORECASE)
+    width_tag = pattern.findall(str(ET.tostringlist(size_element)))[0][1:-1]
+    im_w = float(size_element.find(width_tag).text)
+    pattern = re.compile('<height>', re.IGNORECASE)
+    height_tag = pattern.findall(str(ET.tostringlist(size_element)))[0][1:-1]
+    im_h = float(size_element.find(height_tag).text)
+    gt_bbox = []
+    gt_class = []
+    for i, obj in enumerate(objs):
+        pattern = re.compile('<name>', re.IGNORECASE)
+        name_tag = pattern.findall(str(ET.tostringlist(obj)))[0][1:-1]
+        cname = obj.find(name_tag).text.strip()
+        gt_class.append(cname)
+        pattern = re.compile('<difficult>', re.IGNORECASE)
+        diff_tag = pattern.findall(str(ET.tostringlist(obj)))[0][1:-1]
+        try:
+            _difficult = int(obj.find(diff_tag).text)
+        except Exception:
+            _difficult = 0
+        pattern = re.compile('<bndbox>', re.IGNORECASE)
+        box_tag = pattern.findall(str(ET.tostringlist(obj)))[0][1:-1]
+        box_element = obj.find(box_tag)
+        pattern = re.compile('<xmin>', re.IGNORECASE)
+        xmin_tag = pattern.findall(str(ET.tostringlist(box_element)))[0][1:-1]
+        x1 = float(box_element.find(xmin_tag).text)
+        pattern = re.compile('<ymin>', re.IGNORECASE)
+        ymin_tag = pattern.findall(str(ET.tostringlist(box_element)))[0][1:-1]
+        y1 = float(box_element.find(ymin_tag).text)
+        pattern = re.compile('<xmax>', re.IGNORECASE)
+        xmax_tag = pattern.findall(str(ET.tostringlist(box_element)))[0][1:-1]
+        x2 = float(box_element.find(xmax_tag).text)
+        pattern = re.compile('<ymax>', re.IGNORECASE)
+        ymax_tag = pattern.findall(str(ET.tostringlist(box_element)))[0][1:-1]
+        y2 = float(box_element.find(ymax_tag).text)
+        x1 = max(0, x1)
+        y1 = max(0, y1)
+        if im_w > 0.5 and im_h > 0.5:
+            x2 = min(im_w - 1, x2)
+            y2 = min(im_h - 1, y2)
+        gt_bbox.append([x1, y1, x2, y2])
+    gts = []
+    for bbox, name in zip(gt_bbox, gt_class):
+        x1, y1, x2, y2 = bbox
+        w = x2 - x1 + 1
+        h = y2 - y1 + 1
+        gt = {
+            'category_id': 0,
+            'category': name,
+            'bbox': [x1, y1, w, h],
+            'score': 1
+        }
+        gts.append(gt)
 
-model = pdx.load_model(model_dir)
+    return gts
 
-with open(file_list, 'r') as fr:
-    while True:
-        line = fr.readline()
-        if not line:
-            break
-        img_file, xml_file = [osp.join(data_dir, x) \
-                for x in line.strip().split()[:2]]
-        if not osp.exists(img_file):
-            continue
-        if not osp.exists(xml_file):
-            continue
 
-        res = model.predict(img_file)
-        det_vis = pdx.det.visualize(
-            img_file, res, threshold=score_threshold, save_dir=None)
+if __name__ == '__main__':
+    parser = argparse.ArgumentParser(description=__doc__)
+    parser.add_argument(
+        "--model_dir",
+        default="./output/faster_rcnn_r50_vd_dcn/best_model/",
+        type=str,
+        help="The model directory path.")
+    parser.add_argument(
+        "--dataset_dir",
+        default="./aluminum_inspection",
+        type=str,
+        help="The VOC-format dataset directory path.")
+    parser.add_argument(
+        "--save_dir",
+        default="./visualize/compare",
+        type=str,
+        help="The directory path of result.")
+    parser.add_argument(
+        "--score_threshold",
+        default=0.1,
+        type=float,
+        help="The predicted bbox whose score is lower than score_threshold is filtered."
+    )
 
-        tree = ET.parse(xml_file)
-        pattern = re.compile('<object>', re.IGNORECASE)
-        obj_match = pattern.findall(str(ET.tostringlist(tree.getroot())))
-        if len(obj_match) == 0:
-            continue
-        obj_tag = obj_match[0][1:-1]
-        objs = tree.findall(obj_tag)
-        pattern = re.compile('<size>', re.IGNORECASE)
-        size_tag = pattern.findall(str(ET.tostringlist(tree.getroot())))[0][1:
-                                                                            -1]
-        size_element = tree.find(size_tag)
-        pattern = re.compile('<width>', re.IGNORECASE)
-        width_tag = pattern.findall(str(ET.tostringlist(size_element)))[0][1:
-                                                                           -1]
-        im_w = float(size_element.find(width_tag).text)
-        pattern = re.compile('<height>', re.IGNORECASE)
-        height_tag = pattern.findall(str(ET.tostringlist(size_element)))[0][1:
-                                                                            -1]
-        im_h = float(size_element.find(height_tag).text)
-        gt_bbox = []
-        gt_class = []
-        for i, obj in enumerate(objs):
-            pattern = re.compile('<name>', re.IGNORECASE)
-            name_tag = pattern.findall(str(ET.tostringlist(obj)))[0][1:-1]
-            cname = obj.find(name_tag).text.strip()
-            gt_class.append(cname)
-            pattern = re.compile('<difficult>', re.IGNORECASE)
-            diff_tag = pattern.findall(str(ET.tostringlist(obj)))[0][1:-1]
-            try:
-                _difficult = int(obj.find(diff_tag).text)
-            except Exception:
-                _difficult = 0
-            pattern = re.compile('<bndbox>', re.IGNORECASE)
-            box_tag = pattern.findall(str(ET.tostringlist(obj)))[0][1:-1]
-            box_element = obj.find(box_tag)
-            pattern = re.compile('<xmin>', re.IGNORECASE)
-            xmin_tag = pattern.findall(str(ET.tostringlist(box_element)))[0][
-                1:-1]
-            x1 = float(box_element.find(xmin_tag).text)
-            pattern = re.compile('<ymin>', re.IGNORECASE)
-            ymin_tag = pattern.findall(str(ET.tostringlist(box_element)))[0][
-                1:-1]
-            y1 = float(box_element.find(ymin_tag).text)
-            pattern = re.compile('<xmax>', re.IGNORECASE)
-            xmax_tag = pattern.findall(str(ET.tostringlist(box_element)))[0][
-                1:-1]
-            x2 = float(box_element.find(xmax_tag).text)
-            pattern = re.compile('<ymax>', re.IGNORECASE)
-            ymax_tag = pattern.findall(str(ET.tostringlist(box_element)))[0][
-                1:-1]
-            y2 = float(box_element.find(ymax_tag).text)
-            x1 = max(0, x1)
-            y1 = max(0, y1)
-            if im_w > 0.5 and im_h > 0.5:
-                x2 = min(im_w - 1, x2)
-                y2 = min(im_h - 1, y2)
-            gt_bbox.append([x1, y1, x2, y2])
-        gts = []
-        for bbox, name in zip(gt_bbox, gt_class):
-            x1, y1, x2, y2 = bbox
-            w = x2 - x1 + 1
-            h = y2 - y1 + 1
-            gt = {
-                'category_id': 0,
-                'category': name,
-                'bbox': [x1, y1, w, h],
-                'score': 1
-            }
-            gts.append(gt)
-        gt_vis = pdx.det.visualize(
-            img_file, gts, threshold=score_threshold, save_dir=None)
-        vis = cv2.hconcat([gt_vis, det_vis])
-        cv2.imwrite(os.path.join(save_dir, os.path.split(img_file)[-1]), vis)
-        print('The comparison has been made for {}'.format(img_file))
+    args = parser.parse_args()
 
-print(
-    "The visualized ground-truths and predictions are saved in {}. Ground-truth is on the left, prediciton is on the right".
-    format(save_dir))
+    if not os.path.exists(args.save_dir):
+        os.makedirs(args.save_dir)
+    file_list = osp.join(args.dataset_dir, 'val_list.txt')
+
+    model = pdx.load_model(args.model_dir)
+
+    with open(file_list, 'r') as fr:
+        while True:
+            line = fr.readline()
+            if not line:
+                break
+            img_file, xml_file = [osp.join(args.dataset_dir, x) \
+                    for x in line.strip().split()[:2]]
+            if not osp.exists(img_file):
+                continue
+            if not osp.exists(xml_file):
+                continue
+
+            res = model.predict(img_file)
+            gts = parse_xml_file(xml_file)
+
+            det_vis = pdx.det.visualize(
+                img_file, res, threshold=args.score_threshold, save_dir=None)
+            if gts == False:
+                gts = cv2.imread(img_file)
+            else:
+                gt_vis = pdx.det.visualize(
+                    img_file,
+                    gts,
+                    threshold=args.score_threshold,
+                    save_dir=None)
+            vis = cv2.hconcat([gt_vis, det_vis])
+            cv2.imwrite(
+                os.path.join(args.save_dir, os.path.split(img_file)[-1]), vis)
+            print('The comparison has been made for {}'.format(img_file))
+
+    print(
+        "The visualized ground-truths and predictions are saved in {}. Ground-truth is on the left, prediciton is on the right".
+        format(save_dir))

+ 10 - 10
examples/industrial_quality_inspection/dataset.md

@@ -2,13 +2,13 @@
 
 | 序号 | 瑕疵名称 | 瑕疵成因 | 瑕疵图片示例 | 图片数量 |
 | -- | -- | -- | -- | -- |
-| 1 | 擦花(擦伤)| 表面处理(喷涂)后有轻微擦到其它的东西,造成痕迹 | ![图片](https://agroup-bos-bj.cdn.bcebos.com/bj-be236181808cd68e3ad941efdc16f11d4a04dd35) | 128 |
-| 2 | 杂色 | 喷涂换颜料的时候,装颜料的容器未清洗干净,造成喷涂时有少量其它颜色掺入 |![图片](https://agroup-bos-bj.cdn.bcebos.com/bj-58d7fe6d1a0b72cfd735fa9e192f4f04e58c0901) |365 |
-| 3 | 漏底 | 喷粉效果不好,铝材大量底色露出 | ![图片](https://agroup-bos-bj.cdn.bcebos.com/bj-16d7005e897900aa916ea639cc49810fe77fa982) | 538 |
-| 4 | 不导电 | 直接喷不到铝材表面上去 | ![图片](https://agroup-bos-bj.cdn.bcebos.com/bj-3c886ff349280c22796a46f296fedd3296ae4120) | 390 |
-|5 | 桔皮 | 表面处理后涂层表面粗糙,大颗粒 | ![图片](https://agroup-bos-bj.cdn.bcebos.com/bj-726ce1ab1ac7aa47dd4c8bff89df85b4e3b7ae4d) | 173 |
-| 6 | 喷流| 喷涂时油漆稀从上流下来,有流动痕迹 | ![图片](https://agroup-bos-bj.cdn.bcebos.com/bj-1b245e25169d618a083669acde2c293bb548bea6) | 86 |
-| 7 |漆泡 | 喷涂后表面起泡,小而多| ![](https://agroup-bos-bj.cdn.bcebos.com/bj-00ed3f730ce41f7f18a4d6a5402fd3e61bfa0db9) | 82 |
-| 8 | 起坑 | 型材模具问题,做出来的型材一整条都有一条凹下去的部分 | ![图片](https://agroup-bos-bj.cdn.bcebos.com/bj-739abb8a6d9144560290cd8f2eaa05b3aa183e17) | 407 |
-| 9 | 脏点 | 表面处理时,有灰尘或一些脏东西未能擦掉,导致涂层有颗粒比较突出 | ![图片](https://agroup-bos-bj.cdn.bcebos.com/bj-487a749bfbb149294f540cc3710aacee30e154f2) | 261 |
-| 10 | 角位漏底 | 在型材角落出现的露底 | ![图片](https://agroup-bos-bj.cdn.bcebos.com/bj-91a9fe0f3a69f1b4ab6006bae870e5f687fab111) | 346 |
+| 1 | 擦花(擦伤)| 表面处理(喷涂)后有轻微擦到其它的东西,造成痕迹 | ![](https://bj.bcebos.com/paddlex/examples/industrial_quality_inspection/datasets/ca_hua_example.png) | 128 |
+| 2 | 杂色 | 喷涂换颜料的时候,装颜料的容器未清洗干净,造成喷涂时有少量其它颜色掺入 | ![](https://bj.bcebos.com/paddlex/examples/industrial_quality_inspection/datasets/za_se_example.png) |365 |
+| 3 | 漏底 | 喷粉效果不好,铝材大量底色露出 | ![](https://bj.bcebos.com/paddlex/examples/industrial_quality_inspection/datasets/lou_di_example.png) | 538 |
+| 4 | 不导电 | 直接喷不到铝材表面上去 | ![](https://bj.bcebos.com/paddlex/examples/industrial_quality_inspection/datasets/bu_dao_dian_example.png) | 390 |
+|5 | 桔皮 | 表面处理后涂层表面粗糙,大颗粒 | ![](https://bj.bcebos.com/paddlex/examples/industrial_quality_inspection/datasets/ju_pi_example.png) | 173 |
+| 6 | 喷流| 喷涂时油漆稀从上流下来,有流动痕迹 | ![](https://bj.bcebos.com/paddlex/examples/industrial_quality_inspection/datasets/pen_liu_example.png) | 86 |
+| 7 |漆泡 | 喷涂后表面起泡,小而多| ![](https://bj.bcebos.com/paddlex/examples/industrial_quality_inspection/datasets/qi_pao_example.png) | 82 |
+| 8 | 起坑 | 型材模具问题,做出来的型材一整条都有一条凹下去的部分 | ![](https://bj.bcebos.com/paddlex/examples/industrial_quality_inspection/datasets/qi_keng_example.png.png) | 407 |
+| 9 | 脏点 | 表面处理时,有灰尘或一些脏东西未能擦掉,导致涂层有颗粒比较突出 | ![](https://bj.bcebos.com/paddlex/examples/industrial_quality_inspection/datasets/zang_dian_example.png) | 261 |
+| 10 | 角位漏底 | 在型材角落出现的露底 | ![](https://bj.bcebos.com/paddlex/examples/industrial_quality_inspection/datasets/jiao_wei_lou_di_example.png) | 346 |

+ 9 - 4
examples/industrial_quality_inspection/cal_sensitivities_file.py → examples/industrial_quality_inspection/params_analysis.py

@@ -19,7 +19,7 @@ os.environ['CUDA_VISIBLE_DEVICES'] = '0'
 import paddlex as pdx
 
 
-def cal_sensitivies_file(model_dir, dataset, save_file):
+def params_analysis(model_dir, dataset, batch_size, save_file):
     # 加载模型
     model = pdx.load_model(model_dir)
 
@@ -30,8 +30,11 @@ def cal_sensitivies_file(model_dir, dataset, save_file):
         label_list=os.path.join(dataset, 'labels.txt'),
         transforms=model.eval_transforms)
 
-    pdx.slim.cal_params_sensitivities(
-        model, save_file, eval_dataset, batch_size=8)
+    pdx.slim.prune.analysis(
+        model,
+        dataset=eval_dataset,
+        batch_size=batch_size,
+        save_file=save_file)
 
 
 if __name__ == '__main__':
@@ -46,6 +49,7 @@ if __name__ == '__main__':
         default="./aluminum_inspection",
         type=str,
         help="The model path.")
+    parser.add_argument("--batch_size", default=8, type=int, help="Batch size")
     parser.add_argument(
         "--save_file",
         default="./sensitivities.data",
@@ -53,4 +57,5 @@ if __name__ == '__main__':
         help="The sensitivities file path.")
 
     args = parser.parse_args()
-    cal_sensitivies_file(args.model_dir, args.dataset, args.save_file)
+    params_analysis(args.model_dir, args.dataset, args.batch_size,
+                    args.save_file)

+ 1 - 2
examples/industrial_quality_inspection/predict.py

@@ -32,5 +32,4 @@ if not os.path.exists(save_dir):
 
 model = pdx.load_model(model_dir)
 res = model.predict(img_file)
-det_vis = pdx.det.visualize(
-    img_file, res, threshold=score_threshold, save_dir=save_dir)
+pdx.det.visualize(img_file, res, threshold=score_threshold, save_dir=save_dir)

+ 58 - 0
examples/industrial_quality_inspection/train_pruned_yolov3.py

@@ -0,0 +1,58 @@
+# 环境变量配置,用于控制是否使用GPU
+# 说明文档:https://paddlex.readthedocs.io/zh_CN/develop/appendix/parameters.html#gpu
+import os
+os.environ['CUDA_VISIBLE_DEVICES'] = '0'
+
+from paddlex.det import transforms
+import paddlex as pdx
+
+# 定义训练和验证时的transforms
+# API说明 https://paddlex.readthedocs.io/zh_CN/develop/apis/transforms/det_transforms.html
+train_transforms = transforms.Compose([
+    transforms.MixupImage(mixup_epoch=250), transforms.RandomDistort(),
+    transforms.RandomExpand(), transforms.RandomCrop(), transforms.Resize(
+        target_size=608, interp='RANDOM'), transforms.RandomHorizontalFlip(),
+    transforms.Normalize()
+])
+
+eval_transforms = transforms.Compose([
+    transforms.Resize(
+        target_size=608, interp='CUBIC'), transforms.Normalize()
+])
+
+# 定义训练和验证所用的数据集
+# API说明:https://paddlex.readthedocs.io/zh_CN/develop/apis/datasets.html#paddlex-datasets-vocdetection
+train_dataset = pdx.datasets.VOCDetection(
+    data_dir='aluminum_inspection',
+    file_list='aluminum_inspection/train_list.txt',
+    label_list='aluminum_inspection/labels.txt',
+    transforms=train_transforms,
+    shuffle=True)
+eval_dataset = pdx.datasets.VOCDetection(
+    data_dir='aluminum_inspection',
+    file_list='aluminum_inspection/val_list.txt',
+    label_list='aluminum_inspection/labels.txt',
+    transforms=eval_transforms)
+
+# 初始化模型,并进行训练
+# 可使用VisualDL查看训练指标,参考https://paddlex.readthedocs.io/zh_CN/develop/train/visualdl.html
+num_classes = len(train_dataset.labels)
+
+# API说明: https://paddlex.readthedocs.io/zh_CN/develop/apis/models/detection.html#paddlex-det-yolov3
+model = pdx.det.YOLOv3(num_classes=num_classes, backbone='MobileNetV3_large')
+
+# API说明: https://paddlex.readthedocs.io/zh_CN/develop/apis/models/detection.html#train
+# 各参数介绍与调整说明:https://paddlex.readthedocs.io/zh_CN/develop/appendix/parameters.html
+model.train(
+    num_epochs=400,
+    train_dataset=train_dataset,
+    train_batch_size=8,
+    eval_dataset=eval_dataset,
+    warmup_steps=4000,
+    learning_rate=0.000125,
+    lr_decay_epochs=[240, 320],
+    pretrain_weights='output/yolov3_mobilenetv3/best_model',
+    save_dir='output/yolov3_mobilenetv3_prune',
+    use_vdl=True,
+    sensitivities_file='./sensitivities.data',
+    eval_metric_loss=0.05)

+ 4 - 4
examples/industrial_quality_inspection/train_rcnn.py

@@ -1,7 +1,7 @@
 # 环境变量配置,用于控制是否使用GPU
 # 说明文档:https://paddlex.readthedocs.io/zh_CN/develop/appendix/parameters.html#gpu
 import os
-os.environ['CUDA_VISIBLE_DEVICES'] = '0,1,2,3,4'
+os.environ['CUDA_VISIBLE_DEVICES'] = '0'
 
 from paddlex.det import transforms
 import paddlex as pdx
@@ -65,10 +65,10 @@ model = pdx.det.FasterRCNN(
 model.train(
     num_epochs=80,
     train_dataset=train_dataset,
-    train_batch_size=10,
+    train_batch_size=2,
     eval_dataset=eval_dataset,
-    learning_rate=0.0125,
+    learning_rate=0.0025,
     lr_decay_epochs=[60, 70],
-    warmup_steps=1000,
+    warmup_steps=5000,
     save_dir='output/faster_rcnn_r50_vd_dcn',
     use_vdl=True)

+ 4 - 3
examples/industrial_quality_inspection/train_yolov3.py

@@ -1,7 +1,7 @@
 # 环境变量配置,用于控制是否使用GPU
 # 说明文档:https://paddlex.readthedocs.io/zh_CN/develop/appendix/parameters.html#gpu
 import os
-os.environ['CUDA_VISIBLE_DEVICES'] = '0,1,2,3,4,5,6,7'
+os.environ['CUDA_VISIBLE_DEVICES'] = '0'
 
 from paddlex.det import transforms
 import paddlex as pdx
@@ -50,9 +50,10 @@ model = pdx.det.YOLOv3(num_classes=num_classes, backbone='MobileNetV3_large')
 model.train(
     num_epochs=400,
     train_dataset=train_dataset,
-    train_batch_size=64,
+    train_batch_size=8,
     eval_dataset=eval_dataset,
-    learning_rate=0.001,
+    warmup_steps=4000,
+    learning_rate=0.000125,
     lr_decay_epochs=[240, 320],
     save_dir='output/yolov3_mobilenetv3',
     use_vdl=True)

+ 0 - 2
paddlex/cv/datasets/voc.py

@@ -146,8 +146,6 @@ class VOCDetection(Dataset):
                     name_tag = pattern.findall(str(ET.tostringlist(obj)))[0][
                         1:-1]
                     cname = obj.find(name_tag).text.strip()
-                    if cname in ['bu_dao_dian', 'jiao_wei_lou_di']:
-                        cname = 'lou_di'
                     gt_class[i][0] = cname2cid[cname]
                     pattern = re.compile('<difficult>', re.IGNORECASE)
                     diff_tag = pattern.findall(str(ET.tostringlist(obj)))

+ 1 - 1
paddlex/cv/models/ppyolo.py

@@ -166,7 +166,7 @@ class PPYOLO(BaseAPI):
             use_matrix_nms=self.use_matrix_nms,
             use_fine_grained_loss=self.use_fine_grained_loss,
             use_iou_loss=self.use_iou_loss,
-            batch_size=getattr(self, 'batch_size_per_gpu', 8),
+            batch_size=getattr(self, 'batch_size_per_gpu', None),
             input_channel=self.input_channel)
         if mode == 'train' and self.use_iou_loss or self.use_iou_aware:
             model.max_height = self.max_height

+ 51 - 2
paddlex/cv/models/utils/detection_eval.py

@@ -1,3 +1,4 @@
+# coding: utf8
 # copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
@@ -771,6 +772,21 @@ class DetectionMAP(object):
 
 
 def makeplot(rs, ps, outDir, class_name, iou_type):
+    """针对某个特定类别,绘制不同评估要求下的准确率和召回率。
+       绘制结果说明参考COCODataset官网给出分析工具说明https://cocodataset.org/#detection-eval。
+
+       Refer to https://github.com/open-mmlab/mmdetection/blob/master/tools/coco_error_analysis.py
+
+       Args:
+           rs (np.array): 在不同置信度阈值下计算得到的召回率。
+           ps (np.array): 在不同置信度阈值下计算得到的准确率。ps与rs相同位置下的数值为同一个置信度阈值
+               计算得到的准确率与召回率。
+           outDir (str): 图表保存的路径。
+           class_name (str): 类别名。
+           iou_type (str): iou计算方式,若为检测框,则设置为'bbox',若为像素级分割结果,则设置为'segm'。
+
+    """
+
     import matplotlib.pyplot as plt
 
     cs = np.vstack([
@@ -809,6 +825,24 @@ def makeplot(rs, ps, outDir, class_name, iou_type):
 
 
 def analyze_individual_category(k, cocoDt, cocoGt, catId, iou_type):
+    """针对某个特定类别,分析忽略亚类混淆和类别混淆时的准确率。
+
+       Refer to https://github.com/open-mmlab/mmdetection/blob/master/tools/coco_error_analysis.py
+
+       Args:
+           k (int): 待分析类别的序号。
+           cocoDt (pycocotols.coco.COCO): 按COCO类存放的预测结果。
+           cocoGt (pycocotols.coco.COCO): 按COCO类存放的真值。
+           catId (int): 待分析类别在数据集中的类别id。
+           iou_type (str): iou计算方式,若为检测框,则设置为'bbox',若为像素级分割结果,则设置为'segm'。
+
+       Returns:
+           int:
+           dict: 有关键字'ps_supercategory'和'ps_allcategory'。关键字'ps_supercategory'的键值是忽略亚类间
+               混淆时的准确率,关键字'ps_allcategory'的键值是忽略类别间混淆时的准确率。
+
+    """
+
     from pycocotools.coco import COCO
     from pycocotools.cocoeval import COCOeval
 
@@ -868,8 +902,23 @@ def coco_error_analysis(eval_details_file=None,
                         pred_bbox=None,
                         pred_mask=None,
                         save_dir='./output'):
-    """
-    Refer to https://github.com/open-mmlab/mmdetection/blob/master/tools/coco_error_analysis.py
+    """逐个分析模型预测错误的原因,并将分析结果以图表的形式展示。
+       分析结果说明参考COCODataset官网给出分析工具说明https://cocodataset.org/#detection-eval。
+
+       Refer to https://github.com/open-mmlab/mmdetection/blob/master/tools/coco_error_analysis.py
+
+       Args:
+           eval_details_file (str):  模型评估结果的保存路径,包含真值信息和预测结果。
+           gt (list): 数据集的真值信息。默认值为None。
+           pred_bbox (list): 模型在数据集上的预测框。默认值为None。
+           pred_mask (list): 模型在数据集上的预测mask。默认值为None。
+           save_dir (str): 可视化结果保存路径。默认值为'./output'。
+
+        Note:
+           eval_details_file的优先级更高,只要eval_details_file不为None,
+           就会从eval_details_file提取真值信息和预测结果做分析。
+           当eval_details_file为None时,则用gt、pred_mask、pred_mask做分析。
+
     """
 
     from multiprocessing import Pool