mask_rcnn.py 22 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475
  1. # copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from __future__ import absolute_import
  15. import math
  16. import tqdm
  17. import numpy as np
  18. from multiprocessing.pool import ThreadPool
  19. import paddle.fluid as fluid
  20. import paddlex.utils.logging as logging
  21. import paddlex
  22. import copy
  23. import os.path as osp
  24. from paddlex.cv.transforms import arrange_transforms
  25. from collections import OrderedDict
  26. from .faster_rcnn import FasterRCNN
  27. from .utils.detection_eval import eval_results, bbox2out, mask2out
  28. class MaskRCNN(FasterRCNN):
  29. """构建MaskRCNN,并实现其训练、评估、预测和模型导出。
  30. Args:
  31. num_classes (int): 包含了背景类的类别数。默认为81。
  32. backbone (str): MaskRCNN的backbone网络,取值范围为['ResNet18', 'ResNet50',
  33. 'ResNet50_vd', 'ResNet101', 'ResNet101_vd', 'HRNet_W18']。默认为'ResNet50'。
  34. with_fpn (bool): 是否使用FPN结构。默认为True。
  35. aspect_ratios (list): 生成anchor高宽比的可选值。默认为[0.5, 1.0, 2.0]。
  36. anchor_sizes (list): 生成anchor大小的可选值。默认为[32, 64, 128, 256, 512]。
  37. input_channel (int): 输入图像的通道数量。默认为3。
  38. """
  39. def __init__(self,
  40. num_classes=81,
  41. backbone='ResNet50',
  42. with_fpn=True,
  43. aspect_ratios=[0.5, 1.0, 2.0],
  44. anchor_sizes=[32, 64, 128, 256, 512],
  45. input_channel=3):
  46. self.init_params = locals()
  47. backbones = [
  48. 'ResNet18', 'ResNet50', 'ResNet50_vd', 'ResNet101', 'ResNet101_vd',
  49. 'HRNet_W18'
  50. ]
  51. assert backbone in backbones, "backbone should be one of {}".format(
  52. backbones)
  53. super(FasterRCNN, self).__init__('detector')
  54. self.backbone = backbone
  55. self.num_classes = num_classes
  56. self.with_fpn = with_fpn
  57. self.anchor_sizes = anchor_sizes
  58. self.labels = None
  59. if with_fpn:
  60. self.mask_head_resolution = 28
  61. else:
  62. self.mask_head_resolution = 14
  63. self.fixed_input_shape = None
  64. self.input_channel = input_channel
  65. def build_net(self, mode='train'):
  66. train_pre_nms_top_n = 2000 if self.with_fpn else 12000
  67. test_pre_nms_top_n = 1000 if self.with_fpn else 6000
  68. num_convs = 4 if self.with_fpn else 0
  69. model = paddlex.cv.nets.detection.MaskRCNN(
  70. backbone=self._get_backbone(self.backbone),
  71. num_classes=self.num_classes,
  72. mode=mode,
  73. with_fpn=self.with_fpn,
  74. train_pre_nms_top_n=train_pre_nms_top_n,
  75. test_pre_nms_top_n=test_pre_nms_top_n,
  76. num_convs=num_convs,
  77. mask_head_resolution=self.mask_head_resolution,
  78. fixed_input_shape=self.fixed_input_shape,
  79. input_channel=self.input_channel)
  80. inputs = model.generate_inputs()
  81. if mode == 'train':
  82. model_out = model.build_net(inputs)
  83. loss = model_out['loss']
  84. self.optimizer.minimize(loss)
  85. outputs = OrderedDict(
  86. [('loss', model_out['loss']),
  87. ('loss_cls', model_out['loss_cls']),
  88. ('loss_bbox', model_out['loss_bbox']),
  89. ('loss_mask', model_out['loss_mask']),
  90. ('loss_rpn_cls', model_out['loss_rpn_cls']), (
  91. 'loss_rpn_bbox', model_out['loss_rpn_bbox'])])
  92. else:
  93. outputs = model.build_net(inputs)
  94. return inputs, outputs
  95. def default_optimizer(self, learning_rate, warmup_steps, warmup_start_lr,
  96. lr_decay_epochs, lr_decay_gamma,
  97. num_steps_each_epoch):
  98. if warmup_steps > lr_decay_epochs[0] * num_steps_each_epoch:
  99. logging.error(
  100. "In function train(), parameters should satisfy: warmup_steps <= lr_decay_epochs[0]*num_samples_in_train_dataset",
  101. exit=False)
  102. logging.error(
  103. "See this doc for more information: https://github.com/PaddlePaddle/PaddleX/blob/develop/docs/appendix/parameters.md#notice",
  104. exit=False)
  105. logging.error(
  106. "warmup_steps should less than {} or lr_decay_epochs[0] greater than {}, please modify 'lr_decay_epochs' or 'warmup_steps' in train function".
  107. format(lr_decay_epochs[0] * num_steps_each_epoch, warmup_steps
  108. // num_steps_each_epoch))
  109. boundaries = [b * num_steps_each_epoch for b in lr_decay_epochs]
  110. values = [(lr_decay_gamma**i) * learning_rate
  111. for i in range(len(lr_decay_epochs) + 1)]
  112. lr_decay = fluid.layers.piecewise_decay(
  113. boundaries=boundaries, values=values)
  114. lr_warmup = fluid.layers.linear_lr_warmup(
  115. learning_rate=lr_decay,
  116. warmup_steps=warmup_steps,
  117. start_lr=warmup_start_lr,
  118. end_lr=learning_rate)
  119. optimizer = fluid.optimizer.Momentum(
  120. learning_rate=lr_warmup,
  121. momentum=0.9,
  122. regularization=fluid.regularizer.L2Decay(1e-04))
  123. return optimizer
  124. def train(self,
  125. num_epochs,
  126. train_dataset,
  127. train_batch_size=1,
  128. eval_dataset=None,
  129. save_interval_epochs=1,
  130. log_interval_steps=2,
  131. save_dir='output',
  132. pretrain_weights='IMAGENET',
  133. optimizer=None,
  134. learning_rate=1.0 / 800,
  135. warmup_steps=500,
  136. warmup_start_lr=1.0 / 2400,
  137. lr_decay_epochs=[8, 11],
  138. lr_decay_gamma=0.1,
  139. metric=None,
  140. use_vdl=False,
  141. early_stop=False,
  142. early_stop_patience=5,
  143. resume_checkpoint=None):
  144. """训练。
  145. Args:
  146. num_epochs (int): 训练迭代轮数。
  147. train_dataset (paddlex.datasets): 训练数据读取器。
  148. train_batch_size (int): 训练或验证数据batch大小。目前检测仅支持单卡评估,训练数据batch大小与
  149. 显卡数量之商为验证数据batch大小。默认值为1。
  150. eval_dataset (paddlex.datasets): 验证数据读取器。
  151. save_interval_epochs (int): 模型保存间隔(单位:迭代轮数)。默认为1。
  152. log_interval_steps (int): 训练日志输出间隔(单位:迭代次数)。默认为20。
  153. save_dir (str): 模型保存路径。默认值为'output'。
  154. pretrain_weights (str): 若指定为路径时,则加载路径下预训练模型;若为字符串'IMAGENET',
  155. 则自动下载在ImageNet图片数据上预训练的模型权重;若为字符串'COCO',
  156. 则自动下载在COCO数据集上预训练的模型权重;若为None,则不使用预训练模型。默认为None。
  157. optimizer (paddle.fluid.optimizer): 优化器。当该参数为None时,使用默认优化器:
  158. fluid.layers.piecewise_decay衰减策略,fluid.optimizer.Momentum优化方法。
  159. learning_rate (float): 默认优化器的学习率。默认为1.0/800。
  160. warmup_steps (int): 默认优化器进行warmup过程的步数。默认为500。
  161. warmup_start_lr (int): 默认优化器warmup的起始学习率。默认为1.0/2400。
  162. lr_decay_epochs (list): 默认优化器的学习率衰减轮数。默认为[8, 11]。
  163. lr_decay_gamma (float): 默认优化器的学习率衰减率。默认为0.1。
  164. metric (bool): 训练过程中评估的方式,取值范围为['COCO', 'VOC']。
  165. use_vdl (bool): 是否使用VisualDL进行可视化。默认值为False。
  166. early_stop (bool): 是否使用提前终止训练策略。默认值为False。
  167. early_stop_patience (int): 当使用提前终止训练策略时,如果验证集精度在`early_stop_patience`个epoch内
  168. 连续下降或持平,则终止训练。默认值为5。
  169. resume_checkpoint (str): 恢复训练时指定上次训练保存的模型路径。若为None,则不会恢复训练。默认值为None。
  170. Raises:
  171. ValueError: 评估类型不在指定列表中。
  172. ValueError: 模型从inference model进行加载。
  173. """
  174. if metric is None:
  175. if isinstance(train_dataset, paddlex.datasets.CocoDetection) or \
  176. isinstance(train_dataset, paddlex.datasets.EasyDataDet):
  177. metric = 'COCO'
  178. else:
  179. raise Exception(
  180. "train_dataset should be datasets.COCODetection or datasets.EasyDataDet."
  181. )
  182. assert metric in ['COCO', 'VOC'], "Metric only support 'VOC' or 'COCO'"
  183. self.metric = metric
  184. if not self.trainable:
  185. raise Exception("Model is not trainable from load_model method.")
  186. self.labels = copy.deepcopy(train_dataset.labels)
  187. self.labels.insert(0, 'background')
  188. # 构建训练网络
  189. if optimizer is None:
  190. # 构建默认的优化策略
  191. num_steps_each_epoch = train_dataset.num_samples // train_batch_size
  192. optimizer = self.default_optimizer(
  193. learning_rate=learning_rate,
  194. warmup_steps=warmup_steps,
  195. warmup_start_lr=warmup_start_lr,
  196. lr_decay_epochs=lr_decay_epochs,
  197. lr_decay_gamma=lr_decay_gamma,
  198. num_steps_each_epoch=num_steps_each_epoch)
  199. self.optimizer = optimizer
  200. # 构建训练、验证、测试网络
  201. self.build_program()
  202. fuse_bn = True
  203. if self.with_fpn and self.backbone in [
  204. 'ResNet18', 'ResNet50', 'HRNet_W18'
  205. ]:
  206. fuse_bn = False
  207. self.net_initialize(
  208. startup_prog=fluid.default_startup_program(),
  209. pretrain_weights=pretrain_weights,
  210. fuse_bn=fuse_bn,
  211. save_dir=save_dir,
  212. resume_checkpoint=resume_checkpoint)
  213. # 训练
  214. self.train_loop(
  215. num_epochs=num_epochs,
  216. train_dataset=train_dataset,
  217. train_batch_size=train_batch_size,
  218. eval_dataset=eval_dataset,
  219. save_interval_epochs=save_interval_epochs,
  220. log_interval_steps=log_interval_steps,
  221. save_dir=save_dir,
  222. use_vdl=use_vdl,
  223. early_stop=early_stop,
  224. early_stop_patience=early_stop_patience)
  225. def evaluate(self,
  226. eval_dataset,
  227. batch_size=1,
  228. epoch_id=None,
  229. metric=None,
  230. return_details=False):
  231. """评估。
  232. Args:
  233. eval_dataset (paddlex.datasets): 验证数据读取器。
  234. batch_size (int): 验证数据批大小。默认为1。当前只支持设置为1。
  235. epoch_id (int): 当前评估模型所在的训练轮数。
  236. metric (bool): 训练过程中评估的方式,取值范围为['COCO', 'VOC']。默认为None,
  237. 根据用户传入的Dataset自动选择,如为VOCDetection,则metric为'VOC';
  238. 如为COCODetection,则metric为'COCO'。
  239. return_details (bool): 是否返回详细信息。默认值为False。
  240. Returns:
  241. tuple (metrics, eval_details) /dict (metrics): 当return_details为True时,返回(metrics, eval_details),
  242. 当return_details为False时,返回metrics。metrics为dict,包含关键字:'bbox_mmap'和'segm_mmap'
  243. 或者’bbox_map‘和'segm_map',分别表示预测框和分割区域平均准确率平均值在
  244. 各个IoU阈值下的结果取平均值的结果(mmAP)、平均准确率平均值(mAP)。eval_details为dict,
  245. 包含bbox、mask和gt三个关键字。其中关键字bbox的键值是一个列表,列表中每个元素代表一个预测结果,
  246. 一个预测结果是一个由图像id,预测框类别id, 预测框坐标,预测框得分组成的列表。
  247. 关键字mask的键值是一个列表,列表中每个元素代表各预测框内物体的分割结果,分割结果由图像id、
  248. 预测框类别id、表示预测框内各像素点是否属于物体的二值图、预测框得分。
  249. 而关键字gt的键值是真实标注框的相关信息。
  250. """
  251. input_channel = getattr(self, 'input_channel', 3)
  252. arrange_transforms(
  253. model_type=self.model_type,
  254. class_name=self.__class__.__name__,
  255. transforms=eval_dataset.transforms,
  256. mode='eval',
  257. input_channel=input_channel)
  258. if metric is None:
  259. if hasattr(self, 'metric') and self.metric is not None:
  260. metric = self.metric
  261. else:
  262. if isinstance(eval_dataset, paddlex.datasets.CocoDetection):
  263. metric = 'COCO'
  264. else:
  265. raise Exception(
  266. "eval_dataset should be datasets.COCODetection.")
  267. assert metric in ['COCO', 'VOC'], "Metric only support 'VOC' or 'COCO'"
  268. if batch_size > 1:
  269. batch_size = 1
  270. logging.warning(
  271. "Mask RCNN supports batch_size=1 only during evaluating, so batch_size is forced to be set to 1."
  272. )
  273. data_generator = eval_dataset.generator(
  274. batch_size=batch_size, drop_last=False)
  275. total_steps = math.ceil(eval_dataset.num_samples * 1.0 / batch_size)
  276. results = list()
  277. logging.info(
  278. "Start to evaluating(total_samples={}, total_steps={})...".format(
  279. eval_dataset.num_samples, total_steps))
  280. for step, data in tqdm.tqdm(
  281. enumerate(data_generator()), total=total_steps):
  282. images = np.array([d[0] for d in data]).astype('float32')
  283. im_infos = np.array([d[1] for d in data]).astype('float32')
  284. im_shapes = np.array([d[3] for d in data]).astype('float32')
  285. feed_data = {
  286. 'image': images,
  287. 'im_info': im_infos,
  288. 'im_shape': im_shapes,
  289. }
  290. with fluid.scope_guard(self.scope):
  291. outputs = self.exe.run(
  292. self.test_prog,
  293. feed=[feed_data],
  294. fetch_list=list(self.test_outputs.values()),
  295. return_numpy=False)
  296. res = {
  297. 'bbox': (np.array(outputs[0]),
  298. outputs[0].recursive_sequence_lengths()),
  299. 'mask': (np.array(outputs[1]),
  300. outputs[1].recursive_sequence_lengths())
  301. }
  302. res_im_id = [d[2] for d in data]
  303. res['im_info'] = (im_infos, [])
  304. res['im_shape'] = (im_shapes, [])
  305. res['im_id'] = (np.array(res_im_id), [])
  306. results.append(res)
  307. logging.debug("[EVAL] Epoch={}, Step={}/{}".format(epoch_id, step +
  308. 1, total_steps))
  309. ap_stats, eval_details = eval_results(
  310. results,
  311. 'COCO',
  312. eval_dataset.coco_gt,
  313. with_background=True,
  314. resolution=self.mask_head_resolution)
  315. if metric == 'VOC':
  316. if isinstance(ap_stats[0], np.ndarray) and isinstance(ap_stats[1],
  317. np.ndarray):
  318. metrics = OrderedDict(
  319. zip(['bbox_map', 'segm_map'],
  320. [ap_stats[0][1], ap_stats[1][1]]))
  321. else:
  322. metrics = OrderedDict(
  323. zip(['bbox_map', 'segm_map'], [0.0, 0.0]))
  324. elif metric == 'COCO':
  325. if isinstance(ap_stats[0], np.ndarray) and isinstance(ap_stats[1],
  326. np.ndarray):
  327. metrics = OrderedDict(
  328. zip(['bbox_mmap', 'segm_mmap'],
  329. [ap_stats[0][0], ap_stats[1][0]]))
  330. else:
  331. metrics = OrderedDict(
  332. zip(['bbox_mmap', 'segm_mmap'], [0.0, 0.0]))
  333. if return_details:
  334. return metrics, eval_details
  335. return metrics
  336. @staticmethod
  337. def _postprocess(res, batch_size, num_classes, mask_head_resolution,
  338. labels):
  339. clsid2catid = dict({i: i for i in range(num_classes)})
  340. xywh_results = bbox2out([res], clsid2catid)
  341. segm_results = mask2out([res], clsid2catid, mask_head_resolution)
  342. preds = [[] for i in range(batch_size)]
  343. import pycocotools.mask as mask_util
  344. for index, xywh_res in enumerate(xywh_results):
  345. image_id = xywh_res['image_id']
  346. del xywh_res['image_id']
  347. xywh_res['mask'] = mask_util.decode(segm_results[index][
  348. 'segmentation'])
  349. xywh_res['category'] = labels[xywh_res['category_id']]
  350. preds[image_id].append(xywh_res)
  351. return preds
  352. def predict(self, img_file, transforms=None):
  353. """预测。
  354. Args:
  355. img_file(str|np.ndarray): 预测图像路径,或者是解码后的排列格式为(H, W, C)且类型为float32且为BGR格式的数组。
  356. transforms (paddlex.det.transforms): 数据预处理操作。
  357. Returns:
  358. lict: 预测结果列表,每个预测结果由预测框类别标签、预测框类别名称、
  359. 预测框坐标(坐标格式为[xmin, ymin, w, h])、
  360. 原图大小的预测二值图(1表示预测框类别,0表示背景类)、
  361. 预测框得分组成。
  362. """
  363. if transforms is None and not hasattr(self, 'test_transforms'):
  364. raise Exception("transforms need to be defined, now is None.")
  365. if isinstance(img_file, (str, np.ndarray)):
  366. images = [img_file]
  367. else:
  368. raise Exception("img_file must be str/np.ndarray")
  369. if transforms is None:
  370. transforms = self.test_transforms
  371. input_channel = getattr(self, 'input_channel', 3)
  372. im, im_resize_info, im_shape = FasterRCNN._preprocess(
  373. images,
  374. transforms,
  375. self.model_type,
  376. self.__class__.__name__,
  377. input_channel=input_channel)
  378. with fluid.scope_guard(self.scope):
  379. result = self.exe.run(self.test_prog,
  380. feed={
  381. 'image': im,
  382. 'im_info': im_resize_info,
  383. 'im_shape': im_shape
  384. },
  385. fetch_list=list(self.test_outputs.values()),
  386. return_numpy=False,
  387. use_program_cache=True)
  388. res = {
  389. k: (np.array(v), v.recursive_sequence_lengths())
  390. for k, v in zip(list(self.test_outputs.keys()), result)
  391. }
  392. res['im_id'] = (np.array(
  393. [[i] for i in range(len(images))]).astype('int32'), [])
  394. res['im_shape'] = (np.array(im_shape), [])
  395. preds = MaskRCNN._postprocess(res,
  396. len(images), self.num_classes,
  397. self.mask_head_resolution, self.labels)
  398. return preds[0]
  399. def batch_predict(self, img_file_list, transforms=None):
  400. """预测。
  401. Args:
  402. img_file_list(list|tuple): 对列表(或元组)中的图像同时进行预测,列表中的元素可以是图像路径
  403. 也可以是解码后的排列格式为(H,W,C)且类型为float32且为BGR格式的数组。
  404. transforms (paddlex.det.transforms): 数据预处理操作。
  405. Returns:
  406. dict: 每个元素都为列表,表示各图像的预测结果。在各图像的预测结果列表中,每个预测结果由预测框类别标签、预测框类别名称、
  407. 预测框坐标(坐标格式为[xmin, ymin, w, h])、
  408. 原图大小的预测二值图(1表示预测框类别,0表示背景类)、
  409. 预测框得分组成。
  410. """
  411. if transforms is None and not hasattr(self, 'test_transforms'):
  412. raise Exception("transforms need to be defined, now is None.")
  413. if not isinstance(img_file_list, (list, tuple)):
  414. raise Exception("im_file must be list/tuple")
  415. if transforms is None:
  416. transforms = self.test_transforms
  417. input_channel = getattr(self, 'input_channel', 3)
  418. im, im_resize_info, im_shape = FasterRCNN._preprocess(
  419. img_file_list,
  420. transforms,
  421. self.model_type,
  422. self.__class__.__name__,
  423. self.thread_pool,
  424. input_channel=input_channel)
  425. with fluid.scope_guard(self.scope):
  426. result = self.exe.run(self.test_prog,
  427. feed={
  428. 'image': im,
  429. 'im_info': im_resize_info,
  430. 'im_shape': im_shape
  431. },
  432. fetch_list=list(self.test_outputs.values()),
  433. return_numpy=False,
  434. use_program_cache=True)
  435. res = {
  436. k: (np.array(v), v.recursive_sequence_lengths())
  437. for k, v in zip(list(self.test_outputs.keys()), result)
  438. }
  439. res['im_id'] = (np.array(
  440. [[i] for i in range(len(img_file_list))]).astype('int32'), [])
  441. res['im_shape'] = (np.array(im_shape), [])
  442. preds = MaskRCNN._postprocess(res,
  443. len(img_file_list), self.num_classes,
  444. self.mask_head_resolution, self.labels)
  445. return preds