utils.py 3.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109
  1. # Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import contextlib
  15. import os
  16. from typing import Optional, Union
  17. import paddle
  18. ASYMMETRY_QUANT_SCALE_MIN = "@min_scales"
  19. ASYMMETRY_QUANT_SCALE_MAX = "@max_scales"
  20. SYMMETRY_QUANT_SCALE = "@scales"
  21. CONFIG_NAME = "config.json"
  22. LEGACY_CONFIG_NAME = "model_config.json"
  23. PADDLE_WEIGHTS_NAME = "model_state.pdparams"
  24. PADDLE_WEIGHTS_INDEX_NAME = "model_state.pdparams.index.json"
  25. PYTORCH_WEIGHTS_INDEX_NAME = "pytorch_model.bin.index.json"
  26. PYTORCH_WEIGHTS_NAME = "pytorch_model.bin"
  27. SAFE_WEIGHTS_INDEX_NAME = "model.safetensors.index.json"
  28. SAFE_WEIGHTS_NAME = "model.safetensors"
  29. GENERATION_CONFIG_NAME = "generation_config.json"
  30. def resolve_file_path(
  31. pretrained_model_name_or_path: str = None,
  32. filenames: Union[str, list] = None,
  33. subfolder: Optional[str] = None,
  34. **kwargs,
  35. ):
  36. """
  37. This is a load function, mainly called by the from_pretrained function.
  38. Adapt for PaddleX inference.
  39. """
  40. assert (
  41. pretrained_model_name_or_path is not None
  42. ), "pretrained_model_name_or_path cannot be None"
  43. assert filenames is not None, "filenames cannot be None"
  44. subfolder = subfolder if subfolder is not None else ""
  45. if isinstance(filenames, str):
  46. filenames = [filenames]
  47. if os.path.isfile(pretrained_model_name_or_path):
  48. return pretrained_model_name_or_path
  49. elif os.path.isdir(pretrained_model_name_or_path):
  50. for index, filename in enumerate(filenames):
  51. if os.path.exists(
  52. os.path.join(pretrained_model_name_or_path, subfolder, filename)
  53. ):
  54. if not os.path.isfile(
  55. os.path.join(pretrained_model_name_or_path, subfolder, filename)
  56. ):
  57. raise EnvironmentError(
  58. f"{pretrained_model_name_or_path} does not appear to have file named {filename}."
  59. )
  60. return os.path.join(pretrained_model_name_or_path, subfolder, filename)
  61. elif index < len(filenames) - 1:
  62. continue
  63. else:
  64. raise FileNotFoundError(
  65. f"please make sure one of the {filenames} under the dir {pretrained_model_name_or_path}"
  66. )
  67. else:
  68. raise ValueError(
  69. "please make sure `pretrained_model_name_or_path` is either a file or a directory."
  70. )
  71. @contextlib.contextmanager
  72. def device_guard(device="cpu", dev_id=0):
  73. origin_device = paddle.device.get_device()
  74. if device == "cpu":
  75. paddle.set_device(device)
  76. elif device in ["gpu", "xpu", "npu"]:
  77. paddle.set_device("{}:{}".format(device, dev_id))
  78. try:
  79. yield
  80. finally:
  81. paddle.set_device(origin_device)
  82. def get_env_device():
  83. """
  84. Return the device name of running environment.
  85. """
  86. if paddle.is_compiled_with_cuda():
  87. return "gpu"
  88. elif "npu" in paddle.device.get_all_custom_device_type():
  89. return "npu"
  90. elif "gcu" in paddle.device.get_all_custom_device_type():
  91. return "gcu"
  92. elif paddle.is_compiled_with_rocm():
  93. return "rocm"
  94. elif paddle.is_compiled_with_xpu():
  95. return "xpu"
  96. return "cpu"