소스 검색

add baidu10w pretrain weights for resetnet50_vd

FlyingQianMM 5 년 전
부모
커밋
c2f24a206d
4개의 변경된 파일110개의 추가작업 그리고 10개의 파일을 삭제
  1. 1 1
      docs/apis/models/classification.md
  2. 1 1
      paddlex/cv/models/base.py
  3. 67 8
      paddlex/cv/models/classifier.py
  4. 41 0
      paddlex/cv/models/utils/pretrain_weights.py

+ 1 - 1
docs/apis/models/classification.md

@@ -27,7 +27,7 @@ train(self, num_epochs, train_dataset, train_batch_size=64, eval_dataset=None, s
 > > - **save_interval_epochs** (int): 模型保存间隔(单位:迭代轮数)。默认为1。
 > > - **log_interval_steps** (int): 训练日志输出间隔(单位:迭代步数)。默认为2。
 > > - **save_dir** (str): 模型保存路径。
-> > - **pretrain_weights** (str): 若指定为路径时,则加载路径下预训练模型;若为字符串'IMAGENET',则自动下载在ImageNet图片数据上预训练的模型权重;若为None,则不使用预训练模型。默认为'IMAGENET'。
+> > - **pretrain_weights** (str): 若指定为路径时,则加载路径下预训练模型;若为字符串'IMAGENET',则自动下载在ImageNet图片数据上预训练的模型权重;若为None,则不使用预训练模型。默认为'IMAGENET'。若模型为'ResNet50_vd',则默认下载百度自研10万类预训练模型,即默认为'BAIDU10W'。
 > > - **optimizer** (paddle.fluid.optimizer): 优化器。当该参数为None时,使用默认优化器:fluid.layers.piecewise_decay衰减策略,fluid.optimizer.Momentum优化方法。
 > > - **learning_rate** (float): 默认优化器的初始学习率。默认为0.025。
 > > - **warmup_steps** (int): 默认优化器的warmup步数,学习率将在设定的步数内,从warmup_start_lr线性增长至设定的learning_rate,默认为0。

+ 1 - 1
paddlex/cv/models/base.py

@@ -202,7 +202,7 @@ class BaseAPI:
             if pretrain_weights is not None and not os.path.exists(
                     pretrain_weights):
                 if self.model_type == 'classifier':
-                    if pretrain_weights not in ['IMAGENET']:
+                    if pretrain_weights not in ['IMAGENET', 'BAIDU10W']:
                         logging.warning(
                             "Path of pretrain_weights('{}') is not exists!".
                             format(pretrain_weights))

+ 67 - 8
paddlex/cv/models/classifier.py

@@ -279,7 +279,11 @@ class BaseClassifier(BaseAPI):
         return metrics
 
     @staticmethod
-    def _preprocess(images, transforms, model_type, class_name, thread_pool=None):
+    def _preprocess(images,
+                    transforms,
+                    model_type,
+                    class_name,
+                    thread_pool=None):
         arrange_transforms(
             model_type=model_type,
             class_name=class_name,
@@ -343,10 +347,7 @@ class BaseClassifier(BaseAPI):
 
         return preds[0]
 
-    def batch_predict(self,
-                      img_file_list,
-                      transforms=None,
-                      topk=1):
+    def batch_predict(self, img_file_list, transforms=None, topk=1):
         """预测。
         Args:
             img_file_list(list|tuple): 对列表(或元组)中的图像同时进行预测,列表中的元素可以是图像路径
@@ -365,9 +366,9 @@ class BaseClassifier(BaseAPI):
 
         if transforms is None:
             transforms = self.test_transforms
-        im = BaseClassifier._preprocess(img_file_list, transforms,
-                                        self.model_type,
-                                        self.__class__.__name__, self.thread_pool)
+        im = BaseClassifier._preprocess(
+            img_file_list, transforms, self.model_type,
+            self.__class__.__name__, self.thread_pool)
 
         with fluid.scope_guard(self.scope):
             result = self.exe.run(self.test_prog,
@@ -409,6 +410,64 @@ class ResNet50_vd(BaseClassifier):
         super(ResNet50_vd, self).__init__(
             model_name='ResNet50_vd', num_classes=num_classes)
 
+    def train(self,
+              num_epochs,
+              train_dataset,
+              train_batch_size=64,
+              eval_dataset=None,
+              save_interval_epochs=1,
+              log_interval_steps=2,
+              save_dir='output',
+              pretrain_weights='BAIDU10W',
+              optimizer=None,
+              learning_rate=0.025,
+              warmup_steps=0,
+              warmup_start_lr=0.0,
+              lr_decay_epochs=[30, 60, 90],
+              lr_decay_gamma=0.1,
+              use_vdl=False,
+              sensitivities_file=None,
+              eval_metric_loss=0.05,
+              early_stop=False,
+              early_stop_patience=5,
+              resume_checkpoint=None):
+        """训练。
+        Args:
+            num_epochs (int): 训练迭代轮数。
+            train_dataset (paddlex.datasets): 训练数据读取器。
+            train_batch_size (int): 训练数据batch大小。同时作为验证数据batch大小。默认值为64。
+            eval_dataset (paddlex.datasets: 验证数据读取器。
+            save_interval_epochs (int): 模型保存间隔(单位:迭代轮数)。默认为1。
+            log_interval_steps (int): 训练日志输出间隔(单位:迭代步数)。默认为2。
+            save_dir (str): 模型保存路径。
+            pretrain_weights (str): 若指定为路径时,则加载路径下预训练模型;若为字符串'IMAGENET',
+                则自动下载在ImageNet图片数据上预训练的模型权重;若为None,则不使用预训练模型。若为'BAIDU10W',则自动下载百度自研10万类预训练。默认为'BAIDU10W'。
+            optimizer (paddle.fluid.optimizer): 优化器。当该参数为None时,使用默认优化器:
+                fluid.layers.piecewise_decay衰减策略,fluid.optimizer.Momentum优化方法。
+            learning_rate (float): 默认优化器的初始学习率。默认为0.025。
+            warmup_steps(int): 学习率从warmup_start_lr上升至设定的learning_rate,所需的步数,默认为0
+            warmup_start_lr(float): 学习率在warmup阶段时的起始值,默认为0.0
+            lr_decay_epochs (list): 默认优化器的学习率衰减轮数。默认为[30, 60, 90]。
+            lr_decay_gamma (float): 默认优化器的学习率衰减率。默认为0.1。
+            use_vdl (bool): 是否使用VisualDL进行可视化。默认值为False。
+            sensitivities_file (str): 若指定为路径时,则加载路径下敏感度信息进行裁剪;若为字符串'DEFAULT',
+                则自动下载在ImageNet图片数据上获得的敏感度信息进行裁剪;若为None,则不进行裁剪。默认为None。
+            eval_metric_loss (float): 可容忍的精度损失。默认为0.05。
+            early_stop (bool): 是否使用提前终止训练策略。默认值为False。
+            early_stop_patience (int): 当使用提前终止训练策略时,如果验证集精度在`early_stop_patience`个epoch内
+                连续下降或持平,则终止训练。默认值为5。
+            resume_checkpoint (str): 恢复训练时指定上次训练保存的模型路径。若为None,则不会恢复训练。默认值为None。
+        Raises:
+            ValueError: 模型从inference model进行加载。
+        """
+        return super(ResNet50_vd, self).train(
+            num_epochs, train_dataset, train_batch_size, eval_dataset,
+            save_interval_epochs, log_interval_steps, save_dir,
+            pretrain_weights, optimizer, learning_rate, warmup_steps,
+            warmup_start_lr, lr_decay_epochs, lr_decay_gamma, use_vdl,
+            sensitivities_file, eval_metric_loss, early_stop,
+            early_stop_patience, resume_checkpoint)
+
 
 class ResNet101_vd(BaseClassifier):
     def __init__(self, num_classes=1000):

+ 41 - 0
paddlex/cv/models/utils/pretrain_weights.py

@@ -76,6 +76,11 @@ image_pretrain = {
     'http://paddle-imagenet-models-name.bj.bcebos.com/AlexNet_pretrained.tar'
 }
 
+baidu10w_pretrain = {
+    'ResNet50_vd_BAIDU10W':
+    'https://paddle-imagenet-models-name.bj.bcebos.com/ResNet50_vd_10w_pretrained.tar'
+}
+
 coco_pretrain = {
     'YOLOv3_DarkNet53_COCO':
     'https://paddlemodels.bj.bcebos.com/object_detection/yolov3_darknet.tar',
@@ -180,6 +185,11 @@ def get_pretrain_weights(flag, class_name, backbone, save_dir):
         elif class_name == 'FastSCNN':
             logging.warning(warning_info.format(class_name, flag, 'CITYSCAPES'))
             flag = 'CITYSCAPES'
+    elif flag == 'BAIDU10W':
+        if class_name not in ['ResNet50_vd']:
+            raise Exception(
+                "Only the classifier ResNet50_vd supports BAIDU10W pretrained weights"
+            )
 
     if flag == 'IMAGENET':
         new_save_dir = save_dir
@@ -244,6 +254,37 @@ def get_pretrain_weights(flag, class_name, backbone, save_dir):
         if getattr(paddlex, 'gui_mode', False):
             paddlex.utils.download_and_decompress(url, path=new_save_dir)
             return osp.join(new_save_dir, fname)
+        try:
+            logging.info(
+                "Connecting PaddleHub server to get pretrain weights...")
+            hub.download(backbone, save_path=new_save_dir)
+        except Exception as e:
+            logging.error(
+                "Couldn't download pretrain weight, you can download it manualy from {} (decompress the file if it is a compressed file), and set pretrain weights by your self".
+                format(url),
+                exit=False)
+            if isinstance(hub.ResourceNotFoundError):
+                raise Exception("Resource for backbone {} not found".format(
+                    backbone))
+            elif isinstance(hub.ServerConnectionError):
+                raise Exception(
+                    "Cannot get reource for backbone {}, please check your internet connection"
+                    .format(backbone))
+            else:
+                raise Exception(
+                    "Unexpected error, please make sure paddlehub >= 1.6.2")
+        return osp.join(new_save_dir, backbone)
+    elif flag == 'BAIDU10W':
+        new_save_dir = save_dir
+        if hasattr(paddlex, 'pretrain_dir'):
+            new_save_dir = paddlex.pretrain_dir
+        backbone = backbone + '_BAIDU10W'
+        url = baidu10w_pretrain[backbone]
+        fname = osp.split(url)[-1].split('.')[0]
+
+        if getattr(paddlex, 'gui_mode', False):
+            paddlex.utils.download_and_decompress(url, path=new_save_dir)
+            return osp.join(new_save_dir, fname)
 
         import paddlehub as hub
         try: