Bläddra i källkod

add more comments for meter reader case

FlyingQianMM 4 år sedan
förälder
incheckning
27d4e58669

+ 6 - 0
examples/meter_reader/deploy/cpp/meter_reader/global.cpp

@@ -23,10 +23,16 @@
 
 #include "meter_reader/global.h"
 
+// The size of inputting images of the detector
 std::vector<int> IMAGE_SHAPE = {1920, 1080};
+// The size of visualized prediction
 std::vector<int> RESULT_SHAPE = {1280, 720};
+// The size of inputting images of the segmenter,
+// also the size of circular meters
 std::vector<int> METER_SHAPE = {512, 512};
 
+// The configuration information of a meter,
+// composed of scale value, range, unit
 #define METER_TYPE_NUM 2
 MeterConfig_T meter_config[METER_TYPE_NUM] = {
 {25.0f/50.0f, 25.0f,  "(MPa)"},

+ 9 - 1
examples/meter_reader/deploy/cpp/meter_reader/global.h

@@ -16,15 +16,23 @@
 
 #include <vector>
 
+// The configuration information of a meter, composed of scale value,
+// range, unit.
 typedef struct MeterConfig {
   float scale_value;
   float range;
   char  str[10];
 } MeterConfig_T;
 
+// The size of inputting images of the detector
 extern std::vector<int> IMAGE_SHAPE;
+// The size of visualized prediction
 extern std::vector<int> RESULT_SHAPE;
+// The size of inputting images of the segmenter,
+// also the size of circular meters.
 extern std::vector<int> METER_SHAPE;
 extern MeterConfig_T meter_config[];
-
+// The type of a meter is estimated by a threshold. If the number of scales
+// in a meter is greater than or equal to the threshold, the meter is
+// belong to the former type. Otherwize, the latter.
 #define TYPE_THRESHOLD 40

+ 10 - 0
examples/meter_reader/deploy/cpp/meter_reader/meter_reader.cpp

@@ -59,9 +59,11 @@ void predict(const cv::Mat &input_image, PaddleX::Model *det_model,
              const std::string image_path, const bool use_erode,
              const int erode_kernel, const int thread_num,
              const int seg_batch_size, const double threshold) {
+  // Get detection results
   PaddleX::DetResult det_result;
   det_model->predict(input_image, &det_result);
 
+  // Filter bbox whose score is lower than score_threshold
   PaddleX::DetResult filter_result;
   int num_bboxes = det_result.boxes.size();
   for (int i = 0; i < num_bboxes; ++i) {
@@ -90,6 +92,7 @@ void predict(const cv::Mat &input_image, PaddleX::Model *det_model,
     int batch_thread_num = std::min(thread_num, im_vec_size - i);
     #pragma omp parallel for num_threads(batch_thread_num)
     for (int j = i; j < im_vec_size; ++j) {
+      // Crop the bbox area
       int left = static_cast<int>(filter_result.boxes[j].coordinate[0]);
       int top = static_cast<int>(filter_result.boxes[j].coordinate[1]);
       int width = static_cast<int>(filter_result.boxes[j].coordinate[2]);
@@ -99,6 +102,7 @@ void predict(const cv::Mat &input_image, PaddleX::Model *det_model,
 
       cv::Mat sub_image = input_image(
         cv::Range(top, bottom + 1), cv::Range(left, right + 1));
+      // Resize the image with shape (METER_SHAPE, METER_SHAPE)
       float scale_x =
         static_cast<float>(METER_SHAPE[0]) / static_cast<float>(sub_image.cols);
       float scale_y =
@@ -111,10 +115,12 @@ void predict(const cv::Mat &input_image, PaddleX::Model *det_model,
                  cv::INTER_LINEAR);
       meters_image[j - i] = std::move(sub_image);
     }
+    // Segment scales and point in each meter area
     std::vector<PaddleX::SegResult> batch_result(im_vec_size - i);
     seg_model->predict(meters_image, &batch_result, batch_thread_num);
     #pragma omp parallel for num_threads(batch_thread_num)
     for (int j = i; j < im_vec_size; ++j) {
+      // Do image erosion for the predicted label map of each meter
       if (use_erode) {
         cv::Mat kernel(4, 4, CV_8U, cv::Scalar(1));
         std::vector<uint8_t> label_map(
@@ -144,10 +150,13 @@ void predict(const cv::Mat &input_image, PaddleX::Model *det_model,
 
   std::vector<READ_RESULT> read_results(meter_num);
   int all_thread_num = std::min(thread_num, meter_num);
+  // The postprocess are done to get the point location relative to the scales
   read_process(seg_result, &read_results, all_thread_num);
 
   cv::Mat output_image = input_image.clone();
   for (int i = 0; i < meter_num; i++) {
+    // Provide a digital readout according to point location relative
+    // to the scales
     float result = 0;;
     if (read_results[i].scale_num > TYPE_THRESHOLD) {
       result = read_results[i].scales * meter_config[0].scale_value;
@@ -158,6 +167,7 @@ void predict(const cv::Mat &input_image, PaddleX::Model *det_model,
               << " -- result: " << result
               << " --" << std::endl;
 
+    // Visualize the results
     int lx = static_cast<int>(filter_result.boxes[i].coordinate[0]);
     int ly = static_cast<int>(filter_result.boxes[i].coordinate[1]);
     int w = static_cast<int>(filter_result.boxes[i].coordinate[2]);

+ 25 - 3
examples/meter_reader/deploy/cpp/meter_reader/postprocess.cpp

@@ -29,12 +29,21 @@
 
 using namespace std::chrono;  // NOLINT
 
+// The size of inputting images (SEG_IMAGE_SIZE x SEG_IMAGE_SIZE) of
+// the segmenter.
 #define SEG_IMAGE_SIZE 512
+// During the postprocess phase, annulus formed by the radius from
+// 130 to 250 of a circular meter will be converted to a rectangle.
+// So the height of the rectangle is 120.
 #define LINE_HEIGHT 120
+// The width of the rectangle is 1570, that is to say the perimeter
+// of a circular meter.
 #define LINE_WIDTH 1570
+// Radius of a circular meter
 #define CIRCLE_RADIUS 250
 
 const float pi = 3.1415926536f;
+// Center of a circular meter
 const int circle_center[] = {256, 256};
 
 
@@ -45,14 +54,17 @@ void creat_line_image(const std::vector<int64_t> &seg_image,
   int image_x;
   int image_y;
 
+  // The minimum scale value is at the bottom left, the maximum scale value
+  // is at the bottom right, so the vertical down axis is the starting axis and
+  // rotates around the meter ceneter counterclockwise.
   for (int row = 0; row < LINE_HEIGHT; row++) {
     for (int col = 0; col < LINE_WIDTH; col++) {
       theta = pi * 2 / LINE_WIDTH * (col + 1);
       rho = CIRCLE_RADIUS - row - 1;
-      image_x = static_cast<int>(circle_center[0] + rho * cos(theta) + 0.5);
-      image_y = static_cast<int>(circle_center[1] - rho * sin(theta) + 0.5);
+      image_y = static_cast<int>(circle_center[0] + rho * cos(theta) + 0.5);
+      image_x = static_cast<int>(circle_center[1] - rho * sin(theta) + 0.5);
       (*output)[row * LINE_WIDTH + col] =
-        seg_image[image_x * SEG_IMAGE_SIZE + image_y];
+        seg_image[image_y * SEG_IMAGE_SIZE + image_x];
     }
   }
 
@@ -62,6 +74,8 @@ void creat_line_image(const std::vector<int64_t> &seg_image,
 void convert_1D_data(const std::vector<unsigned char> &line_image,
                      std::vector<unsigned int> *scale_data,
                      std::vector<unsigned int> *pointer_data) {
+  // Accumulte the number of positions whose label is 1 along the height axis.
+  // Accumulte the number of positions whose label is 2 along the height axis.
   for (int col = 0; col < LINE_WIDTH; col++) {
     (*scale_data)[col] = 0;
     (*pointer_data)[col] = 0;
@@ -172,15 +186,23 @@ void read_process(const std::vector<std::vector<int64_t>> &seg_image,
     int read_num = seg_image.size();
     #pragma omp parallel for num_threads(thread_num)
     for (int i_read = 0; i_read < read_num; i_read++) {
+        // Convert the circular meter into a rectangular meter
         std::vector<unsigned char> line_result(LINE_WIDTH*LINE_HEIGHT, 0);
         creat_line_image(seg_image[i_read], &line_result);
 
+        // Get two one-dimension data where 0 represents background and
+        // >0 represents a scale or a pointer
         std::vector<unsigned int> scale_data(LINE_WIDTH);
         std::vector<unsigned int> pointer_data(LINE_WIDTH);
         convert_1D_data(line_result, &scale_data, &pointer_data);
+        // Fliter scale data whose value is lower than the mean value
         std::vector<unsigned int> scale_mean_data(LINE_WIDTH);
         scale_mean_filtration(scale_data, &scale_mean_data);
 
+        // Get the number of scales,the pointer location relative to the
+        // scales, the ratio between the distance from the pointer to the
+        // starting scale and distance from the ending scale to the
+        // starting scale.
         READ_RESULT result;
         get_meter_reader(scale_mean_data, pointer_data, &result);
 

+ 4 - 0
examples/meter_reader/deploy/cpp/meter_reader/postprocess.h

@@ -18,8 +18,12 @@
 #include <vector>
 
 struct READ_RESULT {
+  // the number of scales
   int scale_num;
+  // the pointer location relative to the scales
   float scales;
+  // the ratio between from the pointer to the starting scale and
+  // distance from the ending scale to the starting scale
   float ratio;
 };
 

+ 105 - 9
examples/meter_reader/deploy/python/reader_deploy.py

@@ -23,13 +23,25 @@ import argparse
 from paddlex.seg import transforms
 import paddlex as pdx
 
+# The size of inputting images (METER_SHAPE x METER_SHAPE) of the segmenter,
+# also the size of circular meters.
 METER_SHAPE = 512
+# Center of a circular meter
 CIRCLE_CENTER = [256, 256]
+# Radius of a circular meter
 CIRCLE_RADIUS = 250
 PI = 3.1415926536
+# During the postprocess phase, annulus formed by the radius from
+# 130 to 250 of a circular meter will be converted to a rectangle.
+# So the height of the rectangle is 120.
 LINE_HEIGHT = 120
+# The width of the rectangle is 1570, that is to say the perimeter of a circular meter
 LINE_WIDTH = 1570
+# The type of a meter is estimated by a threshold. If the number of scales in a meter is
+# greater than or equal to the threshold, the meter is belong to the former type.
+# Otherwize, the latter.
 TYPE_THRESHOLD = 40
+# The configuration information of a meter, composed of scale value, range, unit.
 METER_CONFIG = [{
     'scale_value': 25.0 / 50.0,
     'range': 25.0,
@@ -118,6 +130,14 @@ def is_pic(img_name):
 
 
 class MeterReader:
+    """Find the meters in images and provide a digital readout of each meter.
+
+    Args:
+        detector_dir(str): directory of the detector.
+        segmenter_dir(str): directory of the segmenter.
+
+    """
+
     def __init__(self, detector_dir, segmenter_dir):
         if not osp.exists(detector_dir):
             raise Exception("Model path {} does not exist".format(
@@ -138,6 +158,20 @@ class MeterReader:
                 erode_kernel=4,
                 score_threshold=0.5,
                 seg_batch_size=2):
+        """Detect meters in a image, segment scales and points in these meters, the postprocess are
+        done to provide a digital readout according to scale and point location.
+
+        Args:
+            im_file (str):  the path of a image to be predicted.
+            save_dir (str): the directory to save the visual prediction. Default: './'.
+            use_erode (bool, optional): whether to do image erosion by using a specific structuring element for
+                the label map output from the segmenter. Default: True.
+            erode_kernel (int, optional): structuring element used for erosion. Default: 4.
+            score_threshold (float, optional): detected meters whose scores are not lower than `score_threshold`
+                will be fed into the following segmenter. Default: 0.5.
+            seg_batch_size (int, optional): batch size of meters when do segmentation. Default: 2.
+
+        """
         if isinstance(im_file, str):
             im = cv2.imread(im_file).astype('float32')
         else:
@@ -181,9 +215,10 @@ class MeterReader:
             meter_images = list()
             for j in range(i, im_size):
                 meter_images.append(resized_meters[j - i])
+            # Segment scales and point in each meter area
             result = self.segmenter.batch_predict(
-                transforms=self.seg_transforms,
-                img_file_list=meter_images)
+                transforms=self.seg_transforms, img_file_list=meter_images)
+            # Do image erosion for the predicted label map of each meter
             if use_erode:
                 kernel = np.ones((erode_kernel, erode_kernel), np.uint8)
                 for i in range(len(result)):
@@ -192,10 +227,12 @@ class MeterReader:
             seg_results.extend(result)
 
         results = list()
+        # The postprocess are done to get the point location relative to the scales
         for i, seg_result in enumerate(seg_results):
             result = self.read_process(seg_result['label_map'])
             results.append(result)
 
+        # Provide a digital readout according to point location relative to the scales
         meter_values = list()
         for i, result in enumerate(results):
             if result['scale_num'] > TYPE_THRESHOLD:
@@ -205,7 +242,7 @@ class MeterReader:
             meter_values.append(value)
             print("-- Meter {} -- result: {} --\n".format(i, value))
 
-        # visualize the results
+        # Visualize the results
         visual_results = list()
         for i, res in enumerate(filtered_results):
             # Use `score` to represent the meter value
@@ -214,30 +251,68 @@ class MeterReader:
         pdx.det.visualize(im_file, visual_results, -1, save_dir=save_dir)
 
     def read_process(self, label_maps):
-        # Convert the circular meter into rectangular meter
+        """Get the pointer location relative to the scales.
+
+        Args:
+            label_maps (np.array): the label map output from a segmeter for a meter.
+
+        """
+        # Convert the circular meter into a rectangular meter
         line_images = self.creat_line_image(label_maps)
-        # Convert the 2d meter into 1d meter
+        # Get two one-dimension data where 0 represents background and >0 represents
+        # a scale or a pointer
         scale_data, pointer_data = self.convert_1d_data(line_images)
         # Fliter scale data whose value is lower than the mean value
         self.scale_mean_filtration(scale_data)
-        # Get scale_num, scales and ratio of meters
+        # Get the number of scales,the pointer location relative to the scales, the ratio between
+        # the distance from the pointer to the starting scale and distance from the ending scale to the
+        # starting scale.
         result = self.get_meter_reader(scale_data, pointer_data)
         return result
 
     def creat_line_image(self, meter_image):
+        """Convert the circular meter into a rectangular meter.
+
+        The minimum scale value is at the bottom left, the maximum scale value
+        is at the bottom right, so the vertical down axis is the starting axis and
+        rotates around the meter ceneter counterclockwise.
+
+        Args:
+            meter_image (np.array): the label map output from a segmeter for a meter.
+
+        Returns:
+            line_image (np.array): a rectangular meter.
+        """
+
         line_image = np.zeros((LINE_HEIGHT, LINE_WIDTH), dtype=np.uint8)
         for row in range(LINE_HEIGHT):
             for col in range(LINE_WIDTH):
                 theta = PI * 2 / LINE_WIDTH * (col + 1)
                 rho = CIRCLE_RADIUS - row - 1
-                x = int(CIRCLE_CENTER[0] + rho * math.cos(theta) + 0.5)
-                y = int(CIRCLE_CENTER[1] - rho * math.sin(theta) + 0.5)
-                line_image[row, col] = meter_image[x, y]
+                y = int(CIRCLE_CENTER[0] + rho * math.cos(theta) + 0.5)
+                x = int(CIRCLE_CENTER[1] - rho * math.sin(theta) + 0.5)
+                line_image[row, col] = meter_image[y, x]
         return line_image
 
     def convert_1d_data(self, meter_image):
+        """Get two one-dimension data where 0 represents background and >0 represents
+           a scale or a pointer from the rectangular meter.
+
+        Args:
+            meter_image (np.array): the two-dimension rectangular meter output
+                from function creat_line_image().
+
+        Returns:
+            scale_data (np.array): a one-dimension data where 0 represents background and
+                >0 represents scales.
+            pointer_data (np.array): a one-dimension data where 0 represents background and
+                >0 represents a pointer.
+        """
+
         scale_data = np.zeros((LINE_WIDTH), dtype=np.uint8)
         pointer_data = np.zeros((LINE_WIDTH), dtype=np.uint8)
+        # Accumulte the number of positions whose label is 1 along the height axis.
+        # Accumulte the number of positions whose label is 2 along the height axis.
         for col in range(LINE_WIDTH):
             for row in range(LINE_HEIGHT):
                 if meter_image[row, col] == 1:
@@ -247,12 +322,33 @@ class MeterReader:
         return scale_data, pointer_data
 
     def scale_mean_filtration(self, scale_data):
+        """Set the element in the scale data which is lower than its mean value to 0.
+
+        Args:
+            scale_data (np.array): the scale data output from function convert_1d_data().
+        """
         mean_data = np.mean(scale_data)
         for col in range(LINE_WIDTH):
             if scale_data[col] < mean_data:
                 scale_data[col] = 0
 
     def get_meter_reader(self, scale_data, pointer_data):
+        """Calculate the number of scales,the pointer location relative to the scales, the ratio between
+        the distance from the pointer to the starting scale and distance from the ending scale to the
+        starting scale.
+
+        Args:
+            scale_data (np.array): a scale data output from function scale_mean_filtration().
+            pointer_data (np.array): a pointer data output from function convert_1d_data().
+
+        Returns:
+            Dict (keys: 'scale_num', 'scales', 'ratio'):
+                The value of key 'scale_num' (int): the number of scales;
+                The value of 'scales' (float): the pointer location relative to the scales;
+                the value of 'ratio' (float): the ratio between from the pointer to the starting scale and
+                distance from the ending scale to the starting scale.
+
+        """
         scale_flag = False
         pointer_flag = False
         one_scale_start = 0

+ 105 - 9
examples/meter_reader/reader_infer.py

@@ -23,13 +23,25 @@ import argparse
 from paddlex.seg import transforms
 import paddlex as pdx
 
+# The size of inputting images (METER_SHAPE x METER_SHAPE) of the segmenter,
+# also the size of circular meters.
 METER_SHAPE = 512
+# Center of a circular meter
 CIRCLE_CENTER = [256, 256]
+# Radius of a circular meter
 CIRCLE_RADIUS = 250
 PI = 3.1415926536
+# During the postprocess phase, annulus formed by the radius from
+# 130 to 250 of a circular meter will be converted to a rectangle.
+# So the height of the rectangle is 120.
 LINE_HEIGHT = 120
+# The width of the rectangle is 1570, that is to say the perimeter of a circular meter
 LINE_WIDTH = 1570
+# The type of a meter is estimated by a threshold. If the number of scales in a meter is
+# greater than or equal to the threshold, the meter is belong to the former type.
+# Otherwize, the latter.
 TYPE_THRESHOLD = 40
+# The configuration information of a meter, composed of scale value, range, unit.
 METER_CONFIG = [{
     'scale_value': 25.0 / 50.0,
     'range': 25.0,
@@ -118,6 +130,14 @@ def is_pic(img_name):
 
 
 class MeterReader:
+    """Find the meters in images and provide a digital readout of each meter.
+
+    Args:
+        detector_dir(str): directory of the detector.
+        segmenter_dir(str): directory of the segmenter.
+
+    """
+
     def __init__(self, detector_dir, segmenter_dir):
         if not osp.exists(detector_dir):
             raise Exception("Model path {} does not exist".format(
@@ -138,6 +158,20 @@ class MeterReader:
                 erode_kernel=4,
                 score_threshold=0.5,
                 seg_batch_size=2):
+        """Detect meters in a image, segment scales and points in these meters, the postprocess are
+        done to provide a digital readout according to scale and point location.
+
+        Args:
+            im_file (str):  the path of a image to be predicted.
+            save_dir (str): the directory to save the visual prediction. Default: './'.
+            use_erode (bool, optional): whether to do image erosion by using a specific structuring element for
+                the label map output from the segmenter. Default: True.
+            erode_kernel (int, optional): structuring element used for erosion. Default: 4.
+            score_threshold (float, optional): detected meters whose scores are not lower than `score_threshold`
+                will be fed into the following segmenter. Default: 0.5.
+            seg_batch_size (int, optional): batch size of meters when do segmentation. Default: 2.
+
+        """
         if isinstance(im_file, str):
             im = cv2.imread(im_file).astype('float32')
         else:
@@ -181,9 +215,10 @@ class MeterReader:
             meter_images = list()
             for j in range(i, im_size):
                 meter_images.append(resized_meters[j - i])
+            # Segment scales and point in each meter area
             result = self.segmenter.batch_predict(
-                transforms=self.seg_transforms,
-                img_file_list=meter_images)
+                transforms=self.seg_transforms, img_file_list=meter_images)
+            # Do image erosion for the predicted label map of each meter
             if use_erode:
                 kernel = np.ones((erode_kernel, erode_kernel), np.uint8)
                 for i in range(len(result)):
@@ -192,10 +227,12 @@ class MeterReader:
             seg_results.extend(result)
 
         results = list()
+        # The postprocess are done to get the point location relative to the scales
         for i, seg_result in enumerate(seg_results):
             result = self.read_process(seg_result['label_map'])
             results.append(result)
 
+        # Provide a digital readout according to point location relative to the scales
         meter_values = list()
         for i, result in enumerate(results):
             if result['scale_num'] > TYPE_THRESHOLD:
@@ -205,7 +242,7 @@ class MeterReader:
             meter_values.append(value)
             print("-- Meter {} -- result: {} --\n".format(i, value))
 
-        # visualize the results
+        # Visualize the results
         visual_results = list()
         for i, res in enumerate(filtered_results):
             # Use `score` to represent the meter value
@@ -214,30 +251,68 @@ class MeterReader:
         pdx.det.visualize(im_file, visual_results, -1, save_dir=save_dir)
 
     def read_process(self, label_maps):
-        # Convert the circular meter into rectangular meter
+        """Get the pointer location relative to the scales.
+
+        Args:
+            label_maps (np.array): the label map output from a segmeter for a meter.
+
+        """
+        # Convert the circular meter into a rectangular meter
         line_images = self.creat_line_image(label_maps)
-        # Convert the 2d meter into 1d meter
+        # Get two one-dimension data where 0 represents background and >0 represents
+        # a scale or a pointer
         scale_data, pointer_data = self.convert_1d_data(line_images)
         # Fliter scale data whose value is lower than the mean value
         self.scale_mean_filtration(scale_data)
-        # Get scale_num, scales and ratio of meters
+        # Get the number of scales,the pointer location relative to the scales, the ratio between
+        # the distance from the pointer to the starting scale and distance from the ending scale to the
+        # starting scale.
         result = self.get_meter_reader(scale_data, pointer_data)
         return result
 
     def creat_line_image(self, meter_image):
+        """Convert the circular meter into a rectangular meter.
+
+        The minimum scale value is at the bottom left, the maximum scale value
+        is at the bottom right, so the vertical down axis is the starting axis and
+        rotates around the meter ceneter counterclockwise.
+
+        Args:
+            meter_image (np.array): the label map output from a segmeter for a meter.
+
+        Returns:
+            line_image (np.array): a rectangular meter.
+        """
+
         line_image = np.zeros((LINE_HEIGHT, LINE_WIDTH), dtype=np.uint8)
         for row in range(LINE_HEIGHT):
             for col in range(LINE_WIDTH):
                 theta = PI * 2 / LINE_WIDTH * (col + 1)
                 rho = CIRCLE_RADIUS - row - 1
-                x = int(CIRCLE_CENTER[0] + rho * math.cos(theta) + 0.5)
-                y = int(CIRCLE_CENTER[1] - rho * math.sin(theta) + 0.5)
-                line_image[row, col] = meter_image[x, y]
+                y = int(CIRCLE_CENTER[0] + rho * math.cos(theta) + 0.5)
+                x = int(CIRCLE_CENTER[1] - rho * math.sin(theta) + 0.5)
+                line_image[row, col] = meter_image[y, x]
         return line_image
 
     def convert_1d_data(self, meter_image):
+        """Get two one-dimension data where 0 represents background and >0 represents
+           a scale or a pointer from the rectangular meter.
+
+        Args:
+            meter_image (np.array): the two-dimension rectangular meter output
+                from function creat_line_image().
+
+        Returns:
+            scale_data (np.array): a one-dimension data where 0 represents background and
+                >0 represents scales.
+            pointer_data (np.array): a one-dimension data where 0 represents background and
+                >0 represents a pointer.
+        """
+
         scale_data = np.zeros((LINE_WIDTH), dtype=np.uint8)
         pointer_data = np.zeros((LINE_WIDTH), dtype=np.uint8)
+        # Accumulte the number of positions whose label is 1 along the height axis.
+        # Accumulte the number of positions whose label is 2 along the height axis.
         for col in range(LINE_WIDTH):
             for row in range(LINE_HEIGHT):
                 if meter_image[row, col] == 1:
@@ -247,12 +322,33 @@ class MeterReader:
         return scale_data, pointer_data
 
     def scale_mean_filtration(self, scale_data):
+        """Set the element in the scale data which is lower than its mean value to 0.
+
+        Args:
+            scale_data (np.array): the scale data output from function convert_1d_data().
+        """
         mean_data = np.mean(scale_data)
         for col in range(LINE_WIDTH):
             if scale_data[col] < mean_data:
                 scale_data[col] = 0
 
     def get_meter_reader(self, scale_data, pointer_data):
+        """Calculate the number of scales,the pointer location relative to the scales, the ratio between
+        the distance from the pointer to the starting scale and distance from the ending scale to the
+        starting scale.
+
+        Args:
+            scale_data (np.array): a scale data output from function scale_mean_filtration().
+            pointer_data (np.array): a pointer data output from function convert_1d_data().
+
+        Returns:
+            Dict (keys: 'scale_num', 'scales', 'ratio'):
+                The value of key 'scale_num' (int): the number of scales;
+                The value of 'scales' (float): the pointer location relative to the scales;
+                the value of 'ratio' (float): the ratio between from the pointer to the starting scale and
+                distance from the ending scale to the starting scale.
+
+        """
         scale_flag = False
         pointer_flag = False
         one_scale_start = 0