model_paths.py 1.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748
  1. # Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from os import PathLike
  15. from pathlib import Path
  16. from typing import Tuple, TypedDict, Union
  17. from ...constants import MODEL_FILE_PREFIX
  18. class ModelPaths(TypedDict, total=False):
  19. paddle: Tuple[Path, Path]
  20. onnx: Path
  21. om: Path
  22. def get_model_paths(
  23. model_dir: Union[str, PathLike],
  24. model_file_prefix: str = MODEL_FILE_PREFIX,
  25. ) -> ModelPaths:
  26. model_dir = Path(model_dir)
  27. model_paths: ModelPaths = {}
  28. pd_model_path = None
  29. if (model_dir / f"{model_file_prefix}.json").exists():
  30. pd_model_path = model_dir / f"{model_file_prefix}.json"
  31. elif (model_dir / f"{model_file_prefix}.pdmodel").exists():
  32. pd_model_path = model_dir / f"{model_file_prefix}.pdmodel"
  33. if pd_model_path and (model_dir / f"{model_file_prefix}.pdiparams").exists():
  34. model_paths["paddle"] = (
  35. pd_model_path,
  36. model_dir / f"{model_file_prefix}.pdiparams",
  37. )
  38. if (model_dir / f"{model_file_prefix}.onnx").exists():
  39. model_paths["onnx"] = model_dir / f"{model_file_prefix}.onnx"
  40. if (model_dir / f"{model_file_prefix}.om").exists():
  41. model_paths["om"] = model_dir / f"{model_file_prefix}.om"
  42. return model_paths