build_model.py 1.0 KB

12345678910111213141516171819202122232425262728293031323334
  1. # !/usr/bin/env python3
  2. # -*- coding: UTF-8 -*-
  3. ################################################################################
  4. #
  5. # Copyright (c) 2024 Baidu.com, Inc. All Rights Reserved
  6. #
  7. ################################################################################
  8. """
  9. Author: PaddlePaddle Authors
  10. """
  11. import os
  12. from ...repo_apis.base import Config, PaddleModel
  13. from ...utils.device import get_device
  14. def build_model(model_name: str, device: str=None,
  15. config_path: str=None) -> tuple:
  16. """build Config and PaddleModel
  17. Args:
  18. model_name (str): model name
  19. device (str): device, such as gpu, cpu, npu, xpu, mlu
  20. config_path (str, optional): path to the PaddleX config yaml file.
  21. Defaults to None, i.e. using the default config file.
  22. Returns:
  23. tuple(Config, PaddleModel): the Config and PaddleModel
  24. """
  25. config = Config(model_name, config_path)
  26. if device:
  27. config.update_device(get_device(device))
  28. model = PaddleModel(config=config)
  29. return config, model