multilabel_classification.py 4.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129
  1. # copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import os
  15. from typing import Any, Dict, List, Optional, Union
  16. import ultra_infer as ui
  17. import numpy as np
  18. from pathlib import Path
  19. import tempfile
  20. import yaml
  21. from paddlex.inference.common.batch_sampler import ImageBatchSampler
  22. from paddlex.inference.results import MLClassResult
  23. from paddlex.modules.multilabel_classification.model_list import MODELS
  24. from paddlex_hpi.models.base import CVPredictor, HPIParams
  25. class MLClasPredictor(CVPredictor):
  26. entities = MODELS
  27. def __init__(
  28. self,
  29. model_dir: Union[str, os.PathLike],
  30. config: Optional[Dict[str, Any]] = None,
  31. device: Optional[str] = None,
  32. hpi_params: Optional[HPIParams] = None,
  33. threshold: Union[float, dict, list, None] = None,
  34. ) -> None:
  35. self._threshold = threshold
  36. super().__init__(
  37. model_dir=model_dir,
  38. config=config,
  39. device=device,
  40. hpi_params=hpi_params,
  41. )
  42. self._label_list = self._get_label_list()
  43. def _build_ui_model(
  44. self, option: ui.RuntimeOption
  45. ) -> ui.vision.classification.PyOnlyMultilabelClassificationModel:
  46. if self._threshold:
  47. if isinstance(self._threshold, (dict, list)):
  48. raise TypeError("`threshold` must be float or None in PaddleX HPI")
  49. with open(self.config_path, "r") as file:
  50. config = yaml.safe_load(file)
  51. config["PostProcess"]["MultiLabelThreshOutput"][
  52. "threshold"
  53. ] = self._threshold
  54. temp_dir = os.path.dirname(self.config_path)
  55. with tempfile.NamedTemporaryFile(
  56. delete=False, dir=temp_dir, suffix=".yml", mode="w", encoding="utf-8"
  57. ) as temp_file:
  58. temp_file_path = temp_file.name
  59. yaml.safe_dump(config, temp_file, default_flow_style=False)
  60. model = ui.vision.classification.PyOnlyMultilabelClassificationModel(
  61. str(self.model_path),
  62. str(self.params_path),
  63. str(Path(temp_file_path)),
  64. runtime_option=option,
  65. )
  66. else:
  67. model = ui.vision.classification.PyOnlyMultilabelClassificationModel(
  68. str(self.model_path),
  69. str(self.params_path),
  70. str(self.config_path),
  71. runtime_option=option,
  72. )
  73. return model
  74. def _build_batch_sampler(self) -> ImageBatchSampler:
  75. return ImageBatchSampler()
  76. def _get_result_class(self) -> type:
  77. return MLClassResult
  78. def process(
  79. self,
  80. batch_data: List[Any],
  81. threshold: Union[float, dict, list, None] = None,
  82. ) -> Dict[str, List[Any]]:
  83. if threshold:
  84. raise TypeError(
  85. "`threshold` is not supported for multilabel classification in PaddleX HPI"
  86. )
  87. batch_raw_imgs = self._data_reader(imgs=batch_data)
  88. imgs = [np.ascontiguousarray(img) for img in batch_raw_imgs]
  89. ui_results = self._ui_model.batch_predict(imgs)
  90. class_ids_list = []
  91. scores_list = []
  92. label_names_list = []
  93. for ui_result in ui_results:
  94. class_ids_list.append(ui_result.label_ids)
  95. scores_list.append(np.around(ui_result.scores, decimals=5).tolist())
  96. if self._label_list is not None:
  97. label_names_list.append(
  98. [self._label_list[i] for i in ui_result.label_ids]
  99. )
  100. return {
  101. "input_path": batch_data,
  102. "input_img": batch_raw_imgs,
  103. "class_ids": class_ids_list,
  104. "scores": scores_list,
  105. "label_names": label_names_list,
  106. }
  107. def _get_label_list(self) -> Optional[List[str]]:
  108. pp_config = self.config["PostProcess"]
  109. if "MultiLabelThreshOutput" not in pp_config:
  110. raise RuntimeError("`MultiLabelThreshOutput` config not found")
  111. label_list = pp_config["MultiLabelThreshOutput"].get("label_list", None)
  112. return label_list