test_object_detection.py 3.2 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485
  1. # Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from paddlex_hpi.models import DetPredictor
  15. from tests.models.base import BaseTestPredictor
  16. from tests.testing_utils.cv import compare_det_results
  17. from paddlex.inference.results import DetResult
  18. MODEL_URL = "https://paddle-model-ecology.bj.bcebos.com/paddlex/PaddleX3.0/deploy/paddlex_hpi/tests/models/det_model.zip"
  19. INPUT_DATA_URL = "https://paddle-model-ecology.bj.bcebos.com/paddlex/PaddleX3.0/deploy/paddlex_hpi/tests/models/det_input.jpg"
  20. EXPECTED_RESULT_URL = "https://paddle-model-ecology.bj.bcebos.com/paddlex/PaddleX3.0/deploy/paddlex_hpi/tests/models/det_result.json"
  21. EXPECTED_RESULT_WITH_ARGS_URL = "https://paddle-model-ecology.bj.bcebos.com/paddlex/PaddleX3.0/deploy/paddlex_hpi/tests/models/det_result_with_args.json"
  22. class TestDetPredictor(BaseTestPredictor):
  23. @property
  24. def model_url(self):
  25. return MODEL_URL
  26. @property
  27. def input_data_url(self):
  28. return INPUT_DATA_URL
  29. @property
  30. def expected_result_url(self):
  31. return EXPECTED_RESULT_URL
  32. @property
  33. def expected_result_with_args_url(self):
  34. return EXPECTED_RESULT_WITH_ARGS_URL
  35. @property
  36. def should_test_with_args(self):
  37. return True
  38. @property
  39. def predictor_cls(self):
  40. return DetPredictor
  41. def _predict_with_predictor_args(
  42. self, model_path, input_data_path, device, expected_result_with_args
  43. ):
  44. predictor = self.predictor_cls(model_path, device=device, threshold=0.7)
  45. output = predictor(str(input_data_path))
  46. self._check_output(output, expected_result_with_args, 1)
  47. def _predict_with_predict_args(
  48. self,
  49. model_path,
  50. input_data_path,
  51. device,
  52. expected_result,
  53. expected_result_with_args,
  54. ):
  55. predictor = self.predictor_cls(model_path, device=device)
  56. output = predictor(str(input_data_path), threshold=0.7)
  57. self._check_output(output, expected_result_with_args, 1)
  58. output = predictor(str(input_data_path))
  59. self._check_output(output, expected_result, 1)
  60. def _check_result(self, result, expected_result):
  61. assert isinstance(result, DetResult)
  62. assert "input_img" in result
  63. result.pop("input_img")
  64. assert set(result) == set(expected_result)
  65. compare_det_results(
  66. [obj["coordinate"] for obj in result["boxes"]],
  67. [obj["coordinate"] for obj in expected_result["boxes"]],
  68. labels1=[obj["cls_id"] for obj in result["boxes"]],
  69. labels2=[obj["cls_id"] for obj in expected_result["boxes"]],
  70. scores1=[obj["score"] for obj in result["boxes"]],
  71. scores2=[obj["score"] for obj in expected_result["boxes"]],
  72. )