test_image_classification.py 3.1 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485
  1. # Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import numpy as np
  15. from paddlex_hpi.models import ClasPredictor
  16. from tests.models.base import BaseTestPredictor
  17. from paddlex.inference.results import TopkResult
  18. MODEL_URL = "https://paddle-model-ecology.bj.bcebos.com/paddlex/PaddleX3.0/deploy/paddlex_hpi/tests/models/clas_model.zip"
  19. INPUT_DATA_URL = "https://paddle-model-ecology.bj.bcebos.com/paddlex/PaddleX3.0/deploy/paddlex_hpi/tests/models/clas_input.jpg"
  20. EXPECTED_RESULT_URL = "https://paddle-model-ecology.bj.bcebos.com/paddlex/PaddleX3.0/deploy/paddlex_hpi/tests/models/clas_result.json"
  21. EXPECTED_RESULT_WITH_ARGS_URL = "https://paddle-model-ecology.bj.bcebos.com/paddlex/PaddleX3.0/deploy/paddlex_hpi/tests/models/clas_result_with_args.json"
  22. class TestClasPredictor(BaseTestPredictor):
  23. @property
  24. def model_url(self):
  25. return MODEL_URL
  26. @property
  27. def input_data_url(self):
  28. return INPUT_DATA_URL
  29. @property
  30. def expected_result_url(self):
  31. return EXPECTED_RESULT_URL
  32. @property
  33. def expected_result_with_args_url(self):
  34. return EXPECTED_RESULT_WITH_ARGS_URL
  35. @property
  36. def should_test_with_args(self):
  37. return True
  38. @property
  39. def predictor_cls(self):
  40. return ClasPredictor
  41. def _predict_with_predictor_args(
  42. self, model_path, input_data_path, device, expected_result_with_args
  43. ):
  44. predictor = self.predictor_cls(model_path, device=device, topk=2)
  45. output = predictor(str(input_data_path))
  46. self._check_output(output, expected_result_with_args, 1)
  47. def _predict_with_predict_args(
  48. self,
  49. model_path,
  50. input_data_path,
  51. device,
  52. expected_result,
  53. expected_result_with_args,
  54. ):
  55. predictor = self.predictor_cls(model_path, device=device)
  56. output = predictor(str(input_data_path), topk=2)
  57. self._check_output(output, expected_result_with_args, 1)
  58. output = predictor(str(input_data_path))
  59. self._check_output(output, expected_result, 1)
  60. def _check_result(self, result, expected_result):
  61. assert isinstance(result, TopkResult)
  62. assert "input_img" in result
  63. result.pop("input_img")
  64. assert set(result) == set(expected_result)
  65. assert result["class_ids"] == expected_result["class_ids"]
  66. assert np.allclose(
  67. np.array(result["scores"]),
  68. np.array(expected_result["scores"]),
  69. rtol=1e-2,
  70. atol=1e-3,
  71. )
  72. assert result["label_names"] == expected_result["label_names"]