base.py 5.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171
  1. # copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import json
  15. import shutil
  16. import tempfile
  17. from pathlib import Path
  18. from types import GeneratorType
  19. import pytest
  20. from tests.testing_utils.download import download, download_and_extract
  21. from tests.testing_utils.misc import get_filename
  22. NUM_INPUT_FILES = 10
  23. DEVICES = ["cpu", "gpu:0"]
  24. BATCH_SIZES = [1, 2, 4]
  25. class BaseTestPredictor(object):
  26. @property
  27. def model_dir(self):
  28. raise NotImplementedError
  29. @property
  30. def model_url(self):
  31. raise NotImplementedError
  32. @property
  33. def input_data_url(self):
  34. raise NotImplementedError
  35. @property
  36. def expected_result_url(self):
  37. raise NotImplementedError
  38. @property
  39. def expected_result_with_args_url(self):
  40. raise NotImplementedError
  41. @property
  42. def predictor_cls(self):
  43. raise NotImplementedError
  44. @property
  45. def should_test_with_args(self):
  46. return False
  47. @pytest.fixture(scope="class")
  48. def data_dir(self):
  49. with tempfile.TemporaryDirectory() as td:
  50. yield Path(td)
  51. @pytest.fixture(scope="class")
  52. def model_path(self, data_dir):
  53. download_and_extract(self.model_url, data_dir, "model")
  54. yield data_dir / "model"
  55. @pytest.fixture(scope="class")
  56. def input_data_path(self, data_dir):
  57. input_data_path = (data_dir / get_filename(self.input_data_url)).with_stem(
  58. "test"
  59. )
  60. download(self.input_data_url, input_data_path)
  61. yield input_data_path
  62. @pytest.fixture(scope="class")
  63. def expected_result(self, data_dir):
  64. expected_result_path = data_dir / "expected.json"
  65. download(self.expected_result_url, expected_result_path)
  66. with open(expected_result_path, "r", encoding="utf-8") as f:
  67. expected_result = json.load(f)
  68. yield expected_result
  69. @pytest.fixture(scope="class")
  70. def expected_result_with_args(self, data_dir):
  71. expected_result_with_args_path = data_dir / "expected_with_args.json"
  72. download(self.expected_result_with_args_url, expected_result_with_args_path)
  73. with open(expected_result_with_args_path, "r", encoding="utf-8") as f:
  74. expected_result = json.load(f)
  75. yield expected_result
  76. @pytest.mark.parametrize("device", DEVICES)
  77. def test___call__single_input_data(
  78. self, model_path, input_data_path, device, expected_result
  79. ):
  80. predictor = self.predictor_cls(model_path, device=device)
  81. output = predictor(str(input_data_path))
  82. self._check_output(output, expected_result, 1)
  83. output = predictor([str(input_data_path), str(input_data_path)])
  84. self._check_output(output, expected_result, 2)
  85. @pytest.mark.parametrize("device", DEVICES)
  86. @pytest.mark.parametrize("batch_size", BATCH_SIZES)
  87. def test___call__input_batch_data(
  88. self, model_path, input_data_path, device, batch_size, expected_result
  89. ):
  90. predictor = self.predictor_cls(model_path, device=device)
  91. predictor.set_predictor(batch_size=batch_size)
  92. output = predictor([str(input_data_path)] * NUM_INPUT_FILES)
  93. self._check_output(output, expected_result, NUM_INPUT_FILES)
  94. @pytest.mark.parametrize("device", DEVICES)
  95. def test__call__with_predictor_args(
  96. self, model_path, input_data_path, device, request
  97. ):
  98. if self.should_test_with_args:
  99. self._predict_with_predictor_args(
  100. model_path,
  101. input_data_path,
  102. device,
  103. request.getfixturevalue("expected_result_with_args"),
  104. )
  105. else:
  106. pytest.skip("Skipping test__call__with_predictor_args for this predictor")
  107. @pytest.mark.parametrize("device", DEVICES)
  108. def test__call__with_predict_args(
  109. self,
  110. model_path,
  111. input_data_path,
  112. device,
  113. expected_result,
  114. request,
  115. ):
  116. if self.should_test_with_args:
  117. self._predict_with_predict_args(
  118. model_path,
  119. input_data_path,
  120. device,
  121. expected_result,
  122. request.getfixturevalue("expected_result_with_args"),
  123. )
  124. else:
  125. pytest.skip("Skipping test__call__with_predict_args for this predictor")
  126. def _check_output(self, output, expected_result, expected_num_results):
  127. assert isinstance(output, GeneratorType)
  128. # Note that this exhausts the generator
  129. output = list(output)
  130. assert len(output) == expected_num_results
  131. for result in output:
  132. self._check_result(result, expected_result)
  133. def _check_result(self, result, expected_result):
  134. raise NotImplementedError
  135. def _predict_with_predictor_args(
  136. self, model_path, input_data_path, device, expected_result_with_args
  137. ):
  138. raise NotImplementedError
  139. def _predict_with_predict_args(
  140. self,
  141. model_path,
  142. input_data_path,
  143. device,
  144. expected_result,
  145. expected_result_with_args,
  146. ):
  147. raise NotImplementedError