hub_model_server.py 4.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133
  1. # Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import requests
  15. from typing import List
  16. from .hub_config import config
  17. class ServerConnectionError(Exception):
  18. def __init__(self, url: str):
  19. self.url = url
  20. def __str__(self):
  21. tips = "Can't connect to UltraInfer Model Server: {}".format(self.url)
  22. return tips
  23. class ModelServer(object):
  24. """
  25. UltraInfer server source
  26. Args:
  27. url(str) : Url of the server
  28. timeout(int) : Request timeout
  29. """
  30. def __init__(self, url: str, timeout: int = 10):
  31. self._url = url
  32. self._timeout = timeout
  33. def search_model(
  34. self, name: str, format: str = None, version: str = None
  35. ) -> List[dict]:
  36. """
  37. Search model from model server.
  38. Args:
  39. name(str) : UltraInfer model name
  40. format(str): UltraInfer model format
  41. version(str) : UltraInfer model version
  42. Return:
  43. result(list): search results
  44. """
  45. params = {}
  46. params["name"] = name
  47. if format:
  48. params["format"] = format
  49. if version:
  50. params["version"] = version
  51. result = self.request(path="ultra_infer_search", params=params)
  52. if result["status"] == 0 and len(result["data"]) > 0:
  53. return result["data"]
  54. return None
  55. def stat_model(self, name: str, format: str, version: str):
  56. """
  57. Note a record when download a model for statistics.
  58. Args:
  59. name(str) : UltraInfer model name
  60. format(str): UltraInfer model format
  61. version(str) : UltraInfer model version
  62. Return:
  63. is_successful(bool): True if successful, False otherwise
  64. """
  65. params = {}
  66. params["name"] = name
  67. params["format"] = format
  68. params["version"] = version
  69. params["from"] = "ultra_infer"
  70. try:
  71. result = self.request(path="stat", params=params)
  72. except Exception:
  73. return False
  74. if result["status"] == 0:
  75. return True
  76. else:
  77. return False
  78. def request(self, path: str, params: dict) -> dict:
  79. """Request server."""
  80. api = "{}/{}".format(self._url, path)
  81. try:
  82. result = requests.get(api, params, timeout=self._timeout)
  83. return result.json()
  84. except requests.exceptions.ConnectionError as e:
  85. raise ServerConnectionError(self._url)
  86. def get_model_list(self):
  87. """
  88. Get all pre-trained models information in dataset.
  89. Return:
  90. result(dict): key is category name, value is a list which contains models \
  91. information such as name, format and version.
  92. """
  93. api = "{}/{}".format(self._url, "ultra_infer_listmodels")
  94. try:
  95. result = requests.get(api, timeout=self._timeout)
  96. return result.json()
  97. except requests.exceptions.ConnectionError as e:
  98. raise ServerConnectionError(self._url)
  99. def is_connected(self):
  100. return self.check(self._url)
  101. @classmethod
  102. def check(cls, url: str) -> bool:
  103. """
  104. Check if the specified url is a valid model server
  105. Args:
  106. url(str) : Url to check
  107. """
  108. try:
  109. r = requests.get(url + "/search")
  110. return r.status_code == 200
  111. except:
  112. return False
  113. model_server = ModelServer(config.server)