瀏覽代碼

fix_uniform

zhangyubo0722 1 年之前
父節點
當前提交
b3e3f73baa

+ 4 - 1
paddlex/modules/base/evaluator.py

@@ -71,8 +71,11 @@ class BaseEvaluator(ABC, metaclass=AutoRegisterABCMetaClass):
 
         config_path = Path(weight_path).parent / "config.yaml"
         if not config_path.exists():
+            config_path = config_path.parent.parent / "config.yaml"
+
+        if not config_path.exists():
             warning(
-                f"The config file(`{config_path}`) related to weight file(`{weight_path}`) is not exist, use default instead."
+                f"The config file (`{config_path}`) related to weight file (`{weight_path}`) does not exist. Using default instead."
             )
             config_path = None
         return config_path

+ 3 - 3
paddlex/repo_apis/PaddleClas_api/cls/register.py

@@ -486,7 +486,7 @@ register_model_info(
         "config_path": osp.join(PDX_CONFIG_DIR, "MobileNetV1_x0_5.yaml"),
         "supported_apis": ["train", "evaluate", "predict", "export", "infer"],
         "infer_config": "deploy/configs/inference_cls.yaml",
-        "hpi_config_path": HPI_CONFIG_DIR / ".yaml",
+        "hpi_config_path": HPI_CONFIG_DIR / "MobileNetV1_x0_5.yaml",
     }
 )
 
@@ -497,7 +497,7 @@ register_model_info(
         "config_path": osp.join(PDX_CONFIG_DIR, "MobileNetV1_x0_75.yaml"),
         "supported_apis": ["train", "evaluate", "predict", "export", "infer"],
         "infer_config": "deploy/configs/inference_cls.yaml",
-        "hpi_config_path": HPI_CONFIG_DIR / "MobileNetV1_x0_5.yaml",
+        "hpi_config_path": HPI_CONFIG_DIR / "MobileNetV1_x0_75.yaml",
     }
 )
 
@@ -608,7 +608,7 @@ register_model_info(
         "suite": "Cls",
         "config_path": osp.join(PDX_CONFIG_DIR, "MobileNetV3_large_x1_25.yaml"),
         "supported_apis": ["train", "evaluate", "predict", "export"],
-        "hpi_config_path": HPI_CONFIG_DIR / ".yaml",
+        "hpi_config_path": HPI_CONFIG_DIR / "MobileNetV3_large_x1_25.yaml",
     }
 )
 

+ 3 - 0
paddlex/repo_apis/PaddleTS_api/ts_base/model.py

@@ -108,6 +108,8 @@ class TSModel(BaseModel):
         else:
             if num_workers is not None:
                 cli_args.append(CLIArgument("--num_workers", num_workers))
+        config.update({"uniform_output_enabled": True})
+        config.update({"pdx_model_name": self.name})
 
         self._assert_empty_kwargs(kwargs)
 
@@ -236,6 +238,7 @@ class TSModel(BaseModel):
             # Update YAML config file
             config = self.config.copy()
             config.update_pretrained_weights(weight_path)
+            config.update({"pdx_model_name": self.name})
             config.dump(config_path)
 
             return self.runner.export(config_path, cli_args, device)