|
@@ -279,7 +279,11 @@ class BaseClassifier(BaseAPI):
|
|
|
return metrics
|
|
return metrics
|
|
|
|
|
|
|
|
@staticmethod
|
|
@staticmethod
|
|
|
- def _preprocess(images, transforms, model_type, class_name, thread_pool=None):
|
|
|
|
|
|
|
+ def _preprocess(images,
|
|
|
|
|
+ transforms,
|
|
|
|
|
+ model_type,
|
|
|
|
|
+ class_name,
|
|
|
|
|
+ thread_pool=None):
|
|
|
arrange_transforms(
|
|
arrange_transforms(
|
|
|
model_type=model_type,
|
|
model_type=model_type,
|
|
|
class_name=class_name,
|
|
class_name=class_name,
|
|
@@ -343,10 +347,7 @@ class BaseClassifier(BaseAPI):
|
|
|
|
|
|
|
|
return preds[0]
|
|
return preds[0]
|
|
|
|
|
|
|
|
- def batch_predict(self,
|
|
|
|
|
- img_file_list,
|
|
|
|
|
- transforms=None,
|
|
|
|
|
- topk=1):
|
|
|
|
|
|
|
+ def batch_predict(self, img_file_list, transforms=None, topk=1):
|
|
|
"""预测。
|
|
"""预测。
|
|
|
Args:
|
|
Args:
|
|
|
img_file_list(list|tuple): 对列表(或元组)中的图像同时进行预测,列表中的元素可以是图像路径
|
|
img_file_list(list|tuple): 对列表(或元组)中的图像同时进行预测,列表中的元素可以是图像路径
|
|
@@ -365,9 +366,9 @@ class BaseClassifier(BaseAPI):
|
|
|
|
|
|
|
|
if transforms is None:
|
|
if transforms is None:
|
|
|
transforms = self.test_transforms
|
|
transforms = self.test_transforms
|
|
|
- im = BaseClassifier._preprocess(img_file_list, transforms,
|
|
|
|
|
- self.model_type,
|
|
|
|
|
- self.__class__.__name__, self.thread_pool)
|
|
|
|
|
|
|
+ im = BaseClassifier._preprocess(
|
|
|
|
|
+ img_file_list, transforms, self.model_type,
|
|
|
|
|
+ self.__class__.__name__, self.thread_pool)
|
|
|
|
|
|
|
|
with fluid.scope_guard(self.scope):
|
|
with fluid.scope_guard(self.scope):
|
|
|
result = self.exe.run(self.test_prog,
|
|
result = self.exe.run(self.test_prog,
|
|
@@ -409,6 +410,64 @@ class ResNet50_vd(BaseClassifier):
|
|
|
super(ResNet50_vd, self).__init__(
|
|
super(ResNet50_vd, self).__init__(
|
|
|
model_name='ResNet50_vd', num_classes=num_classes)
|
|
model_name='ResNet50_vd', num_classes=num_classes)
|
|
|
|
|
|
|
|
|
|
+ def train(self,
|
|
|
|
|
+ num_epochs,
|
|
|
|
|
+ train_dataset,
|
|
|
|
|
+ train_batch_size=64,
|
|
|
|
|
+ eval_dataset=None,
|
|
|
|
|
+ save_interval_epochs=1,
|
|
|
|
|
+ log_interval_steps=2,
|
|
|
|
|
+ save_dir='output',
|
|
|
|
|
+ pretrain_weights='BAIDU10W',
|
|
|
|
|
+ optimizer=None,
|
|
|
|
|
+ learning_rate=0.025,
|
|
|
|
|
+ warmup_steps=0,
|
|
|
|
|
+ warmup_start_lr=0.0,
|
|
|
|
|
+ lr_decay_epochs=[30, 60, 90],
|
|
|
|
|
+ lr_decay_gamma=0.1,
|
|
|
|
|
+ use_vdl=False,
|
|
|
|
|
+ sensitivities_file=None,
|
|
|
|
|
+ eval_metric_loss=0.05,
|
|
|
|
|
+ early_stop=False,
|
|
|
|
|
+ early_stop_patience=5,
|
|
|
|
|
+ resume_checkpoint=None):
|
|
|
|
|
+ """训练。
|
|
|
|
|
+ Args:
|
|
|
|
|
+ num_epochs (int): 训练迭代轮数。
|
|
|
|
|
+ train_dataset (paddlex.datasets): 训练数据读取器。
|
|
|
|
|
+ train_batch_size (int): 训练数据batch大小。同时作为验证数据batch大小。默认值为64。
|
|
|
|
|
+ eval_dataset (paddlex.datasets: 验证数据读取器。
|
|
|
|
|
+ save_interval_epochs (int): 模型保存间隔(单位:迭代轮数)。默认为1。
|
|
|
|
|
+ log_interval_steps (int): 训练日志输出间隔(单位:迭代步数)。默认为2。
|
|
|
|
|
+ save_dir (str): 模型保存路径。
|
|
|
|
|
+ pretrain_weights (str): 若指定为路径时,则加载路径下预训练模型;若为字符串'IMAGENET',
|
|
|
|
|
+ 则自动下载在ImageNet图片数据上预训练的模型权重;若为None,则不使用预训练模型。若为'BAIDU10W',则自动下载百度自研10万类预训练。默认为'BAIDU10W'。
|
|
|
|
|
+ optimizer (paddle.fluid.optimizer): 优化器。当该参数为None时,使用默认优化器:
|
|
|
|
|
+ fluid.layers.piecewise_decay衰减策略,fluid.optimizer.Momentum优化方法。
|
|
|
|
|
+ learning_rate (float): 默认优化器的初始学习率。默认为0.025。
|
|
|
|
|
+ warmup_steps(int): 学习率从warmup_start_lr上升至设定的learning_rate,所需的步数,默认为0
|
|
|
|
|
+ warmup_start_lr(float): 学习率在warmup阶段时的起始值,默认为0.0
|
|
|
|
|
+ lr_decay_epochs (list): 默认优化器的学习率衰减轮数。默认为[30, 60, 90]。
|
|
|
|
|
+ lr_decay_gamma (float): 默认优化器的学习率衰减率。默认为0.1。
|
|
|
|
|
+ use_vdl (bool): 是否使用VisualDL进行可视化。默认值为False。
|
|
|
|
|
+ sensitivities_file (str): 若指定为路径时,则加载路径下敏感度信息进行裁剪;若为字符串'DEFAULT',
|
|
|
|
|
+ 则自动下载在ImageNet图片数据上获得的敏感度信息进行裁剪;若为None,则不进行裁剪。默认为None。
|
|
|
|
|
+ eval_metric_loss (float): 可容忍的精度损失。默认为0.05。
|
|
|
|
|
+ early_stop (bool): 是否使用提前终止训练策略。默认值为False。
|
|
|
|
|
+ early_stop_patience (int): 当使用提前终止训练策略时,如果验证集精度在`early_stop_patience`个epoch内
|
|
|
|
|
+ 连续下降或持平,则终止训练。默认值为5。
|
|
|
|
|
+ resume_checkpoint (str): 恢复训练时指定上次训练保存的模型路径。若为None,则不会恢复训练。默认值为None。
|
|
|
|
|
+ Raises:
|
|
|
|
|
+ ValueError: 模型从inference model进行加载。
|
|
|
|
|
+ """
|
|
|
|
|
+ return super(ResNet50_vd, self).train(
|
|
|
|
|
+ num_epochs, train_dataset, train_batch_size, eval_dataset,
|
|
|
|
|
+ save_interval_epochs, log_interval_steps, save_dir,
|
|
|
|
|
+ pretrain_weights, optimizer, learning_rate, warmup_steps,
|
|
|
|
|
+ warmup_start_lr, lr_decay_epochs, lr_decay_gamma, use_vdl,
|
|
|
|
|
+ sensitivities_file, eval_metric_loss, early_stop,
|
|
|
|
|
+ early_stop_patience, resume_checkpoint)
|
|
|
|
|
+
|
|
|
|
|
|
|
|
class ResNet101_vd(BaseClassifier):
|
|
class ResNet101_vd(BaseClassifier):
|
|
|
def __init__(self, num_classes=1000):
|
|
def __init__(self, num_classes=1000):
|