Sfoglia il codice sorgente

strip '_ssld' while checking model_name

will-jl944 4 anni fa
parent
commit
de4d1d1799
1 ha cambiato i file con 1 aggiunte e 1 eliminazioni
  1. 1 1
      dygraph/paddlex/cv/models/classifier.py

+ 1 - 1
dygraph/paddlex/cv/models/classifier.py

@@ -54,7 +54,7 @@ class BaseClassifier(BaseModel):
         self.init_params.update(params)
         self.init_params.update(params)
         del self.init_params['params']
         del self.init_params['params']
         super(BaseClassifier, self).__init__('classifier')
         super(BaseClassifier, self).__init__('classifier')
-        if not hasattr(architectures, model_name):
+        if not hasattr(architectures, model_name.strip('_ssld')):
             raise Exception("ERROR: There's no model named {}.".format(
             raise Exception("ERROR: There's no model named {}.".format(
                 model_name))
                 model_name))