浏览代码

add analysis

FlyingQianMM 5 年之前
父节点
当前提交
1093eacda6

+ 1 - 0
docs/apis/index.rst

@@ -6,6 +6,7 @@ API接口说明
 
    transforms/index.rst
    datasets.md
+   analysis.md
    models/index.rst
    slim.md
    visualize.md

+ 17 - 4
docs/apis/transforms/seg_transforms.md

@@ -78,16 +78,19 @@ paddlex.seg.transforms.ResizeStepScaling(min_scale_factor=0.75, max_scale_factor
 
 ## Normalize
 ```python
-paddlex.seg.transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
+paddlex.seg.transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], min_val=[0, 0, 0], max_val=[255.0, 255.0, 255.0])
 ```
 对图像进行标准化。
 
-1.图像像素归一化到区间 [0.0, 1.0]。
-2.对图像进行减均值除以标准差操作。
+1.像素值减去min_val
+2.像素值除以(max_val-min_val), 归一化到区间 [0.0, 1.0]。
+3.对图像进行减均值除以标准差操作。
+
 ### 参数
 * **mean** (list): 图像数据集的均值。默认值[0.5, 0.5, 0.5]。
 * **std** (list): 图像数据集的标准差。默认值[0.5, 0.5, 0.5]。
-
+* **min_val** (list): 图像数据集的最小值。默认值[0, 0, 0]。
+* **max_val** (list): 图像数据集的最大值。默认值[255.0, 255.0, 255.0]。
 
 ## Padding
 ```python
@@ -167,6 +170,16 @@ paddlex.seg.transforms.RandomDistort(brightness_range=0.5, brightness_prob=0.5,
 * **hue_range** (int): 色调因子的范围。默认为18。
 * **hue_prob** (float): 随机调整色调的概率。默认为0.5。
 
+## Clip
+```python
+paddlex.seg.transforms.Clip(min_val=[0, 0, 0], max_val=[255.0, 255.0, 255.0])
+```
+对图像上超出一定范围的数据进行截断。
+
+### 参数
+* **min_val** (list): 裁剪的下限,小于min_val的数值均设为min_val. 默认值0。
+* **max_val** (list): 裁剪的上限,大于max_val的数值均设为max_val. 默认值255.0。
+
 <!--
 ## ComposedSegTransforms
 ```python

+ 2 - 1
docs/apis/visualize.md

@@ -27,7 +27,7 @@ pdx.det.visualize('./xiaoduxiong_epoch_12/xiaoduxiong.jpeg', result, save_dir='.
 ## paddlex.seg.visualize
 > **语义分割模型预测结果可视化**  
 ```
-paddlex.seg.visualize(image, result, weight=0.6, save_dir='./')
+paddlex.seg.visualize(image, result, weight=0.6, save_dir='./', color=None)
 ```
 将语义分割模型预测得到的Mask在原图上进行可视化。
 
@@ -36,6 +36,7 @@ paddlex.seg.visualize(image, result, weight=0.6, save_dir='./')
 > * **result** (str): 模型预测结果。
 > * **weight**(float): mask可视化结果与原图权重因子,weight表示原图的权重。默认0.6。
 > * **save_dir**(str): 可视化结果保存路径。若为None,则表示不保存,该函数将可视化的结果以np.ndarray的形式返回;若设为目录路径,则将可视化结果保存至该目录下。默认值为'./'。
+> * **color** (list): 各类别的BGR颜色值组成的列表。例如两类时可设置为[255, 255, 255, 0, 0, 0]。默认值为None,则使用默认生成的颜色列表。
 
 ### 使用示例
 > 点击下载如下示例中的[模型](https://bj.bcebos.com/paddlex/models/cityscape_deeplab.tar.gz)和[测试图片](https://bj.bcebos.com/paddlex/datasets/city.png)

+ 1 - 0
docs/examples/index.rst

@@ -12,3 +12,4 @@ PaddleX精选飞桨视觉开发套件在产业实践中的成熟模型结构,
    solutions.md
    meter_reader.md
    human_segmentation.md
+   multi-channel_remote_sensing/README.md

+ 125 - 0
docs/examples/multi-channel_remote_sensing/README.md

@@ -0,0 +1,125 @@
+# 多通道遥感影像分割
+遥感影像分割是图像分割领域中的重要应用场景,广泛应用于土地测绘、环境监测、城市建设等领域。遥感影像分割的目标多种多样,有诸如积雪、农作物、道路、建筑、水源等地物目标,也有例如云层的空中目标。
+
+本案例基于PaddleX实现多通道遥感影像分割,涵盖数据分析、模型训练、模型预测等流程,旨在帮助用户利用深度学习技术解决多通道遥感影像分割问题。
+
+
+## 前置依赖
+* Paddle paddle >= 1.8.4
+* Python >= 3.5
+* PaddleX >= 1.1.0
+
+安装的相关问题参考[PaddleX安装](../../install.md)
+
+**另外还需安装gdal**, 使用pip安装gdal可能出错,推荐使用conda进行安装:
+
+```
+conda install gdal
+```
+
+下载PaddleX源码:  
+
+```  
+git clone https://github.com/PaddlePaddle/PaddleX
+```
+
+该案例所有脚本均位于`PaddleX/examples/channel_remote_sensing/`,进入该目录:  
+
+```
+cd PaddleX/examples/channel_remote_sensing/  
+```
+
+## 数据准备
+遥感影像的格式多种多样,不同传感器产生的数据格式也可能不同。PaddleX现已兼容以下4种格式图片读取:
+
+- `tif`
+- `png`
+- `img`
+- `npy`
+
+标注图要求必须为单通道的png格式图像,像素值即为对应的类别,像素标注类别需要从0开始递增。例如0,1,2,3表示有4种类别,255用于指定不参与训练和评估的像素,标注类别最多为256类。
+
+本案例使用[L8 SPARCS公开数据集](https://www.usgs.gov/land-resources/nli/landsat/spatial-procedures-automated-removal-cloud-and-shadow-sparcs-validation)进行云雪分割,该数据集包含80张卫星影像,涵盖10个波段。原始标注图片包含7个类别,分别是`cloud`, `cloud shadow`, `shadow over water`, `snow/ice`, `water`, `land`和`flooded`。由于`flooded`和`shadow over water`2个类别占比仅为`1.8%`和`0.24%`,我们将其进行合并,`flooded`归为`land`,`shadow over water`归为`shadow`,合并后标注包含5个类别。
+
+数值、类别、颜色对应表:
+
+|Pixel value|Class|Color|
+|---|---|---|
+|0|cloud|white|
+|1|shadow|black|
+|2|snow/ice|cyan|
+|3|water|blue|
+|4|land|grey|
+
+ ![](../../../examples/multi-channel_remote_sensing/docs/images/dataset.png)
+
+
+执行以下命令下载并解压经过类别合并后的数据集:
+```shell script
+mkdir dataset && cd dataset
+wget https://paddleseg.bj.bcebos.com/dataset/remote_sensing_seg.zip
+unzip remote_sensing_seg.zip
+cd ..
+```
+其中`data`目录存放遥感影像,`data_vis`目录存放彩色合成预览图,`mask`目录存放标注图。
+
+## 数据分析  
+遥感影像往往由许多波段组成,不同波段数据分布可能大相径庭,例如可见光波段和热红外波段分布十分不同。为了更深入了解数据的分布来优化模型训练效果,需要对数据进行分析。
+
+参考文档[数据分析](./analysis.md)对训练集进行统计分析,确定图像像素值的截断范围,并统计截断后的均值和方差。
+
+## 模型训练
+本案例选择`UNet`语义分割模型完成云雪分割,运行以下步骤完成模型训练,模型的最优精度`miou`为`77.99%`。
+
+* 设置GPU卡号
+```shell script
+export CUDA_VISIBLE_DEVICES=0
+```
+
+* 运行以下脚本开始训练
+```shell script
+python train.py --data_dir dataset/remote_sensing_seg \
+--train_file_list dataset/remote_sensing_seg/train.txt \
+--eval_file_list dataset/remote_sensing_seg/val.txt \
+--label_list dataset/remote_sensing_seg/labels.txt \
+--save_dir saved_model/remote_sensing_unet \
+--num_classes 5 \
+--channel 10 \
+--lr 0.01 \
+--clip_min_value 7172 6561 5777 5103 4291 4000 4000 4232 6934 7199 \
+--clip_max_value 50000 50000 50000 50000 50000 40000 30000 18000 40000 36000 \
+--mean 0.14311188522260637 0.14288498042151332 0.14812997807748615 0.16377211813814938 0.2737538363784552 0.2740934379398823 0.27749601919204 0.07767443032935262 0.5694699410349131 0.5549716085195542 \
+--std 0.09101632762467489 0.09600705942721106 0.096193618606776 0.10371446736389771 0.10911951586604118 0.11043593115173281 0.12648042598739268 0.027746262217260665 0.06822348076384514 0.062377591186668725 \
+--num_epochs 500 \
+--train_batch_size 3
+```
+
+也可以跳过模型训练步骤,下载预训练模型直接进行模型预测:
+
+```
+wget https://bj.bcebos.com/paddlex/examples/multi-channel_remote_sensing/models/l8sparcs_remote_model.tar.gz
+tar -xvf l8sparcs_remote_model.tar.gz
+```
+
+## 模型预测
+运行以下脚本,对遥感图像进行预测并可视化预测结果,相应地也将对应的标注文件进行可视化,以比较预测效果。
+
+```shell script
+export CUDA_VISIBLE_DEVICES=0
+python predict.py
+```
+可视化效果如下所示:
+
+
+![](../../../examples/multi-channel_remote_sensing/docs/images/prediction.jpg)
+
+
+数值、类别、颜色对应表:
+
+|Pixel value|Class|Color|
+|---|---|---|
+|0|cloud|white|
+|1|shadow|black|
+|2|snow/ice|cyan|
+|3|water|blue|
+|4|land|grey|

+ 140 - 0
examples/multi-channel_remote_sensing/README.md

@@ -0,0 +1,140 @@
+# 多通道遥感影像分割
+遥感影像分割是图像分割领域中的重要应用场景,广泛应用于土地测绘、环境监测、城市建设等领域。遥感影像分割的目标多种多样,有诸如积雪、农作物、道路、建筑、水源等地物目标,也有例如云层的空中目标。
+
+本案例基于PaddleX实现多通道遥感影像分割,涵盖数据分析、模型训练、模型预测等流程,旨在帮助用户利用深度学习技术解决多通道遥感影像分割问题。
+
+## 目录
+* [前置依赖](#1)
+* [数据准备](#2)
+* [数据分析](#3)
+* [模型训练](#4)
+* [模型预测](#5)
+
+## <h2 id="1">前置依赖</h2>
+
+* Paddle paddle >= 1.8.4
+* Python >= 3.5
+* PaddleX >= 1.1.0
+
+安装的相关问题参考[PaddleX安装](../../docs/install.md)
+
+**另外还需安装gdal**, 使用pip安装gdal可能出错,推荐使用conda进行安装:
+
+```
+conda install gdal
+```
+
+下载PaddleX源码:  
+
+```  
+git clone https://github.com/PaddlePaddle/PaddleX
+```
+
+该案例所有脚本均位于`PaddleX/examples/channel_remote_sensing/`,进入该目录:  
+
+```
+cd PaddleX/examples/channel_remote_sensing/  
+```
+
+##  <h2 id="2">数据准备</h2>
+
+遥感影像的格式多种多样,不同传感器产生的数据格式也可能不同。PaddleX现已兼容以下4种格式图片读取:
+
+- `tif`
+- `png`
+- `img`
+- `npy`
+
+标注图要求必须为单通道的png格式图像,像素值即为对应的类别,像素标注类别需要从0开始递增。例如0,1,2,3表示有4种类别,255用于指定不参与训练和评估的像素,标注类别最多为256类。
+
+本案例使用[L8 SPARCS公开数据集](https://www.usgs.gov/land-resources/nli/landsat/spatial-procedures-automated-removal-cloud-and-shadow-sparcs-validation)进行云雪分割,该数据集包含80张卫星影像,涵盖10个波段。原始标注图片包含7个类别,分别是`cloud`, `cloud shadow`, `shadow over water`, `snow/ice`, `water`, `land`和`flooded`。由于`flooded`和`shadow over water`2个类别占比仅为`1.8%`和`0.24%`,我们将其进行合并,`flooded`归为`land`,`shadow over water`归为`shadow`,合并后标注包含5个类别。
+
+数值、类别、颜色对应表:
+
+|Pixel value|Class|Color|
+|---|---|---|
+|0|cloud|white|
+|1|shadow|black|
+|2|snow/ice|cyan|
+|3|water|blue|
+|4|land|grey|
+
+<p align="center">
+ <img src="./docs/images/dataset.png" align="middle"
+</p>
+
+<p align='center'>
+ L8 SPARCS数据集示例
+</p>
+
+执行以下命令下载并解压经过类别合并后的数据集:
+```shell script
+mkdir dataset && cd dataset
+wget https://paddleseg.bj.bcebos.com/dataset/remote_sensing_seg.zip
+unzip remote_sensing_seg.zip
+cd ..
+```
+其中`data`目录存放遥感影像,`data_vis`目录存放彩色合成预览图,`mask`目录存放标注图。
+
+## <h2 id="2">数据分析</h2>  
+
+遥感影像往往由许多波段组成,不同波段数据分布可能大相径庭,例如可见光波段和热红外波段分布十分不同。为了更深入了解数据的分布来优化模型训练效果,需要对数据进行分析。
+
+参考文档[数据分析](./docs/analysis.md)对训练集进行统计分析,确定图像像素值的截断范围,并统计截断后的均值和方差。
+
+## <h2 id="2">模型训练</h2>
+
+本案例选择`UNet`语义分割模型完成云雪分割,运行以下步骤完成模型训练,模型的最优精度`miou`为`77.99%`。
+
+* 设置GPU卡号
+```shell script
+export CUDA_VISIBLE_DEVICES=0
+```
+
+* 运行以下脚本开始训练
+```shell script
+python train.py --data_dir dataset/remote_sensing_seg \
+--train_file_list dataset/remote_sensing_seg/train.txt \
+--eval_file_list dataset/remote_sensing_seg/val.txt \
+--label_list dataset/remote_sensing_seg/labels.txt \
+--save_dir saved_model/remote_sensing_unet \
+--num_classes 5 \
+--channel 10 \
+--lr 0.01 \
+--clip_min_value 7172 6561 5777 5103 4291 4000 4000 4232 6934 7199 \
+--clip_max_value 50000 50000 50000 50000 50000 40000 30000 18000 40000 36000 \
+--mean 0.14311188522260637 0.14288498042151332 0.14812997807748615 0.16377211813814938 0.2737538363784552 0.2740934379398823 0.27749601919204 0.07767443032935262 0.5694699410349131 0.5549716085195542 \
+--std 0.09101632762467489 0.09600705942721106 0.096193618606776 0.10371446736389771 0.10911951586604118 0.11043593115173281 0.12648042598739268 0.027746262217260665 0.06822348076384514 0.062377591186668725 \
+--num_epochs 500 \
+--train_batch_size 3
+```
+
+也可以跳过模型训练步骤,下载预训练模型直接进行模型预测:
+
+```
+wget https://bj.bcebos.com/paddlex/examples/multi-channel_remote_sensing/models/l8sparcs_remote_model.tar.gz
+tar -xvf l8sparcs_remote_model.tar.gz
+```
+
+## <h2 id="2">模型预测</h2>
+运行以下脚本,对遥感图像进行预测并可视化预测结果,相应地也将对应的标注文件进行可视化,以比较预测效果。
+
+```shell script
+export CUDA_VISIBLE_DEVICES=0
+python predict.py
+```
+可视化效果如下所示:
+
+
+<img src="./docs/images/prediction.jpg" alt="预测图" align=center />
+
+
+数值、类别、颜色对应表:
+
+|Pixel value|Class|Color|
+|---|---|---|
+|0|cloud|white|
+|1|shadow|black|
+|2|snow/ice|cyan|
+|3|water|blue|
+|4|land|grey|

+ 125 - 0
examples/multi-channel_remote_sensing/docs/analysis.md

@@ -0,0 +1,125 @@
+# 数据分析
+
+遥感影像往往由许多波段组成,不同波段数据分布可能大相径庭,例如可见光波段和热红外波段分布十分不同。为了更深入了解数据的分布来优化模型训练效果,需要对数据进行分析。
+
+## 目录
+* [1. 统计分析](#1)
+* [2. 确定像素值截断范围](#2)
+* [3. 统计截断后的均值和方差](#3)
+
+## <h2 id="1">统计分析</h2>
+执行以下脚本,对训练集进行统计分析,屏幕会输出分析结果,同时结果也会保存至文件`train_information.pkl`中:
+
+```
+python tools/analysis.py
+```
+
+数据统计分析内容如下:
+
+* 图像数量
+
+例如统计出训练集中有64张图片:
+```
+64 samples in file dataset/remote_sensing_seg/train.txt
+```
+* 图像最大和最小的尺寸
+
+例如统计出训练集中最大的高宽和最小的高宽分别是(1000, 1000)和(1000, 1000):
+```
+Minimal image height: 1000 Minimal image width: 1000.
+Maximal image height: 1000 Maximal image width: 1000.
+```
+* 图像通道数量
+
+例如统计出图像的通道数量为10:
+
+```
+Image channel is 10.
+```
+* 图像各通道的最小值和最大值
+
+最小值和最大值分别以列表的形式输出,按照通道从小到大排列。例如:
+
+```
+Minimal image value: [7.172e+03 6.561e+03 5.777e+03 5.103e+03 4.291e+03 1.000e+00 1.000e+00 4.232e+03 6.934e+03 7.199e+03]
+Maximal image value: [65535. 65535. 65535. 65535. 65535. 65535. 65535. 56534. 65535. 63215.]
+
+```
+* 图像各通道的像素值分布
+
+针对各个通道,统计出各像素值的数量,并以柱状图的形式呈现在以'distribute.png'结尾的图片中。**需要注意的是,为便于观察,纵坐标为对数坐标**。用户可以查看这些图片来选择是否需要对分布在头部和尾部的像素值进行截断。
+
+```
+Image pixel distribution of each channel is saved with 'distribute.png' in the dataset/remote_sensing_seg
+```
+
+* 图像各通道归一化后的均值和方差
+
+各通道归一化系数为各通道最大值与最小值之差,均值和方差以列别形式输出,按照通道从小到大排列。例如:
+
+```
+Image mean value: [0.23417574 0.22283101 0.2119595  0.2119887  0.27910388 0.21294892 0.17294037 0.10158925 0.43623915 0.41019192]
+Image standard deviation: [0.06831269 0.07243951 0.07284761 0.07875261 0.08120818 0.0609302 0.05110716 0.00696064 0.03849307 0.03205579]
+```
+
+* 标注图中各类别的数量及比重
+
+统计各类别的像素数量和在数据集全部像素的占比,以(类别值,该类别的数量,该类别的占比)的格式输出。例如:
+
+```
+Label pixel information is shown in a format of (label_id, the number of label_id, the ratio of label_id):
+(0, 13302870, 0.20785734374999995)
+(1, 4577005, 0.07151570312500002)
+(2, 3955012, 0.0617970625)
+(3, 2814243, 0.04397254687499999)
+(4, 39350870, 0.6148573437500001)
+
+```
+
+## <h2 id="2">2 确定像素值截断范围</h2>
+
+遥感影像数据分布范围广,往往存在一些异常值,这会影响算法对实际数据分布的拟合效果。为更好地对数据进行归一化,可以抑制遥感影像中少量的异常值。根据`图像各通道的像素值分布`来确定像素值的截断范围,并在后续图像预处理过程中对超出范围的像素值通过截断进行校正,从而去除异常值带来的干扰。**注意:该步骤是否执行根据数据集实际分布来决定。**
+
+例如各通道的像素值分布可视化效果如下:
+
+<img src="./images/image_pixel_distribution.png" width = "600" height = "600" alt="像素值分布图" align=center />
+
+
+对于上述分布,我们选取的截断范围是(按照通道从小到大排列):
+
+```
+截断范围最小值: clip_min_value = [7172,  6561,  5777, 5103, 4291, 4000, 4000, 4232, 6934, 7199]
+截断范围最大值: clip_max_value = [50000, 50000, 50000, 50000, 50000, 40000, 30000, 18000, 40000, 36000]
+```
+
+## <h2 id="3">3 确定像素值截断范围</h2>
+
+为避免数据截断范围选取不当带来的影响,应该统计异常值像素占比,确保受影响的像素比例不要过高。接着对截断后的数据计算归一化后的均值和方差,**用于后续模型训练时的图像预处理参数设置**。
+
+执行以下脚本:
+```
+python tools/cal_clipped_mean_std.py
+```
+
+截断像素占比统计结果如下:
+
+```
+Channel 0, the ratio of pixels to be clipped = 0.00054778125
+Channel 1, the ratio of pixels to be clipped = 0.0011129375
+Channel 2, the ratio of pixels to be clipped = 0.000843703125
+Channel 3, the ratio of pixels to be clipped = 0.00127125
+Channel 4, the ratio of pixels to be clipped = 0.001330140625
+Channel 5, the ratio of pixels to be clipped = 8.1375e-05
+Channel 6, the ratio of pixels to be clipped = 0.0007348125
+Channel 7, the ratio of pixels to be clipped = 6.5625e-07
+Channel 8, the ratio of pixels to be clipped = 0.000185921875
+Channel 9, the ratio of pixels to be clipped = 0.000139671875
+```
+可看出,被截断像素占比均不超过0.2%。
+
+裁剪后数据的归一化系数如下:
+```
+Image mean value: [0.15163569 0.15142828 0.15574491 0.1716084  0.2799778  0.27652043 0.28195933 0.07853807 0.56333154 0.5477584 ]
+Image standard deviation: [0.09301891 0.09818967 0.09831126 0.1057784  0.10842132 0.11062996 0.12791838 0.02637859 0.0675052  0.06168227]
+(normalized by (clip_max_value - clip_min_value), arranged in 0-10 channel order)
+```

二进制
examples/multi-channel_remote_sensing/docs/images/dataset.png


二进制
examples/multi-channel_remote_sensing/docs/images/image_pixel_distribution.png


二进制
examples/multi-channel_remote_sensing/docs/images/prediction.jpg


+ 21 - 0
examples/multi-channel_remote_sensing/predict.py

@@ -0,0 +1,21 @@
+import numpy as np
+from PIL import Image
+
+import paddlex as pdx
+
+model_dir = "saved_model/remote_sensing_unet/best_model/"
+img_file = "dataset/remote_sensing_seg/data/LC80150242014146LGN00_23_data.tif"
+label_file = "dataset/remote_sensing_seg/mask/LC80150242014146LGN00_23_mask.png"
+color = [255, 255, 255, 0, 0, 0, 255, 255, 0, 255, 0, 0, 150, 150, 150]
+
+# 预测并可视化预测结果
+model = pdx.load_model(model_dir)
+pred = model.predict(img_file)
+pdx.seg.visualize(
+    img_file, pred, weight=0., save_dir='./output/pred', color=color)
+
+# 可视化标注文件
+label = np.asarray(Image.open(label_file))
+pred = {'label_map': label}
+pdx.seg.visualize(
+    img_file, pred, weight=0., save_dir='./output/gt', color=color)

+ 8 - 0
examples/multi-channel_remote_sensing/tools/analysis.py

@@ -0,0 +1,8 @@
+import paddlex as pdx
+
+train_analysis = pdx.datasets.analysis.Seg(
+    data_dir='dataset/remote_sensing_seg',
+    file_list='dataset/remote_sensing_seg/train.txt',
+    label_list='dataset/remote_sensing_seg/labels.txt')
+
+train_analysis.analysis()

+ 15 - 0
examples/multi-channel_remote_sensing/tools/cal_clipped_mean_std.py

@@ -0,0 +1,15 @@
+import paddlex as pdx
+
+clip_min_value = [7172, 6561, 5777, 5103, 4291, 4000, 4000, 4232, 6934, 7199]
+clip_max_value = [
+    50000, 50000, 50000, 50000, 50000, 40000, 30000, 18000, 40000, 36000
+]
+data_info_file = 'dataset/remote_sensing_seg/train_infomation.pkl'
+
+train_analysis = pdx.datasets.analysis.Seg(
+    data_dir='dataset/remote_sensing_seg',
+    file_list='dataset/remote_sensing_seg/train.txt',
+    label_list='dataset/remote_sensing_seg/labels.txt')
+
+train_analysis.cal_clipped_mean_std(clip_min_value, clip_max_value,
+                                    data_info_file)

+ 23 - 11
paddlex/remotesensing/train_demo.py → examples/multi-channel_remote_sensing/train.py

@@ -16,7 +16,6 @@
 import os.path as osp
 import argparse
 from paddlex.seg import transforms
-import paddlex.remotesensing.transforms as rs_transforms
 import paddlex as pdx
 
 
@@ -29,6 +28,24 @@ def parse_args():
         default=None,
         type=str)
     parser.add_argument(
+        '--train_file_list',
+        dest='train_file_list',
+        help='train file_list',
+        default=None,
+        type=str)
+    parser.add_argument(
+        '--eval_file_list',
+        dest='eval_file_list',
+        help='eval file_list',
+        default=None,
+        type=str)
+    parser.add_argument(
+        '--label_list',
+        dest='label_list',
+        help='label_list file',
+        default=None,
+        type=str)
+    parser.add_argument(
         '--save_dir',
         dest='save_dir',
         help='model save directory',
@@ -93,6 +110,9 @@ def parse_args():
 
 args = parse_args()
 data_dir = args.data_dir
+train_list = args.train_file_list
+val_list = args.eval_file_list
+label_list = args.label_list
 save_dir = args.save_dir
 num_classes = args.num_classes
 channel = args.channel
@@ -110,27 +130,19 @@ train_transforms = transforms.Compose([
     transforms.RandomHorizontalFlip(0.5),
     transforms.ResizeStepScaling(0.5, 2.0, 0.25),
     transforms.RandomPaddingCrop(im_padding_value=[1000] * channel),
-    rs_transforms.Clip(
+    transforms.Clip(
         min_val=clip_min_value, max_val=clip_max_value),
     transforms.Normalize(
         min_val=clip_min_value, max_val=clip_max_value, mean=mean, std=std),
 ])
 
-train_transforms.decode_image = rs_transforms.decode_image
-
 eval_transforms = transforms.Compose([
-    rs_transforms.Clip(
+    transforms.Clip(
         min_val=clip_min_value, max_val=clip_max_value),
     transforms.Normalize(
         min_val=clip_min_value, max_val=clip_max_value, mean=mean, std=std),
 ])
 
-eval_transforms.decode_image = rs_transforms.decode_image
-
-train_list = osp.join(data_dir, 'train.txt')
-val_list = osp.join(data_dir, 'val.txt')
-label_list = osp.join(data_dir, 'labels.txt')
-
 train_dataset = pdx.datasets.SegDataset(
     data_dir=data_dir,
     file_list=train_list,

+ 90 - 125
paddlex/cv/datasets/analysis.py

@@ -23,6 +23,7 @@ import multiprocessing as mp
 
 import paddlex.utils.logging as logging
 from paddlex.utils import path_normalization
+from paddlex.cv.transforms.seg_transforms import Compose
 from .dataset import get_encoding
 
 
@@ -57,38 +58,6 @@ class Seg:
                 self.file_list.append([full_path_im, full_path_label])
         self.num_samples = len(self.file_list)
 
-    @staticmethod
-    def decode_image(im, label):
-        if isinstance(im, np.ndarray):
-            if len(im.shape) != 3:
-                raise Exception(
-                    "im should be 3-dimensions, but now is {}-dimensions".
-                    format(len(im.shape)))
-        else:
-            try:
-                im = cv2.imread(im)
-            except:
-                raise ValueError('Can\'t read The image file {}!'.format(im))
-        im = im.astype('float32')
-        if label is not None:
-            if isinstance(label, np.ndarray):
-                if len(label.shape) != 2:
-                    raise Exception(
-                        "label should be 2-dimensions, but now is {}-dimensions".
-                        format(len(label.shape)))
-
-            else:
-                try:
-                    label = np.asarray(Image.open(label))
-                except:
-                    ValueError('Can\'t read The label file {}!'.format(label))
-        im_height, im_width, _ = im.shape
-        label_height, label_width = label.shape
-        if im_height != label_height or im_width != label_width:
-            raise Exception(
-                "The height or width of the image is not same as the label")
-        return (im, label)
-
     def _get_shape(self):
         max_height = max(self.im_height_list)
         max_width = max(self.im_width_list)
@@ -127,48 +96,25 @@ class Seg:
                         im_pixel_info[c][v] = n
                     else:
                         im_pixel_info[c][v] += n
-        mode = osp.split(self.file_list_path)[-1].split('.')[0]
-        with open(
-                osp.join(self.data_dir,
-                         '{}_image_pixel_info.pkl'.format(mode)), 'wb') as f:
-            pickle.dump(im_pixel_info, f)
-
-        import matplotlib.pyplot as plt
-        plot_id = (channel // 3 + 1) * 100 + 31
-        for c in range(channel):
-            if c > 8:
-                continue
-            plt.subplot(plot_id + c)
-            plt.bar(im_pixel_info[c].keys(),
-                    im_pixel_info[c].values(),
-                    width=1,
-                    log=True)
-            plt.xlabel('image pixel value')
-            plt.ylabel('number')
-            plt.title('channel={}'.format(c))
-        plt.savefig(
-            osp.join(self.data_dir, '{}_image_pixel_info.png'.format(mode)),
-            dpi=800)
-        plt.close()
         return im_pixel_info
 
     def _get_mean_std(self):
         im_mean = np.asarray(self.im_mean_list)
         im_mean = im_mean.sum(axis=0)
         im_mean = im_mean / len(self.file_list)
-        im_mean /= 255.
+        im_mean /= self.max_im_value - self.min_im_value
 
         im_std = np.asarray(self.im_std_list)
         im_std = im_std.sum(axis=0)
         im_std = im_std / len(self.file_list)
-        im_std /= 255.
+        im_std /= self.max_im_value - self.min_im_value
 
         return (im_mean, im_std)
 
     def _get_image_info(self, start, end):
         for id in range(start, end):
             full_path_im, full_path_label = self.file_list[id]
-            image, label = self.decode_image(full_path_im, full_path_label)
+            image, label = Compose.decode_image(full_path_im, full_path_label)
 
             height, width, channel = image.shape
             self.im_height_list[id] = height
@@ -176,9 +122,9 @@ class Seg:
             self.im_channel_list[id] = channel
 
             self.im_mean_list[
-                id] = [np.mean(image[:, :, c]) for c in range(channel)]
+                id] = [image[:, :, c].mean() for c in range(channel)]
             self.im_std_list[
-                id] = [np.mean(image[:, :, c]) for c in range(channel)]
+                id] = [image[:, :, c].std() for c in range(channel)]
             for c in range(channel):
                 unique, counts = np.unique(image[:, :, c], return_counts=True)
                 self.im_value_list[id].extend([unique])
@@ -192,7 +138,7 @@ class Seg:
                               clip_max_value):
         for id in range(start, end):
             full_path_im, full_path_label = self.file_list[id]
-            image, label = self.decode_image(full_path_im, full_path_label)
+            image, label = Compose.decode_image(full_path_im, full_path_label)
             for c in range(self.channel_num):
                 np.clip(
                     image[:, :, c],
@@ -219,7 +165,6 @@ class Seg:
         self.label_value_num_list = [[] for i in range(len(self.file_list))]
 
         num_workers = mp.cpu_count() // 2 if mp.cpu_count() // 2 < 8 else 8
-        num_workers = 6
         threads = []
         one_worker_file = len(self.file_list) // num_workers
         for i in range(num_workers):
@@ -228,39 +173,41 @@ class Seg:
                 i + 1) if i < num_workers - 1 else len(self.file_list)
             t = threading.Thread(
                 target=self._get_image_info, args=(start, end))
-            print("====", len(self.file_list), start, end)
-            #t.daemon = True
             threads.append(t)
         for t in threads:
             t.start()
         for t in threads:
             t.join()
-        print('ok')
-        import time
-        import sys
-        sys.exit(0)
-        time.sleep(1000000)
-        return
-
-        #self._get_image_info(0, len(self.file_list))
+
         unique, counts = np.unique(self.im_channel_list, return_counts=True)
-        print('==== unique')
         if len(unique) > 1:
             raise Exception("There are {} kinds of image channels: {}.".format(
                 len(unique), unique[:]))
         self.channel_num = unique[0]
         shape_info = self._get_shape()
-        print('==== shape_info')
         self.max_height = shape_info['max_height']
         self.max_width = shape_info['max_width']
         self.min_height = shape_info['min_height']
         self.min_width = shape_info['min_width']
         self.label_pixel_info = self._get_label_pixel_info()
-        print('==== label_pixel_info')
         self.im_pixel_info = self._get_image_pixel_info()
-        print('==== im_pixel_info')
-        im_mean, im_std = self._get_mean_std()
-        print('==== get_mean_std')
+        mode = osp.split(self.file_list_path)[-1].split('.')[0]
+        import matplotlib.pyplot as plt
+        for c in range(self.channel_num):
+            plt.figure()
+            plt.bar(self.im_pixel_info[c].keys(),
+                    self.im_pixel_info[c].values(),
+                    width=1,
+                    log=True)
+            plt.xlabel('image pixel value')
+            plt.ylabel('number')
+            plt.title('channel={}'.format(c))
+            plt.savefig(
+                osp.join(self.data_dir,
+                         '{}_channel{}_distribute.png'.format(mode, c)),
+                dpi=100)
+            plt.close()
+
         max_im_value = list()
         min_im_value = list()
         for c in range(self.channel_num):
@@ -269,70 +216,78 @@ class Seg:
         self.max_im_value = np.asarray(max_im_value)
         self.min_im_value = np.asarray(min_im_value)
 
+        im_mean, im_std = self._get_mean_std()
+
+        info = {
+            'channel_num': self.channel_num,
+            'image_pixel': self.im_pixel_info,
+            'label_pixel': self.label_pixel_info,
+            'file_num': len(self.file_list),
+            'max_height': self.max_height,
+            'max_width': self.max_width,
+            'min_height': self.min_height,
+            'min_width': self.min_width,
+            'max_image_value': self.max_im_value,
+            'min_image_value': self.min_im_value
+        }
+        saved_pkl_file = osp.join(self.data_dir,
+                                  '{}_infomation.pkl'.format(mode))
+        with open(osp.join(saved_pkl_file), 'wb') as f:
+            pickle.dump(info, f)
+
         logging.info(
             "############## The analysis results are as follows ##############\n"
         )
         logging.info("{} samples in file {}\n".format(
             len(self.file_list), self.file_list_path))
-        logging.info("Maximal image height: {} Maximal image width: {}.\n".
-                     format(self.max_height, self.max_width))
         logging.info("Minimal image height: {} Minimal image width: {}.\n".
                      format(self.min_height, self.min_width))
+        logging.info("Maximal image height: {} Maximal image width: {}.\n".
+                     format(self.max_height, self.max_width))
         logging.info("Image channel is {}.\n".format(self.channel_num))
         logging.info(
-            "Image mean value: {} Image standard deviation: {} (normalized by 255, sorted by a BGR format).\n".
-            format(im_mean, im_std))
+            "Minimal image value: {} Maximal image value: {} (arranged in 0-{} channel order) \n".
+            format(self.min_im_value, self.max_im_value, self.channel_num))
+        logging.info(
+            "Image pixel distribution of each channel is saved with 'distribute.png' in the {}"
+            .format(self.data_dir))
+        logging.info(
+            "Image mean value: {} Image standard deviation: {} (normalized by the (max_im_value - min_im_value), arranged in 0-{} channel order).\n".
+            format(im_mean, im_std, self.channel_num))
         logging.info(
             "Label pixel information is shown in a format of (label_id, the number of label_id, the ratio of label_id):"
         )
         for v, (n, r) in self.label_pixel_info.items():
             logging.info("({}, {}, {})".format(v, n, r))
-        mode = osp.split(self.file_list_path)[-1].split('.')[0]
-        saved_pkl_file = osp.join(self.data_dir,
-                                  '{}_image_pixel_info.pkl'.format(mode))
-        saved_png_file = osp.join(self.data_dir,
-                                  '{}_image_pixel_info.png'.format(mode))
-        logging.info(
-            "Image pixel information is saved in the file '{}' and shown in the file '{}'".
-            format(saved_pkl_file, saved_png_file))
 
-    def cal_clipvalue_ratio(self, clip_min_value, clip_max_value):
-        if len(clip_min_value) != self.channel_num or len(
-                clip_max_value) != self.channel_num:
+        logging.info("Dataset information is saved in {}".format(
+            saved_pkl_file))
+
+    def cal_clipped_mean_std(self, clip_min_value, clip_max_value,
+                             data_info_file):
+        with open(data_info_file, 'rb') as f:
+            im_info = pickle.load(f)
+        channel_num = im_info['channel_num']
+        min_im_value = im_info['min_image_value']
+        max_im_value = im_info['max_image_value']
+        im_pixel_info = im_info['image_pixel']
+
+        if len(clip_min_value) != channel_num or len(
+                clip_max_value) != channel_num:
             raise Exception(
                 "The length of clip_min_value or clip_max_value should be equal to the number of image channel {}."
-                .format(self.channle_num))
-        for c in range(self.channel_num):
-            if clip_min_value[c] < self.min_im_value[c] or clip_min_value[
-                    c] > self.max_im_value[c]:
-                raise Exception(
-                    "Clip_min_value of the channel {} is not in [{}, {}]".
-                    format(c, self.min_im_value[c], self.max_im_value[c]))
-            if clip_max_value[c] < self.min_im_value[c] or clip_max_value[
-                    c] > self.max_im_value[c]:
-                raise Exception(
-                    "Clip_max_value of the channel {} is not in [{}, {}]".
-                    format(c, self.min_im_value[c], self.max_im_value[c]))
-            clip_pixel_num = 0
-            pixel_num = sum(self.im_pixel_info[c].values())
-            for v, n in self.im_pixel_info[c].items():
-                if v < clip_min_value[c] or v > clip_max_value[c]:
-                    clip_pixel_num += n
-            logging.info("Channel {}, the ratio of pixels to be clipped = {}".
-                         format(c, clip_pixel_num / pixel_num))
-
-    def cal_clipped_mean_std(self, clip_min_value, clip_max_value):
-        for c in range(self.channel_num):
-            if clip_min_value[c] < self.min_im_value[c] or clip_min_value[
-                    c] > self.max_im_value[c]:
+                .format(channle_num))
+        for c in range(channel_num):
+            if clip_min_value[c] < min_im_value[c] or clip_min_value[
+                    c] > max_im_value[c]:
                 raise Exception(
                     "Clip_min_value of the channel {} is not in [{}, {}]".
-                    format(c, self.min_im_value[c], self.max_im_value[c]))
-            if clip_max_value[c] < self.min_im_value[c] or clip_max_value[
-                    c] > self.max_im_value[c]:
+                    format(c, min_im_value[c], max_im_value[c]))
+            if clip_max_value[c] < min_im_value[c] or clip_max_value[
+                    c] > max_im_value[c]:
                 raise Exception(
                     "Clip_max_value of the channel {} is not in [{}, {}]".
-                    format(c, self.min_im_value[c], self.max_im_value[c]))
+                    format(c, min_im_value[c], self.max_im_value[c]))
 
         self.clipped_im_mean_list = [[] for i in range(len(self.file_list))]
         self.clipped_im_std_list = [[] for i in range(len(self.file_list))]
@@ -340,6 +295,7 @@ class Seg:
         num_workers = mp.cpu_count() // 2 if mp.cpu_count() // 2 < 8 else 8
         threads = []
         one_worker_file = len(self.file_list) // num_workers
+        self.channel_num = channel_num
         for i in range(num_workers):
             start = one_worker_file * i
             end = one_worker_file * (
@@ -349,9 +305,9 @@ class Seg:
                 args=(start, end, clip_min_value, clip_max_value))
             threads.append(t)
         for t in threads:
-            t.setDaemon(True)
             t.start()
-        t.join()
+        for t in threads:
+            t.join()
 
         im_mean = np.asarray(self.clipped_im_mean_list)
         im_mean = im_mean.sum(axis=0)
@@ -361,6 +317,15 @@ class Seg:
         im_std = im_std.sum(axis=0)
         im_std = im_std / len(self.file_list)
 
+        for c in range(channel_num):
+            clip_pixel_num = 0
+            pixel_num = sum(im_pixel_info[c].values())
+            for v, n in im_pixel_info[c].items():
+                if v < clip_min_value[c] or v > clip_max_value[c]:
+                    clip_pixel_num += n
+            logging.info("Channel {}, the ratio of pixels to be clipped = {}".
+                         format(c, clip_pixel_num / pixel_num))
+
         logging.info(
-            "Image mean value: {} Image standard deviation: {} (normalized by (clip_max_value - clip_min_value)).\n".
-            format(im_mean, im_std))
+            "Image mean value: {} Image standard deviation: {} (normalized by (clip_max_value - clip_min_value), arranged in 0-{} channel order).\n".
+            format(im_mean, im_std, self.channel_num))

+ 27 - 3
paddlex/cv/models/utils/visualize.py

@@ -20,6 +20,7 @@ import numpy as np
 import time
 import paddlex.utils.logging as logging
 from .detection_eval import fixed_linspace, backup_linspace, loadRes
+from paddlex.cv.datasets.dataset import is_pic
 
 
 def visualize_detection(image, result, threshold=0.5, save_dir='./'):
@@ -44,7 +45,11 @@ def visualize_detection(image, result, threshold=0.5, save_dir='./'):
         return image
 
 
-def visualize_segmentation(image, result, weight=0.6, save_dir='./'):
+def visualize_segmentation(image,
+                           result,
+                           weight=0.6,
+                           save_dir='./',
+                           color=None):
     """
     Convert segment result to color image, and save added image.
     Args:
@@ -52,10 +57,14 @@ def visualize_segmentation(image, result, weight=0.6, save_dir='./'):
         result: the predict result of image
         weight: the image weight of visual image, and the result weight is (1 - weight)
         save_dir: the directory for saving visual image
+        color: the list of a BGR-mode color for each label.
     """
     label_map = result['label_map']
     color_map = get_color_map_list(256)
+    if color is not None:
+        color_map[0:len(color) // 3][:] = color
     color_map = np.array(color_map).astype("uint8")
+
     # Use OpenCV LUT for color mapping
     c1 = cv2.LUT(label_map, color_map[:, 0])
     c2 = cv2.LUT(label_map, color_map[:, 1])
@@ -65,11 +74,26 @@ def visualize_segmentation(image, result, weight=0.6, save_dir='./'):
     if isinstance(image, np.ndarray):
         im = image
         image_name = str(int(time.time() * 1000)) + '.jpg'
+        if image.shape[2] != 3:
+            logging.info(
+                "The image is not 3-channel array, so predicted label map is shown as a pseudo color image."
+            )
+            weight = 0.
     else:
         image_name = os.path.split(image)[-1]
-        im = cv2.imread(image)
+        if not is_pic(image):
+            logging.info(
+                "The image cannot be opened by opencv, so predicted label map is shown as a pseudo color image."
+            )
+            image_name = image_name.split('.')[0] + '.jpg'
+            weight = 0.
+        else:
+            im = cv2.imread(image)
 
-    vis_result = cv2.addWeighted(im, weight, pseudo_img, 1 - weight, 0)
+    if abs(weight) < 1e-5:
+        vis_result = pseudo_img
+    else:
+        vis_result = cv2.addWeighted(im, weight, pseudo_img, 1 - weight, 0)
 
     if save_dir is not None:
         if not os.path.exists(save_dir):

+ 64 - 8
paddlex/cv/transforms/seg_transforms.py

@@ -20,7 +20,9 @@ import os.path as osp
 import numpy as np
 from PIL import Image
 import cv2
+import imghdr
 from collections import OrderedDict
+
 import paddlex.utils.logging as logging
 
 
@@ -61,6 +63,30 @@ class Compose(SegTransform):
                     )
 
     @staticmethod
+    def read_img(img_path):
+        img_format = imghdr.what(img_path)
+        name, ext = osp.splitext(img_path)
+        if img_format == 'tiff' or ext == '.img':
+            import gdal
+            gdal.UseExceptions()
+            gdal.PushErrorHandler('CPLQuietErrorHandler')
+
+            try:
+                dataset = gdal.Open(img_path)
+            except:
+                logging.error(gdal.GetLastErrorMsg())
+            if dataset == None:
+                raise Exception('Can not open', img_path)
+            im_data = dataset.ReadAsArray()
+            return im_data.transpose((1, 2, 0))
+        elif img_format == 'png':
+            return np.asarray(Image.open(img_path))
+        elif ext == '.npy':
+            return np.load(img_path)
+        else:
+            raise Exception('Image format {} is not supported!'.format(ext))
+
+    @staticmethod
     def decode_image(im, label):
         if isinstance(im, np.ndarray):
             if len(im.shape) != 3:
@@ -69,7 +95,7 @@ class Compose(SegTransform):
                     format(len(im.shape)))
         else:
             try:
-                im = cv2.imread(im)
+                im = Compose.read_img(im)
             except:
                 raise ValueError('Can\'t read The image file {}!'.format(im))
         im = im.astype('float32')
@@ -85,11 +111,11 @@ class Compose(SegTransform):
                     label = np.asarray(Image.open(label))
                 except:
                     ValueError('Can\'t read The label file {}!'.format(label))
-        im_height, im_width, _ = im.shape
-        label_height, label_width = label.shape
-        if im_height != label_height or im_width != label_width:
-            raise Exception(
-                "The height or width of the image is not same as the label")
+            im_height, im_width, _ = im.shape
+            label_height, label_width = label.shape
+            if im_height != label_height or im_width != label_width:
+                raise Exception(
+                    "The height or width of the image is not same as the label")
         return (im, label)
 
     def __call__(self, im, im_info=None, label=None):
@@ -570,12 +596,15 @@ class ResizeStepScaling(SegTransform):
 
 class Normalize(SegTransform):
     """对图像进行标准化。
-    1.尺度缩放到 [0,1]。
-    2.对图像进行减均值除以标准差操作。
+    1.像素值减去min_val
+    2.像素值除以(max_val-min_val)
+    3.对图像进行减均值除以标准差操作。
 
     Args:
         mean (list): 图像数据集的均值。默认值[0.5, 0.5, 0.5]。
         std (list): 图像数据集的标准差。默认值[0.5, 0.5, 0.5]。
+        min_val (list): 图像数据集的最小值。默认值[0, 0, 0]。
+        max_val (list): 图像数据集的最大值。默认值[255.0, 255.0, 255.0]。
 
     Raises:
         ValueError: mean或std不是list对象。std包含0。
@@ -1099,6 +1128,33 @@ class RandomDistort(SegTransform):
             return (im, im_info, label)
 
 
+class Clip(SegTransform):
+    """
+    对图像上超出一定范围的数据进行截断。
+
+    Args:
+        min_val (list): 裁剪的下限,小于min_val的数值均设为min_val. 默认值0.
+        max_val (list): 裁剪的上限,大于max_val的数值均设为max_val. 默认值255.0.
+    """
+
+    def __init__(self, min_val=[0, 0, 0], max_val=[255.0, 255.0, 255.0]):
+        self.min_val = min_val
+        self.max_val = max_val
+        if not (isinstance(self.min_val, list) and isinstance(self.max_val,
+                                                              list)):
+            raise ValueError("{}: input type is invalid.".format(self))
+
+    def __call__(self, im, im_info=None, label=None):
+        for k in range(im.shape[2]):
+            np.clip(
+                im[:, :, k], self.min_val[k], self.max_val[k], out=im[:, :, k])
+
+        if label is None:
+            return (im, im_info)
+        else:
+            return (im, im_info, label)
+
+
 class ArrangeSegmenter(SegTransform):
     """获取训练/验证/预测所需的信息。
 

+ 0 - 0
paddlex/remotesensing/__init__.py


+ 0 - 91
paddlex/remotesensing/transforms.py

@@ -1,91 +0,0 @@
-import os
-import os.path as osp
-import imghdr
-import gdal
-gdal.UseExceptions()
-gdal.PushErrorHandler('CPLQuietErrorHandler')
-import numpy as np
-from PIL import Image
-
-from paddlex.seg import transforms
-import paddlex.utils.logging as logging
-
-
-def read_img(img_path):
-    img_format = imghdr.what(img_path)
-    name, ext = osp.splitext(img_path)
-    if img_format == 'tiff' or ext == '.img':
-        try:
-            dataset = gdal.Open(img_path)
-        except:
-            logging.error(gdal.GetLastErrorMsg())
-        if dataset == None:
-            raise Exception('Can not open', img_path)
-        im_data = dataset.ReadAsArray()
-        return im_data.transpose((1, 2, 0))
-    elif img_format == 'png':
-        return np.asarray(Image.open(img_path))
-    elif ext == '.npy':
-        return np.load(img_path)
-    else:
-        raise Exception('Image format {} is not supported!'.format(ext))
-
-
-def decode_image(im, label):
-    if isinstance(im, np.ndarray):
-        if len(im.shape) != 3:
-            raise Exception(
-                "im should be 3-dimensions, but now is {}-dimensions".format(
-                    len(im.shape)))
-    else:
-        try:
-            im = read_img(im)
-        except:
-            raise ValueError('Can\'t read The image file {}!'.format(im))
-    im = im.astype('float32')
-
-    if label is not None:
-        if isinstance(label, np.ndarray):
-            if len(label.shape) != 2:
-                raise Exception(
-                    "label should be 2-dimensions, but now is {}-dimensions".
-                    format(len(label.shape)))
-
-        else:
-            try:
-                label = np.asarray(Image.open(label))
-            except:
-                ValueError('Can\'t read The label file {}!'.format(label))
-    im_height, im_width, _ = im.shape
-    label_height, label_width = label.shape
-    if im_height != label_height or im_width != label_width:
-        raise Exception(
-            "The height or width of the image is not same as the label")
-    return (im, label)
-
-
-class Clip(transforms.SegTransform):
-    """
-    对图像上超出一定范围的数据进行裁剪。
-
-    Args:
-        min_val (list): 裁剪的下限,小于min_val的数值均设为min_val. 默认值0.
-        max_val (list): 裁剪的上限,大于max_val的数值均设为max_val. 默认值255.0.
-    """
-
-    def __init__(self, min_val=[0, 0, 0], max_val=[255.0, 255.0, 255.0]):
-        self.min_val = min_val
-        self.max_val = max_val
-        if not (isinstance(self.min_val, list) and isinstance(self.max_val,
-                                                              list)):
-            raise ValueError("{}: input type is invalid.".format(self))
-
-    def __call__(self, im, im_info=None, label=None):
-        for k in range(im.shape[2]):
-            np.clip(
-                im[:, :, k], self.min_val[k], self.max_val[k], out=im[:, :, k])
-
-        if label is None:
-            return (im, im_info)
-        else:
-            return (im, im_info, label)

+ 0 - 0
paddlex/remotesensing/utils/__init__.py


+ 0 - 506
paddlex/remotesensing/utils/analyse.py

@@ -1,506 +0,0 @@
-# coding: utf8
-# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-#    http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
-import numpy as np
-import os
-import os.path as osp
-import sys
-import argparse
-from PIL import Image
-from tqdm import tqdm
-import imghdr
-import logging
-import pickle
-import gdal
-
-
-def parse_args():
-    parser = argparse.ArgumentParser(
-        description='Data analyse and data check before training.')
-    parser.add_argument(
-        '--data_dir',
-        dest='data_dir',
-        help='Dataset directory',
-        default=None,
-        type=str)
-    parser.add_argument(
-        '--num_classes',
-        dest='num_classes',
-        help='Number of classes',
-        default=None,
-        type=int)
-    parser.add_argument(
-        '--separator',
-        dest='separator',
-        help='file list separator',
-        default=" ",
-        type=str)
-    parser.add_argument(
-        '--ignore_index',
-        dest='ignore_index',
-        help='Ignored class index',
-        default=255,
-        type=int)
-    if len(sys.argv) == 1:
-        parser.print_help()
-        sys.exit(1)
-    return parser.parse_args()
-
-
-def read_img(img_path):
-    img_format = imghdr.what(img_path)
-    name, ext = osp.splitext(img_path)
-    if img_format == 'tiff' or ext == '.img':
-        dataset = gdal.Open(img_path)
-        if dataset == None:
-            raise Exception('Can not open', img_path)
-        im_data = dataset.ReadAsArray()
-        return im_data.transpose((1, 2, 0))
-    elif ext == '.npy':
-        return np.load(img_path)
-    else:
-        raise Exception('Not support {} image format!'.format(ext))
-
-
-def img_pixel_statistics(img, img_value_num, img_min_value, img_max_value):
-    channel = img.shape[2]
-    means = np.zeros(channel)
-    stds = np.zeros(channel)
-    for k in range(channel):
-        img_k = img[:, :, k]
-
-        # count mean, std
-        means[k] = np.mean(img_k)
-        stds[k] = np.std(img_k)
-
-        # count min, max
-        min_value = np.min(img_k)
-        max_value = np.max(img_k)
-        if img_max_value[k] < max_value:
-            img_max_value[k] = max_value
-        if img_min_value[k] > min_value:
-            img_min_value[k] = min_value
-
-        # count the distribution of image value, value number
-        unique, counts = np.unique(img_k, return_counts=True)
-        add_num = []
-        max_unique = np.max(unique)
-        add_len = max_unique - len(img_value_num[k]) + 1
-        if add_len > 0:
-            img_value_num[k] += ([0] * add_len)
-        for i in range(len(unique)):
-            value = unique[i]
-            img_value_num[k][value] += counts[i]
-
-        img_value_num[k] += add_num
-    return means, stds, img_min_value, img_max_value, img_value_num
-
-
-def data_distribution_statistics(data_dir, img_value_num, logger):
-    """count the distribution of image value, value number
-    """
-    logger.info(
-        "\n-----------------------------\nThe whole dataset statistics...")
-
-    if not img_value_num:
-        return
-    logger.info("\nImage pixel statistics:")
-    total_ratio = []
-    [total_ratio.append([]) for i in range(len(img_value_num))]
-    for k in range(len(img_value_num)):
-        total_num = sum(img_value_num[k])
-        total_ratio[k] = [i / total_num for i in img_value_num[k]]
-        total_ratio[k] = np.around(total_ratio[k], decimals=4)
-    with open(os.path.join(data_dir, 'img_pixel_statistics.pkl'), 'wb') as f:
-        pickle.dump([total_ratio, img_value_num], f)
-
-
-def data_range_statistics(img_min_value, img_max_value, logger):
-    """print min value, max value
-    """
-    logger.info("value range: \nimg_min_value = {} \nimg_max_value = {}".
-                format(img_min_value, img_max_value))
-
-
-def cal_normalize_coefficient(total_means, total_stds, total_img_num, logger):
-    """count mean, std
-    """
-    total_means = total_means / total_img_num
-    total_stds = total_stds / total_img_num
-    logger.info("\nCount the channel-by-channel mean and std of the image:\n"
-                "mean = {}\nstd = {}".format(total_means, total_stds))
-
-
-def error_print(str):
-    return "".join(["\nNOT PASS ", str])
-
-
-def correct_print(str):
-    return "".join(["\nPASS ", str])
-
-
-def pil_imread(file_path):
-    """read pseudo-color label"""
-    im = Image.open(file_path)
-    return np.asarray(im)
-
-
-def get_img_shape_range(img, max_width, max_height, min_width, min_height):
-    """获取图片最大和最小宽高"""
-    img_shape = img.shape
-    height, width = img_shape[0], img_shape[1]
-    max_height = max(height, max_height)
-    max_width = max(width, max_width)
-    min_height = min(height, min_height)
-    min_width = min(width, min_width)
-    return max_width, max_height, min_width, min_height
-
-
-def get_img_channel_num(img, img_channels):
-    """获取图像的通道数"""
-    img_shape = img.shape
-    if img_shape[-1] not in img_channels:
-        img_channels.append(img_shape[-1])
-    return img_channels
-
-
-def is_label_single_channel(label):
-    """判断标签是否为灰度图"""
-    label_shape = label.shape
-    if len(label_shape) == 2:
-        return True
-    else:
-        return False
-
-
-def image_label_shape_check(img, label):
-    """
-    验证图像和标注的大小是否匹配
-    """
-
-    flag = True
-    img_height = img.shape[0]
-    img_width = img.shape[1]
-    label_height = label.shape[0]
-    label_width = label.shape[1]
-
-    if img_height != label_height or img_width != label_width:
-        flag = False
-    return flag
-
-
-def ground_truth_check(label, label_path):
-    """
-    验证标注图像的格式
-    统计标注图类别和像素数
-    params:
-        label: 标注图
-        label_path: 标注图路径
-    return:
-        png_format: 返回是否是png格式图片
-        unique: 返回标注类别
-        counts: 返回标注的像素数
-    """
-    if imghdr.what(label_path) == "png":
-        png_format = True
-    else:
-        png_format = False
-
-    unique, counts = np.unique(label, return_counts=True)
-
-    return png_format, unique, counts
-
-
-def sum_label_check(label_classes, num_of_each_class, ignore_index,
-                    num_classes, total_label_classes, total_num_of_each_class):
-    """
-    统计所有标注图上的类别和每个类别的像素数
-    params:
-        label_classes: 标注类别
-        num_of_each_class: 各个类别的像素数目
-    """
-    is_label_correct = True
-
-    if ignore_index in label_classes:
-        label_classes2 = np.delete(label_classes,
-                                   np.where(label_classes == ignore_index))
-    else:
-        label_classes2 = label_classes
-    if min(label_classes2) < 0 or max(label_classes2) > num_classes - 1:
-        is_label_correct = False
-    add_class = []
-    add_num = []
-    for i in range(len(label_classes)):
-        gi = label_classes[i]
-        if gi in total_label_classes:
-            j = total_label_classes.index(gi)
-            total_num_of_each_class[j] += num_of_each_class[i]
-        else:
-            add_class.append(gi)
-            add_num.append(num_of_each_class[i])
-    total_num_of_each_class += add_num
-    total_label_classes += add_class
-    return is_label_correct, total_num_of_each_class, total_label_classes
-
-
-def label_class_check(num_classes, total_label_classes,
-                      total_num_of_each_class, wrong_labels, logger):
-    """
-    检查实际标注类别是否和配置参数`num_classes`,`ignore_index`匹配。
-
-    **NOTE:**
-    标注图像类别数值必须在[0~(`num_classes`-1)]范围内或者为`ignore_index`。
-    标注类别最好从0开始,否则可能影响精度。
-    """
-    total_ratio = total_num_of_each_class / sum(total_num_of_each_class)
-    total_ratio = np.around(total_ratio, decimals=4)
-    total_nc = sorted(
-        zip(total_label_classes, total_ratio, total_num_of_each_class))
-    if len(wrong_labels) == 0 and not total_nc[0][0]:
-        logger.info(correct_print("label class check!"))
-    else:
-        logger.info(error_print("label class check!"))
-        if total_nc[0][0]:
-            logger.info("Warning: label classes should start from 0")
-        if len(wrong_labels) > 0:
-            logger.info("fatal error: label class is out of range [0, {}]".
-                        format(num_classes - 1))
-            for i in wrong_labels:
-                logger.debug(i)
-    return total_nc
-
-
-def label_class_statistics(total_nc, logger):
-    """
-    对标注图像进行校验,输出校验结果
-    """
-    logger.info("\nLabel class statistics:\n"
-                "(label class, percentage, total pixel number) = {} ".format(
-                    total_nc))
-
-
-def shape_check(shape_unequal_image, logger):
-    """输出shape校验结果"""
-    if len(shape_unequal_image) == 0:
-        logger.info(correct_print("shape check"))
-        logger.info("All images are the same shape as the labels")
-    else:
-        logger.info(error_print("shape check"))
-        logger.info(
-            "Some images are not the same shape as the labels as follow: ")
-        for i in shape_unequal_image:
-            logger.debug(i)
-
-
-def separator_check(wrong_lines, file_list, separator, logger):
-    """检查分割符是否复合要求"""
-    if len(wrong_lines) == 0:
-        logger.info(
-            correct_print(
-                file_list.split(os.sep)[-1] + " DATASET.separator check"))
-    else:
-        logger.info(
-            error_print(
-                file_list.split(os.sep)[-1] + " DATASET.separator check"))
-        logger.info("The following list is not separated by {}".format(
-            separator))
-        for i in wrong_lines:
-            logger.debug(i)
-
-
-def imread_check(imread_failed, logger):
-    if len(imread_failed) == 0:
-        logger.info(correct_print("dataset reading check"))
-        logger.info("All images can be read successfully")
-    else:
-        logger.info(error_print("dataset reading check"))
-        logger.info("Failed to read {} images".format(len(imread_failed)))
-        for i in imread_failed:
-            logger.debug(i)
-
-
-def single_channel_label_check(label_not_single_channel, logger):
-    if len(label_not_single_channel) == 0:
-        logger.info(correct_print("label single_channel check"))
-        logger.info("All label images are single_channel")
-    else:
-        logger.info(error_print("label single_channel check"))
-        logger.info(
-            "{} label images are not single_channel\nLabel pixel statistics may be insignificant"
-            .format(len(label_not_single_channel)))
-        for i in label_not_single_channel:
-            logger.debug(i)
-
-
-def img_shape_range_statistics(max_width, min_width, max_height, min_height,
-                               logger):
-    logger.info("\nImage size statistics:")
-    logger.info(
-        "max width = {}  min width = {}  max height = {}  min height = {}".
-        format(max_width, min_width, max_height, min_height))
-
-
-def img_channels_statistics(img_channels, logger):
-    logger.info("\nImage channels statistics\nImage channels = {}".format(
-        np.unique(img_channels)))
-
-
-def data_analyse_and_check(data_dir, num_classes, separator, ignore_index,
-                           logger):
-    train_file_list = osp.join(data_dir, 'train.txt')
-    val_file_list = osp.join(data_dir, 'val.txt')
-    test_file_list = osp.join(data_dir, 'test.txt')
-    total_img_num = 0
-    has_label = False
-    for file_list in [train_file_list, val_file_list, test_file_list]:
-        # initialization
-        imread_failed = []
-        max_width = 0
-        max_height = 0
-        min_width = sys.float_info.max
-        min_height = sys.float_info.max
-        label_not_single_channel = []
-        shape_unequal_image = []
-        wrong_labels = []
-        wrong_lines = []
-        total_label_classes = []
-        total_num_of_each_class = []
-        img_channels = []
-
-        with open(file_list, 'r') as fid:
-            logger.info("\n-----------------------------\nCheck {}...".format(
-                file_list))
-            lines = fid.readlines()
-            if not lines:
-                logger.info("File list is empty!")
-                continue
-            for line in tqdm(lines):
-                line = line.strip()
-                parts = line.split(separator)
-                if len(parts) == 1:
-                    if file_list == train_file_list or file_list == val_file_list:
-                        logger.info("Train or val list must have labels!")
-                        break
-                    img_name = parts
-                    img_path = os.path.join(data_dir, img_name[0])
-                    try:
-                        img = read_img(img_path)
-                    except Exception as e:
-                        imread_failed.append((line, str(e)))
-                        continue
-                elif len(parts) == 2:
-                    has_label = True
-                    img_name, label_name = parts[0], parts[1]
-                    img_path = os.path.join(data_dir, img_name)
-                    label_path = os.path.join(data_dir, label_name)
-                    try:
-                        img = read_img(img_path)
-                        label = pil_imread(label_path)
-                    except Exception as e:
-                        imread_failed.append((line, str(e)))
-                        continue
-
-                    is_single_channel = is_label_single_channel(label)
-                    if not is_single_channel:
-                        label_not_single_channel.append(line)
-                        continue
-                    is_equal_img_label_shape = image_label_shape_check(img,
-                                                                       label)
-                    if not is_equal_img_label_shape:
-                        shape_unequal_image.append(line)
-                    png_format, label_classes, num_of_each_class = ground_truth_check(
-                        label, label_path)
-                    is_label_correct, total_num_of_each_class, total_label_classes = sum_label_check(
-                        label_classes, num_of_each_class, ignore_index,
-                        num_classes, total_label_classes,
-                        total_num_of_each_class)
-                    if not is_label_correct:
-                        wrong_labels.append(line)
-                else:
-                    wrong_lines.append(lines)
-                    continue
-
-                if total_img_num == 0:
-                    channel = img.shape[2]
-                    total_means = np.zeros(channel)
-                    total_stds = np.zeros(channel)
-                    img_min_value = [sys.float_info.max] * channel
-                    img_max_value = [0] * channel
-                    img_value_num = []
-                    [img_value_num.append([]) for i in range(channel)]
-                means, stds, img_min_value, img_max_value, img_value_num = img_pixel_statistics(
-                    img, img_value_num, img_min_value, img_max_value)
-                total_means += means
-                total_stds += stds
-                max_width, max_height, min_width, min_height = get_img_shape_range(
-                    img, max_width, max_height, min_width, min_height)
-                img_channels = get_img_channel_num(img, img_channels)
-                total_img_num += 1
-
-            # data check
-            separator_check(wrong_lines, file_list, separator, logger)
-            imread_check(imread_failed, logger)
-            if has_label:
-                single_channel_label_check(label_not_single_channel, logger)
-                shape_check(shape_unequal_image, logger)
-                total_nc = label_class_check(num_classes, total_label_classes,
-                                             total_num_of_each_class,
-                                             wrong_labels, logger)
-
-            # data analyse on train, validation, test set.
-            img_channels_statistics(img_channels, logger)
-            img_shape_range_statistics(max_width, min_width, max_height,
-                                       min_height, logger)
-            if has_label:
-                label_class_statistics(total_nc, logger)
-    # data analyse on the whole dataset.
-    data_range_statistics(img_min_value, img_max_value, logger)
-    data_distribution_statistics(data_dir, img_value_num, logger)
-    cal_normalize_coefficient(total_means, total_stds, total_img_num, logger)
-
-
-def main():
-    args = parse_args()
-    data_dir = args.data_dir
-    ignore_index = args.ignore_index
-    num_classes = args.num_classes
-    separator = args.separator
-
-    logger = logging.getLogger()
-    logger.setLevel('DEBUG')
-    BASIC_FORMAT = "%(message)s"
-    formatter = logging.Formatter(BASIC_FORMAT)
-    sh = logging.StreamHandler()
-    sh.setFormatter(formatter)
-    sh.setLevel('INFO')
-    th = logging.FileHandler(
-        os.path.join(data_dir, 'data_analyse_and_check.log'), 'w')
-    th.setFormatter(formatter)
-    logger.addHandler(sh)
-    logger.addHandler(th)
-
-    data_analyse_and_check(data_dir, num_classes, separator, ignore_index,
-                           logger)
-
-    print("\nDetailed error information can be viewed in {}.".format(
-        os.path.join(data_dir, 'data_analyse_and_check.log')))
-
-
-if __name__ == "__main__":
-    main()