hub_model_server.py 4.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134
  1. # copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import json
  15. import requests
  16. from typing import List
  17. from .hub_config import config
  18. class ServerConnectionError(Exception):
  19. def __init__(self, url: str):
  20. self.url = url
  21. def __str__(self):
  22. tips = "Can't connect to UltraInfer Model Server: {}".format(self.url)
  23. return tips
  24. class ModelServer(object):
  25. """
  26. UltraInfer server source
  27. Args:
  28. url(str) : Url of the server
  29. timeout(int) : Request timeout
  30. """
  31. def __init__(self, url: str, timeout: int = 10):
  32. self._url = url
  33. self._timeout = timeout
  34. def search_model(
  35. self, name: str, format: str = None, version: str = None
  36. ) -> List[dict]:
  37. """
  38. Search model from model server.
  39. Args:
  40. name(str) : UltraInfer model name
  41. format(str): UltraInfer model format
  42. version(str) : UltraInfer model version
  43. Return:
  44. result(list): search results
  45. """
  46. params = {}
  47. params["name"] = name
  48. if format:
  49. params["format"] = format
  50. if version:
  51. params["version"] = version
  52. result = self.request(path="ultra_infer_search", params=params)
  53. if result["status"] == 0 and len(result["data"]) > 0:
  54. return result["data"]
  55. return None
  56. def stat_model(self, name: str, format: str, version: str):
  57. """
  58. Note a record when download a model for statistics.
  59. Args:
  60. name(str) : UltraInfer model name
  61. format(str): UltraInfer model format
  62. version(str) : UltraInfer model version
  63. Return:
  64. is_successful(bool): True if successful, False otherwise
  65. """
  66. params = {}
  67. params["name"] = name
  68. params["format"] = format
  69. params["version"] = version
  70. params["from"] = "ultra_infer"
  71. try:
  72. result = self.request(path="stat", params=params)
  73. except Exception:
  74. return False
  75. if result["status"] == 0:
  76. return True
  77. else:
  78. return False
  79. def request(self, path: str, params: dict) -> dict:
  80. """Request server."""
  81. api = "{}/{}".format(self._url, path)
  82. try:
  83. result = requests.get(api, params, timeout=self._timeout)
  84. return result.json()
  85. except requests.exceptions.ConnectionError as e:
  86. raise ServerConnectionError(self._url)
  87. def get_model_list(self):
  88. """
  89. Get all pre-trained models information in dataset.
  90. Return:
  91. result(dict): key is category name, value is a list which contains models \
  92. information such as name, format and version.
  93. """
  94. api = "{}/{}".format(self._url, "ultra_infer_listmodels")
  95. try:
  96. result = requests.get(api, timeout=self._timeout)
  97. return result.json()
  98. except requests.exceptions.ConnectionError as e:
  99. raise ServerConnectionError(self._url)
  100. def is_connected(self):
  101. return self.check(self._url)
  102. @classmethod
  103. def check(cls, url: str) -> bool:
  104. """
  105. Check if the specified url is a valid model server
  106. Args:
  107. url(str) : Url to check
  108. """
  109. try:
  110. r = requests.get(url + "/search")
  111. return r.status_code == 200
  112. except:
  113. return False
  114. model_server = ModelServer(config.server)