semantic_segmentation.py 2.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263
  1. # copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from typing import Any, Dict, List
  15. import ultra_infer as ui
  16. import numpy as np
  17. from paddlex.inference.common.batch_sampler import ImageBatchSampler
  18. from paddlex.inference.results import SegResult
  19. from paddlex.modules.semantic_segmentation.model_list import MODELS
  20. from paddlex_hpi.models.base import CVPredictor
  21. class SegPredictor(CVPredictor):
  22. entities = MODELS
  23. def _build_ui_model(
  24. self, option: ui.RuntimeOption
  25. ) -> ui.vision.segmentation.PaddleSegModel:
  26. model = ui.vision.segmentation.PaddleSegModel(
  27. str(self.model_path),
  28. str(self.params_path),
  29. str(self.config_path),
  30. runtime_option=option,
  31. )
  32. return model
  33. def _build_batch_sampler(self) -> ImageBatchSampler:
  34. return ImageBatchSampler()
  35. def _get_result_class(self) -> type:
  36. return SegResult
  37. def process(self, batch_data: List[Any]) -> Dict[str, List[Any]]:
  38. batch_raw_imgs = self._data_reader(imgs=batch_data)
  39. imgs = [np.ascontiguousarray(img) for img in batch_raw_imgs]
  40. ui_results = self._ui_model.batch_predict(imgs)
  41. batch_preds = []
  42. for ui_result in ui_results:
  43. pred = np.array(ui_result.label_map, dtype=np.int32).reshape(
  44. ui_result.shape
  45. )
  46. pred = pred[np.newaxis]
  47. batch_preds.append(pred)
  48. return {
  49. "input_path": batch_data,
  50. "input_img": batch_raw_imgs,
  51. "pred": batch_preds,
  52. }