base.py 21 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518
  1. # Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import os
  15. import os.path as osp
  16. from functools import partial
  17. import time
  18. import copy
  19. import math
  20. import yaml
  21. import json
  22. import paddle
  23. from paddle.io import DataLoader, DistributedBatchSampler
  24. from paddleslim import QAT
  25. from paddleslim.analysis import flops
  26. from paddleslim import L1NormFilterPruner, FPGMFilterPruner
  27. import paddlex
  28. from paddlex.cv.transforms import arrange_transforms
  29. from paddlex.utils import (seconds_to_hms, get_single_card_bs, dict2str,
  30. get_pretrain_weights, load_pretrain_weights,
  31. SmoothedValue, TrainingStats,
  32. _get_shared_memory_size_in_M, EarlyStop)
  33. import paddlex.utils.logging as logging
  34. from .slim.prune import _pruner_eval_fn, _pruner_template_input, sensitive_prune
  35. class BaseModel:
  36. def __init__(self, model_type):
  37. self.model_type = model_type
  38. self.num_classes = None
  39. self.labels = None
  40. self.version = paddlex.__version__
  41. self.net = None
  42. self.optimizer = None
  43. self.test_inputs = None
  44. self.train_data_loader = None
  45. self.eval_data_loader = None
  46. self.eval_metrics = None
  47. # 是否使用多卡间同步BatchNorm均值和方差
  48. self.sync_bn = False
  49. self.status = 'Normal'
  50. # 已完成迭代轮数,为恢复训练时的起始轮数
  51. self.completed_epochs = 0
  52. self.pruner = None
  53. self.pruning_ratios = None
  54. self.quantizer = None
  55. self.quant_config = None
  56. def net_initialize(self, pretrain_weights=None, save_dir='.'):
  57. if pretrain_weights is not None and \
  58. not os.path.exists(pretrain_weights):
  59. if not os.path.isdir(save_dir):
  60. if os.path.exists(save_dir):
  61. os.remove(save_dir)
  62. os.makedirs(save_dir)
  63. if self.model_type == 'classifier':
  64. pretrain_weights = get_pretrain_weights(
  65. pretrain_weights, self.model_name, save_dir)
  66. else:
  67. backbone_name = getattr(self, 'backbone_name', None)
  68. pretrain_weights = get_pretrain_weights(
  69. pretrain_weights,
  70. self.__class__.__name__,
  71. save_dir,
  72. backbone_name=backbone_name)
  73. if pretrain_weights is not None:
  74. load_pretrain_weights(
  75. self.net, pretrain_weights, model_name=self.model_name)
  76. def get_model_info(self):
  77. info = dict()
  78. info['version'] = paddlex.__version__
  79. info['Model'] = self.__class__.__name__
  80. info['_Attributes'] = {'model_type': self.model_type}
  81. if 'self' in self.init_params:
  82. del self.init_params['self']
  83. if '__class__' in self.init_params:
  84. del self.init_params['__class__']
  85. if 'model_name' in self.init_params:
  86. del self.init_params['model_name']
  87. if 'params' in self.init_params:
  88. del self.init_params['params']
  89. info['_init_params'] = self.init_params
  90. info['_Attributes']['num_classes'] = self.num_classes
  91. info['_Attributes']['labels'] = self.labels
  92. try:
  93. primary_metric_key = list(self.eval_metrics.keys())[0]
  94. primary_metric_value = float(self.eval_metrics[primary_metric_key])
  95. info['_Attributes']['eval_metrics'] = {
  96. primary_metric_key: primary_metric_value
  97. }
  98. except:
  99. pass
  100. if hasattr(self, 'test_transforms'):
  101. if self.test_transforms is not None:
  102. info['Transforms'] = list()
  103. for op in self.test_transforms.transforms:
  104. name = op.__class__.__name__
  105. if name.startswith('Arrange'):
  106. continue
  107. attr = op.__dict__
  108. info['Transforms'].append({name: attr})
  109. info['completed_epochs'] = self.completed_epochs
  110. return info
  111. def get_pruning_info(self):
  112. info = dict()
  113. info['pruner'] = self.pruner.__class__.__name__
  114. info['pruning_ratios'] = self.pruning_ratios
  115. pruner_inputs = self.pruner.inputs
  116. if self.model_type == 'detector':
  117. pruner_inputs = {
  118. k: v.tolist()
  119. for k, v in pruner_inputs[0].items()
  120. }
  121. info['pruner_inputs'] = pruner_inputs
  122. return info
  123. def get_quant_info(self):
  124. info = dict()
  125. info['quant_config'] = self.quant_config
  126. return info
  127. def save_model(self, save_dir):
  128. if not osp.isdir(save_dir):
  129. if osp.exists(save_dir):
  130. os.remove(save_dir)
  131. os.makedirs(save_dir)
  132. model_info = self.get_model_info()
  133. model_info['status'] = self.status
  134. paddle.save(self.net.state_dict(),
  135. osp.join(save_dir, 'model.pdparams'))
  136. paddle.save(self.optimizer.state_dict(),
  137. osp.join(save_dir, 'model.pdopt'))
  138. with open(
  139. osp.join(save_dir, 'model.yml'), encoding='utf-8',
  140. mode='w') as f:
  141. yaml.dump(model_info, f)
  142. # 评估结果保存
  143. if hasattr(self, 'eval_details'):
  144. with open(osp.join(save_dir, 'eval_details.json'), 'w') as f:
  145. json.dump(self.eval_details, f)
  146. if self.status == 'Pruned' and self.pruner is not None:
  147. pruning_info = self.get_pruning_info()
  148. with open(
  149. osp.join(save_dir, 'prune.yml'), encoding='utf-8',
  150. mode='w') as f:
  151. yaml.dump(pruning_info, f)
  152. if self.status == 'Quantized' and self.quantizer is not None:
  153. quant_info = self.get_quant_info()
  154. with open(
  155. osp.join(save_dir, 'quant.yml'), encoding='utf-8',
  156. mode='w') as f:
  157. yaml.dump(quant_info, f)
  158. # 模型保存成功的标志
  159. open(osp.join(save_dir, '.success'), 'w').close()
  160. logging.info("Model saved in {}.".format(save_dir))
  161. def build_data_loader(self, dataset, batch_size, mode='train'):
  162. if dataset.num_samples < batch_size:
  163. raise Exception(
  164. 'The volume of dataset({}) must be larger than batch size({}).'
  165. .format(dataset.num_samples, batch_size))
  166. batch_size_each_card = get_single_card_bs(batch_size=batch_size)
  167. # TODO detection eval阶段需做判断
  168. batch_sampler = DistributedBatchSampler(
  169. dataset,
  170. batch_size=batch_size_each_card,
  171. shuffle=dataset.shuffle,
  172. drop_last=mode == 'train')
  173. shm_size = _get_shared_memory_size_in_M()
  174. if shm_size is None or shm_size < 1024.:
  175. use_shared_memory = False
  176. else:
  177. use_shared_memory = True
  178. loader = DataLoader(
  179. dataset,
  180. batch_sampler=batch_sampler,
  181. collate_fn=dataset.batch_transforms,
  182. num_workers=dataset.num_workers,
  183. return_list=True,
  184. use_shared_memory=use_shared_memory)
  185. return loader
  186. def train_loop(self,
  187. num_epochs,
  188. train_dataset,
  189. train_batch_size,
  190. eval_dataset=None,
  191. save_interval_epochs=1,
  192. log_interval_steps=10,
  193. save_dir='output',
  194. ema=None,
  195. early_stop=False,
  196. early_stop_patience=5,
  197. use_vdl=True):
  198. arrange_transforms(
  199. model_type=self.model_type,
  200. transforms=train_dataset.transforms,
  201. mode='train')
  202. nranks = paddle.distributed.get_world_size()
  203. local_rank = paddle.distributed.get_rank()
  204. if nranks > 1:
  205. find_unused_parameters = getattr(self, 'find_unused_parameters',
  206. False)
  207. # Initialize parallel environment if not done.
  208. if not paddle.distributed.parallel.parallel_helper._is_parallel_ctx_initialized(
  209. ):
  210. paddle.distributed.init_parallel_env()
  211. ddp_net = paddle.DataParallel(
  212. self.net, find_unused_parameters=find_unused_parameters)
  213. else:
  214. ddp_net = paddle.DataParallel(
  215. self.net, find_unused_parameters=find_unused_parameters)
  216. if use_vdl:
  217. from visualdl import LogWriter
  218. vdl_logdir = osp.join(save_dir, 'vdl_log')
  219. log_writer = LogWriter(vdl_logdir)
  220. # task_id: 目前由PaddleX GUI赋值
  221. # 用于在VisualDL日志中注明所属任务id
  222. task_id = getattr(paddlex, "task_id", "")
  223. thresh = .0001
  224. if early_stop:
  225. earlystop = EarlyStop(early_stop_patience, thresh)
  226. self.train_data_loader = self.build_data_loader(
  227. train_dataset, batch_size=train_batch_size, mode='train')
  228. if eval_dataset is not None:
  229. self.test_transforms = copy.deepcopy(eval_dataset.transforms)
  230. start_epoch = self.completed_epochs
  231. train_step_time = SmoothedValue(log_interval_steps)
  232. train_step_each_epoch = math.floor(train_dataset.num_samples /
  233. train_batch_size)
  234. train_total_step = train_step_each_epoch * (num_epochs - start_epoch)
  235. if eval_dataset is not None:
  236. eval_batch_size = train_batch_size
  237. eval_epoch_time = 0
  238. best_accuracy_key = ""
  239. best_accuracy = -1.0
  240. best_model_epoch = -1
  241. current_step = 0
  242. for i in range(start_epoch, num_epochs):
  243. self.net.train()
  244. if callable(
  245. getattr(self.train_data_loader.dataset, 'set_epoch',
  246. None)):
  247. self.train_data_loader.dataset.set_epoch(i)
  248. train_avg_metrics = TrainingStats()
  249. step_time_tic = time.time()
  250. for step, data in enumerate(self.train_data_loader()):
  251. if nranks > 1:
  252. outputs = self.run(ddp_net, data, mode='train')
  253. else:
  254. outputs = self.run(self.net, data, mode='train')
  255. loss = outputs['loss']
  256. loss.backward()
  257. self.optimizer.step()
  258. self.optimizer.clear_grad()
  259. lr = self.optimizer.get_lr()
  260. if isinstance(self.optimizer._learning_rate,
  261. paddle.optimizer.lr.LRScheduler):
  262. self.optimizer._learning_rate.step()
  263. train_avg_metrics.update(outputs)
  264. outputs['lr'] = lr
  265. if ema is not None:
  266. ema.update(self.net)
  267. step_time_toc = time.time()
  268. train_step_time.update(step_time_toc - step_time_tic)
  269. step_time_tic = step_time_toc
  270. current_step += 1
  271. # 每间隔log_interval_steps,输出loss信息
  272. if current_step % log_interval_steps == 0 and local_rank == 0:
  273. if use_vdl:
  274. for k, v in outputs.items():
  275. log_writer.add_scalar(
  276. '{}-Metrics/Training(Step): {}'.format(
  277. task_id, k), v, current_step)
  278. # 估算剩余时间
  279. avg_step_time = train_step_time.avg()
  280. eta = avg_step_time * (train_total_step - current_step)
  281. if eval_dataset is not None:
  282. eval_num_epochs = math.ceil(
  283. (num_epochs - i - 1) / save_interval_epochs)
  284. if eval_epoch_time == 0:
  285. eta += avg_step_time * math.ceil(
  286. eval_dataset.num_samples / eval_batch_size)
  287. else:
  288. eta += eval_epoch_time * eval_num_epochs
  289. logging.info(
  290. "[TRAIN] Epoch={}/{}, Step={}/{}, {}, time_each_step={}s, eta={}"
  291. .format(i + 1, num_epochs, step + 1,
  292. train_step_each_epoch,
  293. dict2str(outputs),
  294. round(avg_step_time, 2), seconds_to_hms(eta)))
  295. logging.info('[TRAIN] Epoch {} finished, {} .'
  296. .format(i + 1, train_avg_metrics.log()))
  297. self.completed_epochs += 1
  298. # 每间隔save_interval_epochs, 在验证集上评估和对模型进行保存
  299. if ema is not None:
  300. weight = self.net.state_dict()
  301. self.net.set_dict(ema.apply())
  302. eval_epoch_tic = time.time()
  303. if (i + 1) % save_interval_epochs == 0 or i == num_epochs - 1:
  304. if eval_dataset is not None and eval_dataset.num_samples > 0:
  305. eval_result = self.evaluate(
  306. eval_dataset,
  307. batch_size=eval_batch_size,
  308. return_details=True)
  309. # 保存最优模型
  310. if local_rank == 0:
  311. self.eval_metrics, self.eval_details = eval_result
  312. logging.info('[EVAL] Finished, Epoch={}, {} .'.format(
  313. i + 1, dict2str(self.eval_metrics)))
  314. best_accuracy_key = list(self.eval_metrics.keys())[0]
  315. current_accuracy = self.eval_metrics[best_accuracy_key]
  316. if current_accuracy > best_accuracy:
  317. best_accuracy = current_accuracy
  318. best_model_epoch = i + 1
  319. best_model_dir = osp.join(save_dir, "best_model")
  320. self.save_model(save_dir=best_model_dir)
  321. if best_model_epoch > 0:
  322. logging.info(
  323. 'Current evaluated best model in eval_dataset is epoch_{}, {}={}'
  324. .format(best_model_epoch, best_accuracy_key,
  325. best_accuracy))
  326. eval_epoch_time = time.time() - eval_epoch_tic
  327. current_save_dir = osp.join(save_dir, "epoch_{}".format(i + 1))
  328. if local_rank == 0:
  329. self.save_model(save_dir=current_save_dir)
  330. if eval_dataset is not None and early_stop:
  331. if earlystop(current_accuracy):
  332. break
  333. if ema is not None:
  334. self.net.set_dict(weight)
  335. def analyze_sensitivity(self,
  336. dataset,
  337. batch_size=8,
  338. criterion='l1_norm',
  339. save_dir='output'):
  340. """
  341. Args:
  342. dataset(paddlex.dataset): Dataset used for evaluation during sensitivity analysis.
  343. batch_size(int, optional): Batch size used in evaluation. Defaults to 8.
  344. criterion({'l1_norm', 'fpgm'}, optional): Pruning criterion. Defaults to 'l1_norm'.
  345. save_dir(str, optional): The directory to save sensitivity file of the model. Defaults to 'output'.
  346. """
  347. if self.__class__.__name__ in ['FasterRCNN', 'MaskRCNN']:
  348. raise Exception("{} does not support pruning currently!".format(
  349. self.__class__.__name__))
  350. assert criterion in ['l1_norm', 'fpgm'], \
  351. "Pruning criterion {} is not supported. Please choose from ['l1_norm', 'fpgm']"
  352. arrange_transforms(
  353. model_type=self.model_type,
  354. transforms=dataset.transforms,
  355. mode='eval')
  356. if self.model_type == 'detector':
  357. self.net.eval()
  358. else:
  359. self.net.train()
  360. inputs = _pruner_template_input(
  361. sample=dataset[0], model_type=self.model_type)
  362. if criterion == 'l1_norm':
  363. self.pruner = L1NormFilterPruner(self.net, inputs=inputs)
  364. else:
  365. self.pruner = FPGMFilterPruner(self.net, inputs=inputs)
  366. if not osp.isdir(save_dir):
  367. os.makedirs(save_dir)
  368. sen_file = osp.join(save_dir, 'model.sensi.data')
  369. logging.info('Sensitivity analysis of model parameters starts...')
  370. self.pruner.sensitive(
  371. eval_func=partial(_pruner_eval_fn, self, dataset, batch_size),
  372. sen_file=sen_file)
  373. logging.info(
  374. 'Sensitivity analysis is complete. The result is saved at {}.'.
  375. format(sen_file))
  376. def prune(self, pruned_flops, save_dir=None):
  377. """
  378. Args:
  379. pruned_flops(float): Ratio of FLOPs to be pruned.
  380. save_dir(None or str, optional): If None, the pruned model will not be saved.
  381. Otherwise, the pruned model will be saved at save_dir. Defaults to None.
  382. """
  383. if self.status == "Pruned":
  384. raise Exception(
  385. "A pruned model cannot be done model pruning again!")
  386. pre_pruning_flops = flops(self.net, self.pruner.inputs)
  387. logging.info("Pre-pruning FLOPs: {}. Pruning starts...".format(
  388. pre_pruning_flops))
  389. _, self.pruning_ratios = sensitive_prune(self.pruner, pruned_flops)
  390. post_pruning_flops = flops(self.net, self.pruner.inputs)
  391. logging.info("Pruning is complete. Post-pruning FLOPs: {}".format(
  392. post_pruning_flops))
  393. logging.warning("Pruning the model may hurt its performance, "
  394. "retraining is highly recommended")
  395. self.status = 'Pruned'
  396. if save_dir is not None:
  397. self.save_model(save_dir)
  398. logging.info("Pruned model is saved at {}".format(save_dir))
  399. def _prepare_qat(self, quant_config):
  400. if quant_config is None:
  401. # default quantization configuration
  402. quant_config = {
  403. # {None, 'PACT'}. Weight preprocess type. If None, no preprocessing is performed.
  404. 'weight_preprocess_type': None,
  405. # {None, 'PACT'}. Activation preprocess type. If None, no preprocessing is performed.
  406. 'activation_preprocess_type': None,
  407. # {'abs_max', 'channel_wise_abs_max', 'range_abs_max', 'moving_average_abs_max'}.
  408. # Weight quantization type.
  409. 'weight_quantize_type': 'channel_wise_abs_max',
  410. # {'abs_max', 'range_abs_max', 'moving_average_abs_max'}. Activation quantization type.
  411. 'activation_quantize_type': 'moving_average_abs_max',
  412. # The number of bits of weights after quantization.
  413. 'weight_bits': 8,
  414. # The number of bits of activation after quantization.
  415. 'activation_bits': 8,
  416. # Data type after quantization, such as 'uint8', 'int8', etc.
  417. 'dtype': 'int8',
  418. # Window size for 'range_abs_max' quantization.
  419. 'window_size': 10000,
  420. # Decay coefficient of moving average.
  421. 'moving_rate': .9,
  422. # Types of layers that will be quantized.
  423. 'quantizable_layer_type': ['Conv2D', 'Linear']
  424. }
  425. self.quant_config = quant_config
  426. self.quantizer = QAT(config=self.quant_config)
  427. logging.info("Preparing the model for quantization-aware training...")
  428. self.quantizer.quantize(self.net)
  429. logging.info("Model is ready for quantization-aware training.")
  430. self.status = 'Quantized'
  431. def _export_inference_model(self, save_dir, image_shape=None):
  432. save_dir = osp.join(save_dir, 'inference_model')
  433. self.net.eval()
  434. self.test_inputs = self._get_test_inputs(image_shape)
  435. if self.status == 'Quantized':
  436. self.quantizer.save_quantized_model(self.net,
  437. osp.join(save_dir, 'model'),
  438. self.test_inputs)
  439. quant_info = self.get_quant_info()
  440. with open(
  441. osp.join(save_dir, 'quant.yml'), encoding='utf-8',
  442. mode='w') as f:
  443. yaml.dump(quant_info, f)
  444. else:
  445. static_net = paddle.jit.to_static(
  446. self.net, input_spec=self.test_inputs)
  447. paddle.jit.save(static_net, osp.join(save_dir, 'model'))
  448. if self.status == 'Pruned':
  449. pruning_info = self.get_pruning_info()
  450. with open(
  451. osp.join(save_dir, 'prune.yml'), encoding='utf-8',
  452. mode='w') as f:
  453. yaml.dump(pruning_info, f)
  454. model_info = self.get_model_info()
  455. model_info['status'] = 'Infer'
  456. with open(
  457. osp.join(save_dir, 'model.yml'), encoding='utf-8',
  458. mode='w') as f:
  459. yaml.dump(model_info, f)
  460. # 模型保存成功的标志
  461. open(osp.join(save_dir, '.success'), 'w').close()
  462. logging.info("The model for the inference deployment is saved in {}.".
  463. format(save_dir))