test_multilabel_classification.py 3.0 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485
  1. # Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import numpy as np
  15. import pytest
  16. from paddlex_hpi.models import MLClasPredictor
  17. from tests.models.base import BaseTestPredictor
  18. from paddlex.inference.results import MLClassResult
  19. MODEL_URL = "https://paddle-model-ecology.bj.bcebos.com/paddlex/PaddleX3.0/deploy/paddlex_hpi/tests/models/ml_clas_model.zip"
  20. INPUT_DATA_URL = "https://paddle-model-ecology.bj.bcebos.com/paddlex/PaddleX3.0/deploy/paddlex_hpi/tests/models/ml_clas_input.jpg"
  21. EXPECTED_RESULT_URL = "https://paddle-model-ecology.bj.bcebos.com/paddlex/PaddleX3.0/deploy/paddlex_hpi/tests/models/ml_clas_result.json"
  22. EXPECTED_RESULT_WITH_ARGS_URL = "https://paddle-model-ecology.bj.bcebos.com/paddlex/PaddleX3.0/deploy/paddlex_hpi/tests/models/ml_clas_result_with_args.json"
  23. class TestMLClasPredictor(BaseTestPredictor):
  24. @property
  25. def model_url(self):
  26. return MODEL_URL
  27. @property
  28. def input_data_url(self):
  29. return INPUT_DATA_URL
  30. @property
  31. def expected_result_url(self):
  32. return EXPECTED_RESULT_URL
  33. @property
  34. def expected_result_with_args_url(self):
  35. return EXPECTED_RESULT_WITH_ARGS_URL
  36. @property
  37. def should_test_with_args(self):
  38. return True
  39. @property
  40. def predictor_cls(self):
  41. return MLClasPredictor
  42. def _predict_with_predictor_args(
  43. self, model_path, input_data_path, device, expected_result_with_args
  44. ):
  45. predictor = self.predictor_cls(model_path, device=device, threshold=0.85)
  46. output = predictor(str(input_data_path))
  47. self._check_output(output, expected_result_with_args, 1)
  48. def _predict_with_predict_args(
  49. self,
  50. model_path,
  51. input_data_path,
  52. device,
  53. expected_result,
  54. expected_result_with_args,
  55. ):
  56. predictor = self.predictor_cls(model_path, device=device)
  57. with pytest.raises(TypeError):
  58. output = predictor(str(input_data_path), threshold=0.85)
  59. output = list(output)
  60. def _check_result(self, result, expected_result):
  61. assert isinstance(result, MLClassResult)
  62. assert "input_img" in result
  63. result.pop("input_img")
  64. assert set(result) == set(expected_result)
  65. assert result["class_ids"] == expected_result["class_ids"]
  66. assert np.allclose(
  67. np.array(result["scores"]),
  68. np.array(expected_result["scores"]),
  69. rtol=1e-2,
  70. atol=1e-3,
  71. )
  72. assert result["label_names"] == expected_result["label_names"]