ppyolo.py 27 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575
  1. # copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from __future__ import absolute_import
  15. import math
  16. import tqdm
  17. import os.path as osp
  18. import numpy as np
  19. from multiprocessing.pool import ThreadPool
  20. import paddle
  21. import paddle.fluid as fluid
  22. from paddle.fluid.layers.learning_rate_scheduler import _decay_step_counter
  23. from paddle.fluid.optimizer import ExponentialMovingAverage
  24. import paddlex.utils.logging as logging
  25. import paddlex
  26. import copy
  27. from paddlex.cv.transforms import arrange_transforms
  28. from paddlex.cv.datasets import generate_minibatch
  29. from .base import BaseAPI
  30. from collections import OrderedDict
  31. from .utils.detection_eval import eval_results, bbox2out
  32. class PPYOLO(BaseAPI):
  33. """构建PPYOLO,并实现其训练、评估、预测和模型导出。
  34. Args:
  35. num_classes (int): 类别数。默认为80。
  36. backbone (str): PPYOLO的backbone网络,取值范围为['ResNet50_vd_ssld']。默认为'ResNet50_vd_ssld'。
  37. with_dcn_v2 (bool): Backbone是否使用DCNv2结构。默认为True。
  38. anchors (list|tuple): anchor框的宽度和高度,为None时表示使用默认值
  39. [[10, 13], [16, 30], [33, 23], [30, 61], [62, 45],
  40. [59, 119], [116, 90], [156, 198], [373, 326]]。
  41. anchor_masks (list|tuple): 在计算PPYOLO损失时,使用anchor的mask索引,为None时表示使用默认值
  42. [[6, 7, 8], [3, 4, 5], [0, 1, 2]]。
  43. use_coord_conv (bool): 是否使用CoordConv。默认值为True。
  44. use_iou_aware (bool): 是否使用IoU Aware分支。默认值为True。
  45. use_spp (bool): 是否使用Spatial Pyramid Pooling结构。默认值为True。
  46. use_drop_block (bool): 是否使用Drop Block。默认值为True。
  47. scale_x_y (float): 调整中心点位置时的系数因子。默认值为1.05。
  48. use_iou_loss (bool): 是否使用IoU loss。默认值为True。
  49. use_matrix_nms (bool): 是否使用Matrix NMS。默认值为True。
  50. ignore_threshold (float): 在计算PPYOLO损失时,IoU大于`ignore_threshold`的预测框的置信度被忽略。默认为0.7。
  51. nms_score_threshold (float): 检测框的置信度得分阈值,置信度得分低于阈值的框应该被忽略。默认为0.01。
  52. nms_topk (int): 进行NMS时,根据置信度保留的最大检测框数。默认为1000。
  53. nms_keep_topk (int): 进行NMS后,每个图像要保留的总检测框数。默认为100。
  54. nms_iou_threshold (float): 进行NMS时,用于剔除检测框IOU的阈值。默认为0.45。
  55. label_smooth (bool): 是否使用label smooth。默认值为False。
  56. train_random_shapes (list|tuple): 训练时从列表中随机选择图像大小。默认值为[320, 352, 384, 416, 448, 480, 512, 544, 576, 608]。
  57. """
  58. def __init__(
  59. self,
  60. num_classes=80,
  61. backbone='ResNet50_vd_ssld',
  62. with_dcn_v2=True,
  63. # YOLO Head
  64. anchors=None,
  65. anchor_masks=None,
  66. use_coord_conv=True,
  67. use_iou_aware=True,
  68. use_spp=True,
  69. use_drop_block=True,
  70. scale_x_y=1.05,
  71. # PPYOLO Loss
  72. ignore_threshold=0.7,
  73. label_smooth=False,
  74. use_iou_loss=True,
  75. # NMS
  76. use_matrix_nms=True,
  77. nms_score_threshold=0.01,
  78. nms_topk=1000,
  79. nms_keep_topk=100,
  80. nms_iou_threshold=0.45,
  81. train_random_shapes=[
  82. 320, 352, 384, 416, 448, 480, 512, 544, 576, 608
  83. ]):
  84. self.init_params = locals()
  85. super(PPYOLO, self).__init__('detector')
  86. backbones = ['ResNet50_vd_ssld']
  87. assert backbone in backbones, "backbone should be one of {}".format(
  88. backbones)
  89. self.backbone = backbone
  90. self.num_classes = num_classes
  91. self.anchors = anchors
  92. self.anchor_masks = anchor_masks
  93. if anchors is None:
  94. self.anchors = [[10, 13], [16, 30], [33, 23], [30, 61], [62, 45],
  95. [59, 119], [116, 90], [156, 198], [373, 326]]
  96. if anchor_masks is None:
  97. self.anchor_masks = [[6, 7, 8], [3, 4, 5], [0, 1, 2]]
  98. self.ignore_threshold = ignore_threshold
  99. self.nms_score_threshold = nms_score_threshold
  100. self.nms_topk = nms_topk
  101. self.nms_keep_topk = nms_keep_topk
  102. self.nms_iou_threshold = nms_iou_threshold
  103. self.label_smooth = label_smooth
  104. self.sync_bn = True
  105. self.train_random_shapes = train_random_shapes
  106. self.fixed_input_shape = None
  107. self.use_fine_grained_loss = False
  108. if use_coord_conv or use_iou_aware or use_spp or use_drop_block or use_iou_loss:
  109. self.use_fine_grained_loss = True
  110. self.use_coord_conv = use_coord_conv
  111. self.use_iou_aware = use_iou_aware
  112. self.use_spp = use_spp
  113. self.use_drop_block = use_drop_block
  114. self.use_iou_loss = use_iou_loss
  115. self.scale_x_y = scale_x_y
  116. self.max_height = 608
  117. self.max_width = 608
  118. self.use_matrix_nms = use_matrix_nms
  119. self.use_ema = False
  120. self.with_dcn_v2 = with_dcn_v2
  121. if paddle.__version__ < '1.8.4' and paddle.__version__ != '0.0.0':
  122. raise Exception(
  123. "PPYOLO requires paddlepaddle or paddlepaddle-gpu >= 1.8.4")
  124. def _get_backbone(self, backbone_name):
  125. if backbone_name.startswith('ResNet50_vd'):
  126. backbone = paddlex.cv.nets.ResNet(
  127. norm_type='sync_bn',
  128. layers=50,
  129. freeze_norm=False,
  130. norm_decay=0.,
  131. feature_maps=[3, 4, 5],
  132. freeze_at=0,
  133. variant='d',
  134. dcn_v2_stages=[5] if self.with_dcn_v2 else [])
  135. return backbone
  136. def build_net(self, mode='train'):
  137. model = paddlex.cv.nets.detection.YOLOv3(
  138. backbone=self._get_backbone(self.backbone),
  139. num_classes=self.num_classes,
  140. mode=mode,
  141. anchors=self.anchors,
  142. anchor_masks=self.anchor_masks,
  143. ignore_threshold=self.ignore_threshold,
  144. label_smooth=self.label_smooth,
  145. nms_score_threshold=self.nms_score_threshold,
  146. nms_topk=self.nms_topk,
  147. nms_keep_topk=self.nms_keep_topk,
  148. nms_iou_threshold=self.nms_iou_threshold,
  149. fixed_input_shape=self.fixed_input_shape,
  150. coord_conv=self.use_coord_conv,
  151. iou_aware=self.use_iou_aware,
  152. scale_x_y=self.scale_x_y,
  153. spp=self.use_spp,
  154. drop_block=self.use_drop_block,
  155. use_matrix_nms=self.use_matrix_nms,
  156. use_fine_grained_loss=self.use_fine_grained_loss,
  157. use_iou_loss=self.use_iou_loss,
  158. batch_size=self.batch_size_per_gpu
  159. if hasattr(self, 'batch_size_per_gpu') else 8)
  160. if mode == 'train' and self.use_iou_loss or self.use_iou_aware:
  161. model.max_height = self.max_height
  162. model.max_width = self.max_width
  163. inputs = model.generate_inputs()
  164. model_out = model.build_net(inputs)
  165. outputs = OrderedDict([('bbox', model_out)])
  166. if mode == 'train':
  167. self.optimizer.minimize(model_out)
  168. outputs = OrderedDict([('loss', model_out)])
  169. if self.use_ema:
  170. global_steps = _decay_step_counter()
  171. self.ema = ExponentialMovingAverage(
  172. self.ema_decay, thres_steps=global_steps)
  173. self.ema.update()
  174. return inputs, outputs
  175. def default_optimizer(self, learning_rate, warmup_steps, warmup_start_lr,
  176. lr_decay_epochs, lr_decay_gamma,
  177. num_steps_each_epoch):
  178. if warmup_steps > lr_decay_epochs[0] * num_steps_each_epoch:
  179. logging.error(
  180. "In function train(), parameters should satisfy: warmup_steps <= lr_decay_epochs[0]*num_samples_in_train_dataset",
  181. exit=False)
  182. logging.error(
  183. "See this doc for more information: https://github.com/PaddlePaddle/PaddleX/blob/develop/docs/appendix/parameters.md#notice",
  184. exit=False)
  185. logging.error(
  186. "warmup_steps should less than {} or lr_decay_epochs[0] greater than {}, please modify 'lr_decay_epochs' or 'warmup_steps' in train function".
  187. format(lr_decay_epochs[0] * num_steps_each_epoch, warmup_steps
  188. // num_steps_each_epoch))
  189. boundaries = [b * num_steps_each_epoch for b in lr_decay_epochs]
  190. values = [(lr_decay_gamma**i) * learning_rate
  191. for i in range(len(lr_decay_epochs) + 1)]
  192. lr_decay = fluid.layers.piecewise_decay(
  193. boundaries=boundaries, values=values)
  194. lr_warmup = fluid.layers.linear_lr_warmup(
  195. learning_rate=lr_decay,
  196. warmup_steps=warmup_steps,
  197. start_lr=warmup_start_lr,
  198. end_lr=learning_rate)
  199. optimizer = fluid.optimizer.Momentum(
  200. learning_rate=lr_warmup,
  201. momentum=0.9,
  202. regularization=fluid.regularizer.L2DecayRegularizer(5e-04))
  203. return optimizer
  204. def train(self,
  205. num_epochs,
  206. train_dataset,
  207. train_batch_size=8,
  208. eval_dataset=None,
  209. save_interval_epochs=20,
  210. log_interval_steps=2,
  211. save_dir='output',
  212. pretrain_weights='IMAGENET',
  213. optimizer=None,
  214. learning_rate=1.0 / 8000,
  215. warmup_steps=1000,
  216. warmup_start_lr=0.0,
  217. lr_decay_epochs=[213, 240],
  218. lr_decay_gamma=0.1,
  219. metric=None,
  220. use_vdl=False,
  221. sensitivities_file=None,
  222. eval_metric_loss=0.05,
  223. early_stop=False,
  224. early_stop_patience=5,
  225. resume_checkpoint=None,
  226. use_ema=True,
  227. ema_decay=0.9998):
  228. """训练。
  229. Args:
  230. num_epochs (int): 训练迭代轮数。
  231. train_dataset (paddlex.datasets): 训练数据读取器。
  232. train_batch_size (int): 训练数据batch大小。目前检测仅支持单卡评估,训练数据batch大小与显卡
  233. 数量之商为验证数据batch大小。默认值为8。
  234. eval_dataset (paddlex.datasets): 验证数据读取器。
  235. save_interval_epochs (int): 模型保存间隔(单位:迭代轮数)。默认为20。
  236. log_interval_steps (int): 训练日志输出间隔(单位:迭代次数)。默认为10。
  237. save_dir (str): 模型保存路径。默认值为'output'。
  238. pretrain_weights (str): 若指定为路径时,则加载路径下预训练模型;若为字符串'IMAGENET',
  239. 则自动下载在ImageNet图片数据上预训练的模型权重;若为字符串'COCO',
  240. 则自动下载在COCO数据集上预训练的模型权重;若为None,则不使用预训练模型。默认为'IMAGENET'。
  241. optimizer (paddle.fluid.optimizer): 优化器。当该参数为None时,使用默认优化器:
  242. fluid.layers.piecewise_decay衰减策略,fluid.optimizer.Momentum优化方法。
  243. learning_rate (float): 默认优化器的学习率。默认为1.0/8000。
  244. warmup_steps (int): 默认优化器进行warmup过程的步数。默认为1000。
  245. warmup_start_lr (int): 默认优化器warmup的起始学习率。默认为0.0。
  246. lr_decay_epochs (list): 默认优化器的学习率衰减轮数。默认为[213, 240]。
  247. lr_decay_gamma (float): 默认优化器的学习率衰减率。默认为0.1。
  248. metric (bool): 训练过程中评估的方式,取值范围为['COCO', 'VOC']。默认值为None。
  249. use_vdl (bool): 是否使用VisualDL进行可视化。默认值为False。
  250. sensitivities_file (str): 若指定为路径时,则加载路径下敏感度信息进行裁剪;若为字符串'DEFAULT',
  251. 则自动下载在ImageNet图片数据上获得的敏感度信息进行裁剪;若为None,则不进行裁剪。默认为None。
  252. eval_metric_loss (float): 可容忍的精度损失。默认为0.05。
  253. early_stop (bool): 是否使用提前终止训练策略。默认值为False。
  254. early_stop_patience (int): 当使用提前终止训练策略时,如果验证集精度在`early_stop_patience`个epoch内
  255. 连续下降或持平,则终止训练。默认值为5。
  256. resume_checkpoint (str): 恢复训练时指定上次训练保存的模型路径。若为None,则不会恢复训练。默认值为None。
  257. use_ema (bool): 是否使用指数衰减计算参数的滑动平均值。默认值为True。
  258. ema_decay (float): 指数衰减率。默认值为0.9998。
  259. Raises:
  260. ValueError: 评估类型不在指定列表中。
  261. ValueError: 模型从inference model进行加载。
  262. """
  263. if not self.trainable:
  264. raise ValueError("Model is not trainable from load_model method.")
  265. if metric is None:
  266. if isinstance(train_dataset, paddlex.datasets.CocoDetection):
  267. metric = 'COCO'
  268. elif isinstance(train_dataset, paddlex.datasets.VOCDetection) or \
  269. isinstance(train_dataset, paddlex.datasets.EasyDataDet):
  270. metric = 'VOC'
  271. else:
  272. raise ValueError(
  273. "train_dataset should be datasets.VOCDetection or datasets.COCODetection or datasets.EasyDataDet."
  274. )
  275. assert metric in ['COCO', 'VOC'], "Metric only support 'VOC' or 'COCO'"
  276. self.metric = metric
  277. self.labels = train_dataset.labels
  278. # 构建训练网络
  279. if optimizer is None:
  280. # 构建默认的优化策略
  281. num_steps_each_epoch = train_dataset.num_samples // train_batch_size
  282. optimizer = self.default_optimizer(
  283. learning_rate=learning_rate,
  284. warmup_steps=warmup_steps,
  285. warmup_start_lr=warmup_start_lr,
  286. lr_decay_epochs=lr_decay_epochs,
  287. lr_decay_gamma=lr_decay_gamma,
  288. num_steps_each_epoch=num_steps_each_epoch)
  289. self.optimizer = optimizer
  290. self.use_ema = use_ema
  291. self.ema_decay = ema_decay
  292. self.batch_size_per_gpu = int(train_batch_size /
  293. paddlex.env_info['num'])
  294. if self.use_fine_grained_loss:
  295. for transform in train_dataset.transforms.transforms:
  296. if isinstance(transform, paddlex.det.transforms.Resize):
  297. self.max_height = transform.target_size
  298. self.max_width = transform.target_size
  299. break
  300. if train_dataset.transforms.batch_transforms is None:
  301. train_dataset.transforms.batch_transforms = list()
  302. define_random_shape = False
  303. for bt in train_dataset.transforms.batch_transforms:
  304. if isinstance(bt, paddlex.det.transforms.BatchRandomShape):
  305. define_random_shape = True
  306. if not define_random_shape:
  307. if isinstance(self.train_random_shapes,
  308. (list, tuple)) and len(self.train_random_shapes) > 0:
  309. train_dataset.transforms.batch_transforms.append(
  310. paddlex.det.transforms.BatchRandomShape(
  311. random_shapes=self.train_random_shapes))
  312. if self.use_fine_grained_loss:
  313. self.max_height = max(self.max_height,
  314. max(self.train_random_shapes))
  315. self.max_width = max(self.max_width,
  316. max(self.train_random_shapes))
  317. if self.use_fine_grained_loss:
  318. define_generate_target = False
  319. for bt in train_dataset.transforms.batch_transforms:
  320. if isinstance(bt, paddlex.det.transforms.GenerateYoloTarget):
  321. define_generate_target = True
  322. if not define_generate_target:
  323. train_dataset.transforms.batch_transforms.append(
  324. paddlex.det.transforms.GenerateYoloTarget(
  325. anchors=self.anchors,
  326. anchor_masks=self.anchor_masks,
  327. num_classes=self.num_classes,
  328. downsample_ratios=[32, 16, 8]))
  329. # 构建训练、验证、预测网络
  330. self.build_program()
  331. # 初始化网络权重
  332. self.net_initialize(
  333. startup_prog=fluid.default_startup_program(),
  334. pretrain_weights=pretrain_weights,
  335. save_dir=save_dir,
  336. sensitivities_file=sensitivities_file,
  337. eval_metric_loss=eval_metric_loss,
  338. resume_checkpoint=resume_checkpoint)
  339. # 训练
  340. self.train_loop(
  341. num_epochs=num_epochs,
  342. train_dataset=train_dataset,
  343. train_batch_size=train_batch_size,
  344. eval_dataset=eval_dataset,
  345. save_interval_epochs=save_interval_epochs,
  346. log_interval_steps=log_interval_steps,
  347. save_dir=save_dir,
  348. use_vdl=use_vdl,
  349. early_stop=early_stop,
  350. early_stop_patience=early_stop_patience)
  351. def evaluate(self,
  352. eval_dataset,
  353. batch_size=1,
  354. epoch_id=None,
  355. metric=None,
  356. return_details=False):
  357. """评估。
  358. Args:
  359. eval_dataset (paddlex.datasets): 验证数据读取器。
  360. batch_size (int): 验证数据批大小。默认为1。
  361. epoch_id (int): 当前评估模型所在的训练轮数。
  362. metric (bool): 训练过程中评估的方式,取值范围为['COCO', 'VOC']。默认为None,
  363. 根据用户传入的Dataset自动选择,如为VOCDetection,则metric为'VOC';
  364. 如为COCODetection,则metric为'COCO'。
  365. return_details (bool): 是否返回详细信息。
  366. Returns:
  367. tuple (metrics, eval_details) | dict (metrics): 当return_details为True时,返回(metrics, eval_details),
  368. 当return_details为False时,返回metrics。metrics为dict,包含关键字:'bbox_mmap'或者’bbox_map‘,
  369. 分别表示平均准确率平均值在各个IoU阈值下的结果取平均值的结果(mmAP)、平均准确率平均值(mAP)。
  370. eval_details为dict,包含bbox和gt两个关键字。其中关键字bbox的键值是一个列表,列表中每个元素代表一个预测结果,
  371. 一个预测结果是一个由图像id,预测框类别id, 预测框坐标,预测框得分组成的列表。而关键字gt的键值是真实标注框的相关信息。
  372. """
  373. arrange_transforms(
  374. model_type=self.model_type,
  375. class_name=self.__class__.__name__,
  376. transforms=eval_dataset.transforms,
  377. mode='eval')
  378. if metric is None:
  379. if hasattr(self, 'metric') and self.metric is not None:
  380. metric = self.metric
  381. else:
  382. if isinstance(eval_dataset, paddlex.datasets.CocoDetection):
  383. metric = 'COCO'
  384. elif isinstance(eval_dataset, paddlex.datasets.VOCDetection):
  385. metric = 'VOC'
  386. else:
  387. raise Exception(
  388. "eval_dataset should be datasets.VOCDetection or datasets.COCODetection."
  389. )
  390. assert metric in ['COCO', 'VOC'], "Metric only support 'VOC' or 'COCO'"
  391. total_steps = math.ceil(eval_dataset.num_samples * 1.0 / batch_size)
  392. results = list()
  393. data_generator = eval_dataset.generator(
  394. batch_size=batch_size, drop_last=False)
  395. logging.info(
  396. "Start to evaluating(total_samples={}, total_steps={})...".format(
  397. eval_dataset.num_samples, total_steps))
  398. for step, data in tqdm.tqdm(
  399. enumerate(data_generator()), total=total_steps):
  400. images = np.array([d[0] for d in data])
  401. im_sizes = np.array([d[1] for d in data])
  402. feed_data = {'image': images, 'im_size': im_sizes}
  403. with fluid.scope_guard(self.scope):
  404. outputs = self.exe.run(
  405. self.test_prog,
  406. feed=[feed_data],
  407. fetch_list=list(self.test_outputs.values()),
  408. return_numpy=False)
  409. res = {
  410. 'bbox': (np.array(outputs[0]),
  411. outputs[0].recursive_sequence_lengths())
  412. }
  413. res_id = [np.array([d[2]]) for d in data]
  414. res['im_id'] = (res_id, [])
  415. if metric == 'VOC':
  416. res_gt_box = [d[3].reshape(-1, 4) for d in data]
  417. res_gt_label = [d[4].reshape(-1, 1) for d in data]
  418. res_is_difficult = [d[5].reshape(-1, 1) for d in data]
  419. res_id = [np.array([d[2]]) for d in data]
  420. res['gt_box'] = (res_gt_box, [])
  421. res['gt_label'] = (res_gt_label, [])
  422. res['is_difficult'] = (res_is_difficult, [])
  423. results.append(res)
  424. logging.debug("[EVAL] Epoch={}, Step={}/{}".format(epoch_id, step +
  425. 1, total_steps))
  426. box_ap_stats, eval_details = eval_results(
  427. results, metric, eval_dataset.coco_gt, with_background=False)
  428. evaluate_metrics = OrderedDict(
  429. zip(['bbox_mmap'
  430. if metric == 'COCO' else 'bbox_map'], box_ap_stats))
  431. if return_details:
  432. return evaluate_metrics, eval_details
  433. return evaluate_metrics
  434. @staticmethod
  435. def _preprocess(images,
  436. transforms,
  437. model_type,
  438. class_name,
  439. thread_pool=None):
  440. arrange_transforms(
  441. model_type=model_type,
  442. class_name=class_name,
  443. transforms=transforms,
  444. mode='test')
  445. if thread_pool is not None:
  446. batch_data = thread_pool.map(transforms, images)
  447. else:
  448. batch_data = list()
  449. for image in images:
  450. batch_data.append(transforms(image))
  451. padding_batch = generate_minibatch(batch_data)
  452. im = np.array(
  453. [data[0] for data in padding_batch],
  454. dtype=padding_batch[0][0].dtype)
  455. im_size = np.array([data[1] for data in padding_batch], dtype=np.int32)
  456. return im, im_size
  457. @staticmethod
  458. def _postprocess(res, batch_size, num_classes, labels):
  459. clsid2catid = dict({i: i for i in range(num_classes)})
  460. xywh_results = bbox2out([res], clsid2catid)
  461. preds = [[] for i in range(batch_size)]
  462. for xywh_res in xywh_results:
  463. image_id = xywh_res['image_id']
  464. del xywh_res['image_id']
  465. xywh_res['category'] = labels[xywh_res['category_id']]
  466. preds[image_id].append(xywh_res)
  467. return preds
  468. def predict(self, img_file, transforms=None):
  469. """预测。
  470. Args:
  471. img_file (str|np.ndarray): 预测图像路径,或者是解码后的排列格式为(H, W, C)且类型为float32且为BGR格式的数组。
  472. transforms (paddlex.det.transforms): 数据预处理操作。
  473. Returns:
  474. list: 预测结果列表,每个预测结果由预测框类别标签、
  475. 预测框类别名称、预测框坐标(坐标格式为[xmin, ymin, w, h])、
  476. 预测框得分组成。
  477. """
  478. if transforms is None and not hasattr(self, 'test_transforms'):
  479. raise Exception("transforms need to be defined, now is None.")
  480. if isinstance(img_file, (str, np.ndarray)):
  481. images = [img_file]
  482. else:
  483. raise Exception("img_file must be str/np.ndarray")
  484. if transforms is None:
  485. transforms = self.test_transforms
  486. im, im_size = PPYOLO._preprocess(images, transforms, self.model_type,
  487. self.__class__.__name__)
  488. with fluid.scope_guard(self.scope):
  489. result = self.exe.run(self.test_prog,
  490. feed={'image': im,
  491. 'im_size': im_size},
  492. fetch_list=list(self.test_outputs.values()),
  493. return_numpy=False,
  494. use_program_cache=True)
  495. res = {
  496. k: (np.array(v), v.recursive_sequence_lengths())
  497. for k, v in zip(list(self.test_outputs.keys()), result)
  498. }
  499. res['im_id'] = (np.array(
  500. [[i] for i in range(len(images))]).astype('int32'), [[]])
  501. preds = PPYOLO._postprocess(res,
  502. len(images), self.num_classes, self.labels)
  503. return preds[0]
  504. def batch_predict(self, img_file_list, transforms=None):
  505. """预测。
  506. Args:
  507. img_file_list (list|tuple): 对列表(或元组)中的图像同时进行预测,列表中的元素可以是图像路径,也可以是解码后的排列格式为(H,W,C)
  508. 且类型为float32且为BGR格式的数组。
  509. transforms (paddlex.det.transforms): 数据预处理操作。
  510. Returns:
  511. list: 每个元素都为列表,表示各图像的预测结果。在各图像的预测结果列表中,每个预测结果由预测框类别标签、
  512. 预测框类别名称、预测框坐标(坐标格式为[xmin, ymin, w, h])、
  513. 预测框得分组成。
  514. """
  515. if transforms is None and not hasattr(self, 'test_transforms'):
  516. raise Exception("transforms need to be defined, now is None.")
  517. if not isinstance(img_file_list, (list, tuple)):
  518. raise Exception("im_file must be list/tuple")
  519. if transforms is None:
  520. transforms = self.test_transforms
  521. im, im_size = PPYOLO._preprocess(
  522. img_file_list, transforms, self.model_type,
  523. self.__class__.__name__, self.thread_pool)
  524. with fluid.scope_guard(self.scope):
  525. result = self.exe.run(self.test_prog,
  526. feed={'image': im,
  527. 'im_size': im_size},
  528. fetch_list=list(self.test_outputs.values()),
  529. return_numpy=False,
  530. use_program_cache=True)
  531. res = {
  532. k: (np.array(v), v.recursive_sequence_lengths())
  533. for k, v in zip(list(self.test_outputs.keys()), result)
  534. }
  535. res['im_id'] = (np.array(
  536. [[i] for i in range(len(img_file_list))]).astype('int32'), [[]])
  537. preds = PPYOLO._postprocess(res,
  538. len(img_file_list), self.num_classes,
  539. self.labels)
  540. return preds