base.py 3.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117
  1. # copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import json
  15. import shutil
  16. import tempfile
  17. from pathlib import Path
  18. from types import GeneratorType
  19. import pytest
  20. from tests.testing_utils.download import download, download_and_extract
  21. from tests.testing_utils.misc import get_filename
  22. NUM_INPUT_FILES = 10
  23. DEVICES = ["cpu", "gpu:0"]
  24. BATCH_SIZES = [1, 2, 4]
  25. class BaseTestPredictor(object):
  26. @property
  27. def model_dir(self):
  28. raise NotImplementedError
  29. @property
  30. def model_url(self):
  31. raise NotImplementedError
  32. @property
  33. def input_data_url(self):
  34. raise NotImplementedError
  35. @property
  36. def expected_result_url(self):
  37. raise NotImplementedError
  38. @property
  39. def predictor_cls(self):
  40. raise NotImplementedError
  41. @pytest.fixture(scope="class")
  42. def data_dir(self):
  43. with tempfile.TemporaryDirectory() as td:
  44. yield Path(td)
  45. @pytest.fixture(scope="class")
  46. def model_path(self, data_dir):
  47. download_and_extract(self.model_url, data_dir, "model")
  48. yield data_dir / "model"
  49. @pytest.fixture(scope="class")
  50. def input_data_path(self, data_dir):
  51. input_data_path = (data_dir / get_filename(self.input_data_url)).with_stem(
  52. "test"
  53. )
  54. download(self.input_data_url, input_data_path)
  55. yield input_data_path
  56. @pytest.fixture(scope="class")
  57. def input_data_dir(self, data_dir, input_data_path):
  58. input_data_dir = data_dir / "input_data"
  59. input_data_dir.mkdir()
  60. for i in range(NUM_INPUT_FILES):
  61. shutil.copy(
  62. input_data_path,
  63. (input_data_dir / f"test_{i}").with_suffix(input_data_path.suffix),
  64. )
  65. yield input_data_dir
  66. @pytest.fixture(scope="class")
  67. def expected_result(self, data_dir):
  68. expected_result_path = data_dir / "expected.json"
  69. download(self.expected_result_url, expected_result_path)
  70. with open(expected_result_path, "r", encoding="utf-8") as f:
  71. expected_result = json.load(f)
  72. yield expected_result
  73. @pytest.mark.parametrize("device", DEVICES)
  74. def test___call__single_input_data(
  75. self, model_path, input_data_path, device, expected_result
  76. ):
  77. predictor = self.predictor_cls(model_path, device=device)
  78. output = predictor(str(input_data_path))
  79. self._check_output(output, expected_result, 1)
  80. output = predictor([str(input_data_path), str(input_data_path)])
  81. self._check_output(output, expected_result, 2)
  82. @pytest.mark.parametrize("device", DEVICES)
  83. @pytest.mark.parametrize("batch_size", BATCH_SIZES)
  84. def test___call__input_data_dir(
  85. self, model_path, input_data_dir, device, batch_size, expected_result
  86. ):
  87. predictor = self.predictor_cls(model_path, device=device)
  88. predictor.set_predictor(batch_size=batch_size)
  89. output = predictor(str(input_data_dir))
  90. self._check_output(output, expected_result, NUM_INPUT_FILES)
  91. def _check_output(self, output, expected_result, expected_num_results):
  92. assert isinstance(output, GeneratorType)
  93. # Note that this exhausts the generator
  94. output = list(output)
  95. assert len(output) == expected_num_results
  96. for result in output:
  97. self._check_result(result, expected_result)
  98. def _check_result(self, result, expected_result):
  99. raise NotImplementedError