|
|
@@ -75,7 +75,7 @@ class BaseDetector(BaseModel):
|
|
|
|
|
|
def _check_image_shape(self, image_shape):
|
|
|
if len(image_shape) == 2:
|
|
|
- image_shape = [None, 3] + image_shape
|
|
|
+ image_shape = [1, 3] + image_shape
|
|
|
if image_shape[-2] % 32 > 0 or image_shape[-1] % 32 > 0:
|
|
|
raise Exception(
|
|
|
"Height and width in fixed_input_shape must be a multiple of 32, but received {}.".
|
|
|
@@ -88,6 +88,7 @@ class BaseDetector(BaseModel):
|
|
|
self._fix_transforms_shape(image_shape[-2:])
|
|
|
else:
|
|
|
image_shape = [None, 3, -1, -1]
|
|
|
+ self.fixed_input_shape = image_shape
|
|
|
|
|
|
return self._define_input_spec(image_shape)
|
|
|
|
|
|
@@ -158,7 +159,8 @@ class BaseDetector(BaseModel):
|
|
|
use_ema=False,
|
|
|
early_stop=False,
|
|
|
early_stop_patience=5,
|
|
|
- use_vdl=True):
|
|
|
+ use_vdl=True,
|
|
|
+ resume_checkpoint=None):
|
|
|
"""
|
|
|
Train the model.
|
|
|
Args:
|
|
|
@@ -185,8 +187,15 @@ class BaseDetector(BaseModel):
|
|
|
early_stop(bool, optional): Whether to adopt early stop strategy. Defaults to False.
|
|
|
early_stop_patience(int, optional): Early stop patience. Defaults to 5.
|
|
|
use_vdl(bool, optional): Whether to use VisualDL to monitor the training process. Defaults to True.
|
|
|
+ resume_checkpoint(str or None, optional): The path of the checkpoint to resume training from.
|
|
|
+ If None, no training checkpoint will be resumed. At most one of `resume_checkpoint` and
|
|
|
+ `pretrain_weights` can be set simultaneously. Defaults to None.
|
|
|
|
|
|
"""
|
|
|
+ if pretrain_weights is not None and resume_checkpoint is not None:
|
|
|
+ logging.error(
|
|
|
+ "pretrain_weights and resume_checkpoint cannot be set simultaneously.",
|
|
|
+ exit=True)
|
|
|
if train_dataset.__class__.__name__ == 'VOCDetection':
|
|
|
train_dataset.data_fields = {
|
|
|
'im_id', 'image_shape', 'image', 'gt_bbox', 'gt_class',
|
|
|
@@ -253,7 +262,9 @@ class BaseDetector(BaseModel):
|
|
|
exit=True)
|
|
|
pretrained_dir = osp.join(save_dir, 'pretrain')
|
|
|
self.net_initialize(
|
|
|
- pretrain_weights=pretrain_weights, save_dir=pretrained_dir)
|
|
|
+ pretrain_weights=pretrain_weights,
|
|
|
+ save_dir=pretrained_dir,
|
|
|
+ resume_checkpoint=resume_checkpoint)
|
|
|
|
|
|
if use_ema:
|
|
|
ema = ExponentialMovingAverage(
|
|
|
@@ -293,6 +304,7 @@ class BaseDetector(BaseModel):
|
|
|
early_stop=False,
|
|
|
early_stop_patience=5,
|
|
|
use_vdl=True,
|
|
|
+ resume_checkpoint=None,
|
|
|
quant_config=None):
|
|
|
"""
|
|
|
Quantization-aware training.
|
|
|
@@ -320,6 +332,8 @@ class BaseDetector(BaseModel):
|
|
|
use_vdl(bool, optional): Whether to use VisualDL to monitor the training process. Defaults to True.
|
|
|
quant_config(dict or None, optional): Quantization configuration. If None, a default rule of thumb
|
|
|
configuration will be used. Defaults to None.
|
|
|
+ resume_checkpoint(str or None, optional): The path of the checkpoint to resume quantization-aware training
|
|
|
+ from. If None, no training checkpoint will be resumed. Defaults to None.
|
|
|
|
|
|
"""
|
|
|
self._prepare_qat(quant_config)
|
|
|
@@ -342,7 +356,8 @@ class BaseDetector(BaseModel):
|
|
|
use_ema=use_ema,
|
|
|
early_stop=early_stop,
|
|
|
early_stop_patience=early_stop_patience,
|
|
|
- use_vdl=use_vdl)
|
|
|
+ use_vdl=use_vdl,
|
|
|
+ resume_checkpoint=resume_checkpoint)
|
|
|
|
|
|
def evaluate(self,
|
|
|
eval_dataset,
|
|
|
@@ -1020,6 +1035,7 @@ class FasterRCNN(BaseDetector):
|
|
|
self.test_transforms.transforms.append(
|
|
|
Padding(im_padding_value=[0., 0., 0.]))
|
|
|
|
|
|
+ self.fixed_input_shape = image_shape
|
|
|
return self._define_input_spec(image_shape)
|
|
|
|
|
|
|
|
|
@@ -1414,14 +1430,10 @@ class PPYOLOv2(YOLOv3):
|
|
|
|
|
|
def _get_test_inputs(self, image_shape):
|
|
|
if image_shape is not None:
|
|
|
- if len(image_shape) == 2:
|
|
|
- image_shape = [None, 3] + image_shape
|
|
|
- if image_shape[-2] % 32 > 0 or image_shape[-1] % 32 > 0:
|
|
|
- raise Exception(
|
|
|
- "Height and width in fixed_input_shape must be a multiple of 32, but recieved is {}.".
|
|
|
- format(image_shape[-2:]))
|
|
|
+ image_shape = self._check_image_shape(image_shape)
|
|
|
self._fix_transforms_shape(image_shape[-2:])
|
|
|
else:
|
|
|
+ image_shape = [None, 3, 608, 608]
|
|
|
logging.warning(
|
|
|
'[Important!!!] When exporting inference model for {},'.format(
|
|
|
self.__class__.__name__) +
|
|
|
@@ -1429,20 +1441,9 @@ class PPYOLOv2(YOLOv3):
|
|
|
+
|
|
|
'Please check image shape after transforms is [3, 608, 608], if not, fixed_input_shape '
|
|
|
+ 'should be specified manually.')
|
|
|
- image_shape = [None, 3, 608, 608]
|
|
|
-
|
|
|
- input_spec = [{
|
|
|
- "image": InputSpec(
|
|
|
- shape=image_shape, name='image', dtype='float32'),
|
|
|
- "im_shape": InputSpec(
|
|
|
- shape=[image_shape[0], 2], name='im_shape', dtype='float32'),
|
|
|
- "scale_factor": InputSpec(
|
|
|
- shape=[image_shape[0], 2],
|
|
|
- name='scale_factor',
|
|
|
- dtype='float32')
|
|
|
- }]
|
|
|
|
|
|
- return input_spec
|
|
|
+ self.fixed_input_shape = image_shape
|
|
|
+ return self._define_input_spec(image_shape)
|
|
|
|
|
|
|
|
|
class MaskRCNN(BaseDetector):
|
|
|
@@ -1741,5 +1742,6 @@ class MaskRCNN(BaseDetector):
|
|
|
if self.with_fpn:
|
|
|
self.test_transforms.transforms.append(
|
|
|
Padding(im_padding_value=[0., 0., 0.]))
|
|
|
+ self.fixed_input_shape = image_shape
|
|
|
|
|
|
return self._define_input_spec(image_shape)
|