base.py 5.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170
  1. # Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import json
  15. import tempfile
  16. from pathlib import Path
  17. from types import GeneratorType
  18. import pytest
  19. from tests.testing_utils.download import download, download_and_extract
  20. from tests.testing_utils.misc import get_filename
  21. NUM_INPUT_FILES = 10
  22. DEVICES = ["cpu", "gpu:0"]
  23. BATCH_SIZES = [1, 2, 4]
  24. class BaseTestPredictor(object):
  25. @property
  26. def model_dir(self):
  27. raise NotImplementedError
  28. @property
  29. def model_url(self):
  30. raise NotImplementedError
  31. @property
  32. def input_data_url(self):
  33. raise NotImplementedError
  34. @property
  35. def expected_result_url(self):
  36. raise NotImplementedError
  37. @property
  38. def expected_result_with_args_url(self):
  39. raise NotImplementedError
  40. @property
  41. def predictor_cls(self):
  42. raise NotImplementedError
  43. @property
  44. def should_test_with_args(self):
  45. return False
  46. @pytest.fixture(scope="class")
  47. def data_dir(self):
  48. with tempfile.TemporaryDirectory() as td:
  49. yield Path(td)
  50. @pytest.fixture(scope="class")
  51. def model_path(self, data_dir):
  52. download_and_extract(self.model_url, data_dir, "model")
  53. yield data_dir / "model"
  54. @pytest.fixture(scope="class")
  55. def input_data_path(self, data_dir):
  56. input_data_path = (data_dir / get_filename(self.input_data_url)).with_stem(
  57. "test"
  58. )
  59. download(self.input_data_url, input_data_path)
  60. yield input_data_path
  61. @pytest.fixture(scope="class")
  62. def expected_result(self, data_dir):
  63. expected_result_path = data_dir / "expected.json"
  64. download(self.expected_result_url, expected_result_path)
  65. with open(expected_result_path, "r", encoding="utf-8") as f:
  66. expected_result = json.load(f)
  67. yield expected_result
  68. @pytest.fixture(scope="class")
  69. def expected_result_with_args(self, data_dir):
  70. expected_result_with_args_path = data_dir / "expected_with_args.json"
  71. download(self.expected_result_with_args_url, expected_result_with_args_path)
  72. with open(expected_result_with_args_path, "r", encoding="utf-8") as f:
  73. expected_result = json.load(f)
  74. yield expected_result
  75. @pytest.mark.parametrize("device", DEVICES)
  76. def test___call__single_input_data(
  77. self, model_path, input_data_path, device, expected_result
  78. ):
  79. predictor = self.predictor_cls(model_path, device=device)
  80. output = predictor(str(input_data_path))
  81. self._check_output(output, expected_result, 1)
  82. output = predictor([str(input_data_path), str(input_data_path)])
  83. self._check_output(output, expected_result, 2)
  84. @pytest.mark.parametrize("device", DEVICES)
  85. @pytest.mark.parametrize("batch_size", BATCH_SIZES)
  86. def test___call__input_batch_data(
  87. self, model_path, input_data_path, device, batch_size, expected_result
  88. ):
  89. predictor = self.predictor_cls(model_path, device=device)
  90. predictor.set_predictor(batch_size=batch_size)
  91. output = predictor([str(input_data_path)] * NUM_INPUT_FILES)
  92. self._check_output(output, expected_result, NUM_INPUT_FILES)
  93. @pytest.mark.parametrize("device", DEVICES)
  94. def test__call__with_predictor_args(
  95. self, model_path, input_data_path, device, request
  96. ):
  97. if self.should_test_with_args:
  98. self._predict_with_predictor_args(
  99. model_path,
  100. input_data_path,
  101. device,
  102. request.getfixturevalue("expected_result_with_args"),
  103. )
  104. else:
  105. pytest.skip("Skipping test__call__with_predictor_args for this predictor")
  106. @pytest.mark.parametrize("device", DEVICES)
  107. def test__call__with_predict_args(
  108. self,
  109. model_path,
  110. input_data_path,
  111. device,
  112. expected_result,
  113. request,
  114. ):
  115. if self.should_test_with_args:
  116. self._predict_with_predict_args(
  117. model_path,
  118. input_data_path,
  119. device,
  120. expected_result,
  121. request.getfixturevalue("expected_result_with_args"),
  122. )
  123. else:
  124. pytest.skip("Skipping test__call__with_predict_args for this predictor")
  125. def _check_output(self, output, expected_result, expected_num_results):
  126. assert isinstance(output, GeneratorType)
  127. # Note that this exhausts the generator
  128. output = list(output)
  129. assert len(output) == expected_num_results
  130. for result in output:
  131. self._check_result(result, expected_result)
  132. def _check_result(self, result, expected_result):
  133. raise NotImplementedError
  134. def _predict_with_predictor_args(
  135. self, model_path, input_data_path, device, expected_result_with_args
  136. ):
  137. raise NotImplementedError
  138. def _predict_with_predict_args(
  139. self,
  140. model_path,
  141. input_data_path,
  142. device,
  143. expected_result,
  144. expected_result_with_args,
  145. ):
  146. raise NotImplementedError