|
|
@@ -147,15 +147,22 @@ class ClsConfig(BaseConfig):
|
|
|
pretrained_model (str): the local path or url of pretrained weight file to set.
|
|
|
"""
|
|
|
assert isinstance(
|
|
|
- pretrained_model, (str, None)
|
|
|
+ pretrained_model, (str, type(None))
|
|
|
), "The 'pretrained_model' should be a string, indicating the path to the '*.pdparams' file, or 'None', \
|
|
|
indicating that no pretrained model to be used."
|
|
|
|
|
|
- if pretrained_model and not pretrained_model.startswith(
|
|
|
- ('http://', 'https://')):
|
|
|
- pretrained_model = abspath(
|
|
|
- pretrained_model.replace(".pdparams", ""))
|
|
|
- self.update([f'Global.pretrained_model={pretrained_model}'])
|
|
|
+ if pretrained_model is None:
|
|
|
+ self.update(['Global.pretrained_model=None'])
|
|
|
+ self.update(['Arch.pretrained=False'])
|
|
|
+ else:
|
|
|
+ if pretrained_model.lower() == "default":
|
|
|
+ self.update(['Global.pretrained_model=None'])
|
|
|
+ self.update(['Arch.pretrained=True'])
|
|
|
+ else:
|
|
|
+ if not pretrained_model.startswith(('http://', 'https://')):
|
|
|
+ pretrained_model = abspath(
|
|
|
+ pretrained_model.replace(".pdparams", ""))
|
|
|
+ self.update([f'Global.pretrained_model={pretrained_model}'])
|
|
|
|
|
|
def update_num_classes(self, num_classes: int):
|
|
|
"""update classes number
|