test_semantic_segmentation.py 2.7 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879
  1. # Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import numpy as np
  15. import pytest
  16. from paddlex_hpi.models import SegPredictor
  17. from tests.models.base import BaseTestPredictor
  18. from paddlex.inference.results import SegResult
  19. MODEL_URL = "https://paddle-model-ecology.bj.bcebos.com/paddlex/PaddleX3.0/deploy/paddlex_hpi/tests/models/seg_model.zip"
  20. INPUT_DATA_URL = "https://paddle-model-ecology.bj.bcebos.com/paddlex/PaddleX3.0/deploy/paddlex_hpi/tests/models/seg_input.png"
  21. EXPECTED_RESULT_URL = "https://paddle-model-ecology.bj.bcebos.com/paddlex/PaddleX3.0/deploy/paddlex_hpi/tests/models/seg_result.json"
  22. class TestSegPredictor(BaseTestPredictor):
  23. @property
  24. def model_url(self):
  25. return MODEL_URL
  26. @property
  27. def input_data_url(self):
  28. return INPUT_DATA_URL
  29. @property
  30. def expected_result_url(self):
  31. return EXPECTED_RESULT_URL
  32. @property
  33. def expected_result_with_args_url(self):
  34. return EXPECTED_RESULT_URL
  35. @property
  36. def predictor_cls(self):
  37. return SegPredictor
  38. @property
  39. def should_test_with_args(self):
  40. return True
  41. def _predict_with_predictor_args(
  42. self, model_path, input_data_path, device, expected_result_with_args
  43. ):
  44. with pytest.raises(TypeError):
  45. predictor = self.predictor_cls(model_path, device=device, target_size=400)
  46. def _predict_with_predict_args(
  47. self,
  48. model_path,
  49. input_data_path,
  50. device,
  51. expected_result,
  52. expected_result_with_args,
  53. ):
  54. predictor = self.predictor_cls(model_path, device=device)
  55. with pytest.raises(TypeError):
  56. output = predictor(str(input_data_path), target_size=400)
  57. output = list(output)
  58. def _check_result(self, result, expected_result):
  59. assert isinstance(result, SegResult)
  60. assert "input_img" in result
  61. result.pop("input_img")
  62. assert set(result) == set(expected_result)
  63. pred = result["pred"]
  64. expected_pred = np.array(expected_result["pred"], dtype=np.int32)
  65. assert pred.shape == expected_pred.shape
  66. assert (pred != expected_pred).sum() / pred.size < 0.01