Kaynağa Gözat

split multilabel_classification and image_classification

zhouchangda 1 yıl önce
ebeveyn
işleme
94604e9bc8
29 değiştirilmiş dosya ile 1354 ekleme ve 353 silme
  1. 0 1
      paddlex/modules/image_classification/__init__.py
  2. 1 36
      paddlex/modules/image_classification/dataset_checker/__init__.py
  3. 2 9
      paddlex/modules/image_classification/dataset_checker/dataset_src/analyse_dataset.py
  4. 10 28
      paddlex/modules/image_classification/dataset_checker/dataset_src/check_dataset.py
  5. 0 44
      paddlex/modules/image_classification/dataset_checker/dataset_src/utils/visualizer.py
  6. 1 15
      paddlex/modules/image_classification/evaluator.py
  7. 1 7
      paddlex/modules/image_classification/exportor.py
  8. 0 9
      paddlex/modules/image_classification/model_list.py
  9. 0 1
      paddlex/modules/image_classification/predictor/__init__.py
  10. 0 40
      paddlex/modules/image_classification/predictor/predictor_ml.py
  11. 0 99
      paddlex/modules/image_classification/predictor/transforms.py
  12. 0 64
      paddlex/modules/image_classification/trainer_ml.py
  13. 19 0
      paddlex/modules/multilabel_classification/__init__.py
  14. 104 0
      paddlex/modules/multilabel_classification/dataset_checker/__init__.py
  15. 18 0
      paddlex/modules/multilabel_classification/dataset_checker/dataset_src/__init__.py
  16. 95 0
      paddlex/modules/multilabel_classification/dataset_checker/dataset_src/analyse_dataset.py
  17. 131 0
      paddlex/modules/multilabel_classification/dataset_checker/dataset_src/check_dataset.py
  18. 81 0
      paddlex/modules/multilabel_classification/dataset_checker/dataset_src/split_dataset.py
  19. 13 0
      paddlex/modules/multilabel_classification/dataset_checker/dataset_src/utils/__init__.py
  20. 152 0
      paddlex/modules/multilabel_classification/dataset_checker/dataset_src/utils/visualizer.py
  21. 43 0
      paddlex/modules/multilabel_classification/evaluator.py
  22. 22 0
      paddlex/modules/multilabel_classification/exportor.py
  23. 22 0
      paddlex/modules/multilabel_classification/model_list.py
  24. 17 0
      paddlex/modules/multilabel_classification/predictor/__init__.py
  25. 29 0
      paddlex/modules/multilabel_classification/predictor/keys.py
  26. 82 0
      paddlex/modules/multilabel_classification/predictor/predictor.py
  27. 259 0
      paddlex/modules/multilabel_classification/predictor/transforms.py
  28. 105 0
      paddlex/modules/multilabel_classification/predictor/utils.py
  29. 147 0
      paddlex/modules/multilabel_classification/trainer.py

+ 0 - 1
paddlex/modules/image_classification/__init__.py

@@ -13,7 +13,6 @@
 # limitations under the License.
 
 from .trainer import ClsTrainer
-from .trainer_ml import MLClsTrainer
 from .dataset_checker import ClsDatasetChecker
 from .evaluator import ClsEvaluator
 from .exportor import ClsExportor

+ 1 - 36
paddlex/modules/image_classification/dataset_checker/__init__.py

@@ -16,7 +16,7 @@ from pathlib import Path
 
 from ...base import BaseDatasetChecker
 from .dataset_src import check, split_dataset, deep_analyse
-from ..model_list import MODELS, ML_MODELS
+from ..model_list import MODELS
 
 
 class ClsDatasetChecker(BaseDatasetChecker):
@@ -102,38 +102,3 @@ class ClsDatasetChecker(BaseDatasetChecker):
             str: dataset type
         """
         return "ClsDataset"
-
-
-class MLClsDatasetChecker(ClsDatasetChecker):
-    entities = ML_MODELS
-    sample_num = 10
-
-    def check_dataset(self, dataset_dir: str, sample_num: int = sample_num) -> dict:
-        """check if the dataset meets the specifications and get dataset summary
-
-        Args:
-            dataset_dir (str): the root directory of dataset.
-            sample_num (int): the number to be sampled.
-        Returns:
-            dict: dataset summary.
-        """
-        return check(dataset_dir, self.output, dataset_type="MLCls")
-
-    def analyse(self, dataset_dir: str) -> dict:
-        """deep analyse dataset
-
-        Args:
-            dataset_dir (str): the root directory of dataset.
-
-        Returns:
-            dict: the deep analysis results.
-        """
-        return deep_analyse(dataset_dir, self.output, dataset_type="MLCls")
-
-    def get_dataset_type(self) -> str:
-        """return the dataset type
-
-        Returns:
-            str: dataset type
-        """
-        return "MLClsDataset"

+ 2 - 9
paddlex/modules/image_classification/dataset_checker/dataset_src/analyse_dataset.py

@@ -29,7 +29,7 @@ from .....utils.file_interface import custom_open
 from .....utils.fonts import PINGFANG_FONT_FILE_PATH
 
 
-def deep_analyse(dataset_path, output, dataset_type="Cls"):
+def deep_analyse(dataset_path, output):
     """class analysis for dataset"""
     tags = ["train", "val"]
     labels_cnt = defaultdict(str)
@@ -48,14 +48,7 @@ def deep_analyse(dataset_path, output, dataset_type="Cls"):
             lines = f.readlines()
         for line in lines:
             line = line.strip().split()
-            if dataset_type == "Cls":
-                classes_num[labels_cnt[line[1]]] += 1
-            elif dataset_type == "MLCls":
-                for i, label in enumerate(line[1].split(",")):
-                    if label == "1":
-                        classes_num[labels_cnt[str(i)]] += 1
-            else:
-                raise ValueError(f"dataset_type {dataset_type} is not supported")
+            classes_num[labels_cnt[line[1]]] += 1
         if tag == "train":
             cnts_train = [cat_ids for cat_name, cat_ids in classes_num.items()]
         elif tag == "val":

+ 10 - 28
paddlex/modules/image_classification/dataset_checker/dataset_src/check_dataset.py

@@ -19,10 +19,10 @@ from PIL import Image, ImageOps
 from collections import defaultdict
 
 from .....utils.errors import DatasetFileNotFoundError, CheckFailedError
-from .utils.visualizer import draw_label, draw_multi_label
+from .utils.visualizer import draw_label
 
 
-def check(dataset_dir, output, sample_num=10, dataset_type="Cls"):
+def check(dataset_dir, output, sample_num=10):
     """check dataset"""
     dataset_dir = osp.abspath(dataset_dir)
     # Custom dataset
@@ -30,10 +30,7 @@ def check(dataset_dir, output, sample_num=10, dataset_type="Cls"):
         raise DatasetFileNotFoundError(file_path=dataset_dir)
 
     tags = ["train", "val"]
-    if dataset_type == "MLCls":
-        delim = "\t"
-    else:
-        delim = " "
+    delim = " "
     valid_num_parts = 2
 
     sample_cnts = dict()
@@ -105,14 +102,7 @@ def check(dataset_dir, output, sample_num=10, dataset_type="Cls"):
                     if len(sample_paths[tag]) < sample_num:
                         img = Image.open(img_path)
                         img = ImageOps.exif_transpose(img)
-                        if dataset_type == "Cls":
-                            vis_im = draw_label(img, label, label_map_dict)
-                        elif dataset_type == "MLCls":
-                            vis_im = draw_multi_label(img, label, label_map_dict)
-                        else:
-                            raise CheckFailedError(
-                                f"Do not support dataset type '{dataset_type}', only support 'Cls' and 'MLCls'."
-                            )
+                        vis_im = draw_label(img, label, label_map_dict)
                         vis_path = osp.join(vis_save_dir, osp.basename(file_name))
                         vis_im.save(vis_path)
                         sample_path = osp.join(
@@ -120,20 +110,12 @@ def check(dataset_dir, output, sample_num=10, dataset_type="Cls"):
                         )
                         sample_paths[tag].append(sample_path)
 
-                    if dataset_type == "Cls":
-                        try:
-                            label = int(label)
-                        except (ValueError, TypeError) as e:
-                            raise CheckFailedError(
-                                f"Ensure that the second number in each line in {label_file} should be int."
-                            ) from e
-                    elif dataset_type == "MLCls":
-                        try:
-                            label = list(map(int, label.split(",")))
-                        except (ValueError, TypeError) as e:
-                            raise CheckFailedError(
-                                f"Ensure that the second number in each line in {label_file} should be int."
-                            ) from e
+                    try:
+                        label = int(label)
+                    except (ValueError, TypeError) as e:
+                        raise CheckFailedError(
+                            f"Ensure that the second number in each line in {label_file} should be int."
+                        ) from e
 
     num_classes = max(labels) + 1
 

+ 0 - 44
paddlex/modules/image_classification/dataset_checker/dataset_src/utils/visualizer.py

@@ -154,47 +154,3 @@ def draw_label(image, label, label_map_dict):
     draw.text((text_x, text_y), label_map_dict[int(label)], fill=font_color, font=font)
 
     return image
-
-
-def draw_multi_label(image, label, label_map_dict):
-    labels = label.split(",")
-    label_names = [
-        label_map_dict[i] for i, label in enumerate(labels) if int(label) == 1
-    ]
-    image = image.convert("RGB")
-    image_width, image_height = image.size
-    font_size = int(image_width * 0.06)
-
-    font = ImageFont.truetype(PINGFANG_FONT_FILE_PATH, font_size)
-    text_lines = []
-    row_width = 0
-    row_height = 0
-    row_text = "\t"
-    for label_name in label_names:
-        text = f"{label_name}\t"
-        text_width, row_height = font.getsize(text)
-        if row_width + text_width <= image_width:
-            row_text += text
-            row_width += text_width
-        else:
-            text_lines.append(row_text)
-            row_text = "\t" + text
-            row_width = text_width
-    text_lines.append(row_text)
-    color_list = colormap(rgb=True)
-    color = tuple(color_list[0])
-    new_image_height = image_height + len(text_lines) * int(row_height * 1.2)
-    new_image = Image.new("RGB", (image_width, new_image_height), color)
-    new_image.paste(image, (0, 0))
-
-    draw = ImageDraw.Draw(new_image)
-    font_color = tuple(font_colormap(3))
-    for i, text in enumerate(text_lines):
-        text_width, _ = font.getsize(text)
-        draw.text(
-            (0, image_height + i * int(row_height * 1.2)),
-            text,
-            fill=font_color,
-            font=font,
-        )
-    return new_image

+ 1 - 15
paddlex/modules/image_classification/evaluator.py

@@ -13,7 +13,7 @@
 # limitations under the License.
 
 from ..base import BaseEvaluator
-from .model_list import MODELS, ML_MODELS
+from .model_list import MODELS
 
 
 class ClsEvaluator(BaseEvaluator):
@@ -41,17 +41,3 @@ class ClsEvaluator(BaseEvaluator):
             "weight_path": self.eval_config.weight_path,
             "device": self.get_device(using_device_number=1),
         }
-
-
-class MlClsEvaluator(ClsEvaluator):
-    entities = ML_MODELS
-
-    def update_config(self):
-        """update evalution config"""
-        if self.eval_config.log_interval:
-            self.pdx_config.update_log_interval(self.eval_config.log_interval)
-        if self.pdx_config["Arch"]["name"] == "DistillationModel":
-            self.pdx_config.update_teacher_model(pretrained=False)
-            self.pdx_config.update_student_model(pretrained=False)
-        self.pdx_config.update_dataset(self.global_config.dataset_dir, "MLClsDataset")
-        self.pdx_config.update_pretrained_weights(self.eval_config.weight_path)

+ 1 - 7
paddlex/modules/image_classification/exportor.py

@@ -13,16 +13,10 @@
 # limitations under the License.
 
 from ..base import BaseExportor
-from .model_list import MODELS, ML_MODELS
+from .model_list import MODELS
 
 
 class ClsExportor(BaseExportor):
     """Image Classification Model Exportor"""
 
     entities = MODELS
-
-
-class MlClsExportor(BaseExportor):
-    """Image Classification Model Exportor"""
-
-    entities = ML_MODELS

+ 0 - 9
paddlex/modules/image_classification/model_list.py

@@ -79,12 +79,3 @@ MODELS = [
     "SwinTransformer_large_patch4_window7_224",
     "SwinTransformer_large_patch4_window12_384",
 ]
-
-ML_MODELS = [
-    "ResNet50_ML",
-    "PP-LCNet_x1_0_ML",
-    "PP-HGNetV2-B0_ML",
-    "PP-HGNetV2-B4_ML",
-    "PP-HGNetV2-B6_ML",
-    "CLIP_vit_base_patch16_448_ML",
-]

+ 0 - 1
paddlex/modules/image_classification/predictor/__init__.py

@@ -13,5 +13,4 @@
 # limitations under the License.
 
 from .predictor import ClsPredictor
-from .predictor_ml import MLClsPredictor
 from . import transforms

+ 0 - 40
paddlex/modules/image_classification/predictor/predictor_ml.py

@@ -1,40 +0,0 @@
-# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
-# 
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-#    http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-import os
-import numpy as np
-from pathlib import Path
-
-from ...base import BasePredictor
-from ...base.predictor.transforms import image_common
-from .keys import ClsKeys as K
-from .utils import InnerConfig
-from ....utils import logging
-from . import transforms as T
-from .predictor import ClsPredictor
-from ..model_list import ML_MODELS
-
-
-class MLClsPredictor(ClsPredictor, BasePredictor):
-    """ MLClssification Predictor """
-    entities = ML_MODELS
-
-    def _get_post_transforms_from_config(self):
-        """ get postprocess transforms """
-        post_transforms = self.other_src.post_transforms
-        post_transforms.extend([
-            T.PrintResult(), T.SaveMLClsResults(self.output,
-                                                self.other_src.labels)
-        ])
-        return post_transforms

+ 0 - 99
paddlex/modules/image_classification/predictor/transforms.py

@@ -30,7 +30,6 @@ __all__ = [
     "NormalizeFeatures",
     "PrintResult",
     "SaveClsResults",
-    "MultiLabelThreshOutput",
 ]
 
 
@@ -288,101 +287,3 @@ class SaveClsResults(BaseTransform):
     def get_output_keys(cls):
         """get output keys"""
         return []
-
-
-class MultiLabelThreshOutput(BaseTransform):
-    def __init__(self, threshold=0.5, class_ids=None, delimiter=None):
-        super().__init__()
-        assert isinstance(threshold, (float,))
-        self.threshold = threshold
-        self.delimiter = delimiter if delimiter is not None else " "
-        self.class_id_map = _parse_class_id_map(class_ids)
-
-    def apply(self, data):
-        """apply"""
-        y = []
-        x = data[K.CLS_PRED]
-        pred_index = np.where(x >= self.threshold)[0].astype("int32")
-        index = pred_index[np.argsort(x[pred_index])][::-1]
-        clas_id_list = []
-        score_list = []
-        label_name_list = []
-        for i in index:
-            clas_id_list.append(i.item())
-            score_list.append(x[i].item())
-            if self.class_id_map is not None:
-                label_name_list.append(self.class_id_map[i.item()])
-        result = {
-            "class_ids": clas_id_list,
-            "scores": np.around(score_list, decimals=5).tolist(),
-            "label_names": label_name_list,
-        }
-        y.append(result)
-        data[K.CLS_RESULT] = y
-        return data
-
-    @classmethod
-    def get_input_keys(cls):
-        """get input keys"""
-        return [K.IM_PATH, K.CLS_PRED]
-
-    @classmethod
-    def get_output_keys(cls):
-        """get output keys"""
-        return [K.CLS_RESULT]
-
-
-class SaveMLClsResults(SaveClsResults, BaseTransform):
-    def __init__(self, save_dir, class_ids=None):
-        super().__init__(save_dir=save_dir)
-        self.save_dir = save_dir
-        self.class_id_map = _parse_class_id_map(class_ids)
-        self._writer = ImageWriter(backend="pillow")
-
-    def apply(self, data):
-        """Draw label on image"""
-        ori_path = data[K.IM_PATH]
-        results = data[K.CLS_RESULT]
-        scores = results[0]["scores"]
-        label_names = results[0]["label_names"]
-        file_name = os.path.basename(ori_path)
-        save_path = os.path.join(self.save_dir, file_name)
-        image = ImageReader(backend="pil").read(ori_path)
-        image = image.convert("RGB")
-        image_width, image_height = image.size
-        font_size = int(image_width * 0.06)
-
-        font = ImageFont.truetype(PINGFANG_FONT_FILE_PATH, font_size)
-        text_lines = []
-        row_width = 0
-        row_height = 0
-        row_text = "\t"
-        for label_name, score in zip(label_names, scores):
-            text = f"{label_name}({score})\t"
-            text_width, row_height = font.getsize(text)
-            if row_width + text_width <= image_width:
-                row_text += text
-                row_width += text_width
-            else:
-                text_lines.append(row_text)
-                row_text = "\t" + text
-                row_width = text_width
-        text_lines.append(row_text)
-        color_list = self._get_colormap(rgb=True)
-        color = tuple(color_list[0])
-        new_image_height = image_height + len(text_lines) * int(row_height * 1.2)
-        new_image = Image.new("RGB", (image_width, new_image_height), color)
-        new_image.paste(image, (0, 0))
-
-        draw = ImageDraw.Draw(new_image)
-        font_color = tuple(self._get_font_colormap(3))
-        for i, text in enumerate(text_lines):
-            text_width, _ = font.getsize(text)
-            draw.text(
-                (0, image_height + i * int(row_height * 1.2)),
-                text,
-                fill=font_color,
-                font=font,
-            )
-        self._write_image(save_path, new_image)
-        return data

+ 0 - 64
paddlex/modules/image_classification/trainer_ml.py

@@ -1,64 +0,0 @@
-# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
-# 
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-#    http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-import json
-import shutil
-import paddle
-from pathlib import Path
-
-from ..base import BaseTrainer, BaseTrainDeamon
-from .trainer import ClsTrainer, ClsTrainDeamon
-from .model_list import ML_MODELS
-from ...utils.config import AttrDict
-
-
-class MLClsTrainer(ClsTrainer, BaseTrainer):
-    """ Multi Label Image Classification Model Trainer """
-    entities = ML_MODELS
-
-    def update_config(self):
-        """update training config
-        """
-        if self.train_config.log_interval:
-            self.pdx_config.update_log_interval(self.train_config.log_interval)
-        if self.train_config.eval_interval:
-            self.pdx_config.update_eval_interval(
-                self.train_config.eval_interval)
-        if self.train_config.save_interval:
-            self.pdx_config.update_save_interval(
-                self.train_config.save_interval)
-
-        self.pdx_config.update_dataset(self.global_config.dataset_dir,
-                                       "MLClsDataset")
-        if self.train_config.num_classes is not None:
-            self.pdx_config.update_num_classes(self.train_config.num_classes)
-        if self.train_config.pretrain_weight_path and self.train_config.pretrain_weight_path != "":
-            self.pdx_config.update_pretrained_weights(
-                self.train_config.pretrain_weight_path)
-
-        label_dict_path = Path(self.global_config.dataset_dir).joinpath(
-            "label.txt")
-        if label_dict_path.exists():
-            self.dump_label_dict(label_dict_path)
-        if self.train_config.batch_size is not None:
-            self.pdx_config.update_batch_size(self.train_config.batch_size)
-        if self.train_config.learning_rate is not None:
-            self.pdx_config.update_learning_rate(
-                self.train_config.learning_rate)
-        if self.train_config.epochs_iters is not None:
-            self.pdx_config._update_epochs(self.train_config.epochs_iters)
-        if self.train_config.warmup_steps is not None:
-            self.pdx_config.update_warmup_epochs(self.train_config.warmup_steps)
-        if self.global_config.output is not None:
-            self.pdx_config._update_output_dir(self.global_config.output)

+ 19 - 0
paddlex/modules/multilabel_classification/__init__.py

@@ -0,0 +1,19 @@
+# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from .trainer import MLClsTrainer
+from .dataset_checker import MLClsDatasetChecker
+from .evaluator import MlClsEvaluator
+from .exportor import MlClsExportor
+from .predictor import MLClsPredictor, transforms

+ 104 - 0
paddlex/modules/multilabel_classification/dataset_checker/__init__.py

@@ -0,0 +1,104 @@
+# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from pathlib import Path
+
+from ...base import BaseDatasetChecker
+from .dataset_src import check, split_dataset, deep_analyse
+from ..model_list import MODELS
+
+
+class MLClsDatasetChecker(BaseDatasetChecker):
+    """Dataset Checker for Image Classification Model"""
+
+    entities = MODELS
+    sample_num = 10
+
+    def get_dataset_root(self, dataset_dir: str) -> str:
+        """find the dataset root dir
+
+        Args:
+            dataset_dir (str): the directory that contain dataset.
+
+        Returns:
+            str: the root directory of dataset.
+        """
+        anno_dirs = list(Path(dataset_dir).glob("**/images"))
+        assert len(anno_dirs) == 1
+        dataset_dir = anno_dirs[0].parent.as_posix()
+        return dataset_dir
+
+    def convert_dataset(self, src_dataset_dir: str) -> str:
+        """convert the dataset from other type to specified type
+
+        Args:
+            src_dataset_dir (str): the root directory of dataset.
+
+        Returns:
+            str: the root directory of converted dataset.
+        """
+        return src_dataset_dir
+
+    def split_dataset(self, src_dataset_dir: str) -> str:
+        """repartition the train and validation dataset
+
+        Args:
+            src_dataset_dir (str): the root directory of dataset.
+
+        Returns:
+            str: the root directory of splited dataset.
+        """
+        return split_dataset(
+            src_dataset_dir,
+            self.check_dataset_config.split.train_percent,
+            self.check_dataset_config.split.val_percent,
+        )
+
+    def check_dataset(self, dataset_dir: str, sample_num: int = sample_num) -> dict:
+        """check if the dataset meets the specifications and get dataset summary
+
+        Args:
+            dataset_dir (str): the root directory of dataset.
+            sample_num (int): the number to be sampled.
+        Returns:
+            dict: dataset summary.
+        """
+        return check(dataset_dir, self.output)
+
+    def analyse(self, dataset_dir: str) -> dict:
+        """deep analyse dataset
+
+        Args:
+            dataset_dir (str): the root directory of dataset.
+
+        Returns:
+            dict: the deep analysis results.
+        """
+        return deep_analyse(dataset_dir, self.output)
+
+    def get_show_type(self) -> str:
+        """get the show type of dataset
+
+        Returns:
+            str: show type
+        """
+        return "image"
+
+    def get_dataset_type(self) -> str:
+        """return the dataset type
+
+        Returns:
+            str: dataset type
+        """
+        return "MLClsDataset"

+ 18 - 0
paddlex/modules/multilabel_classification/dataset_checker/dataset_src/__init__.py

@@ -0,0 +1,18 @@
+# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+from .check_dataset import check
+from .split_dataset import split_dataset
+from .analyse_dataset import deep_analyse

+ 95 - 0
paddlex/modules/multilabel_classification/dataset_checker/dataset_src/analyse_dataset.py

@@ -0,0 +1,95 @@
+# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import os
+import json
+import math
+import platform
+from pathlib import Path
+
+from collections import defaultdict
+from PIL import Image
+import numpy as np
+import matplotlib.pyplot as plt
+from matplotlib import font_manager
+from matplotlib.backends.backend_agg import FigureCanvasAgg
+
+from .....utils.file_interface import custom_open
+from .....utils.fonts import PINGFANG_FONT_FILE_PATH
+
+
+def deep_analyse(dataset_path, output):
+    """class analysis for dataset"""
+    tags = ["train", "val"]
+    labels_cnt = defaultdict(str)
+    label_path = os.path.join(dataset_path, "label.txt")
+    with custom_open(label_path, "r") as f:
+        lines = f.readlines()
+    for line in lines:
+        line = line.strip().split()
+        labels_cnt[line[0]] = " ".join(line[1:])
+    for tag in tags:
+        anno_path = os.path.join(dataset_path, f"{tag}.txt")
+        classes_num = defaultdict(int)
+        for i in range(len(labels_cnt)):
+            classes_num[labels_cnt[str(i)]] = 0
+        with custom_open(anno_path, "r") as f:
+            lines = f.readlines()
+        for line in lines:
+            line = line.strip().split()
+            for i, label in enumerate(line[1].split(",")):
+                if label == "1":
+                    classes_num[labels_cnt[str(i)]] += 1
+        if tag == "train":
+            cnts_train = [cat_ids for cat_name, cat_ids in classes_num.items()]
+        elif tag == "val":
+            cnts_val = [cat_ids for cat_name, cat_ids in classes_num.items()]
+
+    classes = [cat_name for cat_name, cat_ids in classes_num.items()]
+    sorted_id = sorted(
+        range(len(cnts_train)), key=lambda k: cnts_train[k], reverse=True
+    )
+    cnts_train_sorted = [cnts_train[index] for index in sorted_id]
+    cnts_val_sorted = [cnts_val[index] for index in sorted_id]
+    classes_sorted = [classes[index] for index in sorted_id]
+    x = np.arange(len(classes))
+    width = 0.5
+
+    # bar
+    os_system = platform.system().lower()
+    if os_system == "windows":
+        plt.rcParams["font.sans-serif"] = "FangSong"
+    else:
+        font = font_manager.FontProperties(fname=PINGFANG_FONT_FILE_PATH, size=10)
+    fig, ax = plt.subplots(figsize=(max(8, int(len(classes) / 5)), 5), dpi=300)
+    ax.bar(x, cnts_train_sorted, width=0.5, label="train")
+    ax.bar(x + width, cnts_val_sorted, width=0.5, label="val")
+    plt.xticks(
+        x + width / 2,
+        classes_sorted,
+        rotation=90,
+        fontproperties=None if os_system == "windows" else font,
+    )
+    ax.set_xlabel(
+        "类别名称", fontproperties=None if os_system == "windows" else font, fontsize=12
+    )
+    ax.set_ylabel(
+        "图片数量", fontproperties=None if os_system == "windows" else font, fontsize=12
+    )
+    plt.legend(loc=1)
+    fig.tight_layout()
+    file_path = os.path.join(output, "histogram.png")
+    fig.savefig(file_path, dpi=300)
+
+    return {"histogram": os.path.join("check_dataset", "histogram.png")}

+ 131 - 0
paddlex/modules/multilabel_classification/dataset_checker/dataset_src/check_dataset.py

@@ -0,0 +1,131 @@
+# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import os
+import os.path as osp
+import random
+from PIL import Image, ImageOps
+from collections import defaultdict
+
+from .....utils.errors import DatasetFileNotFoundError, CheckFailedError
+from .utils.visualizer import draw_multi_label
+
+
+def check(dataset_dir, output, sample_num=10):
+    """check dataset"""
+    dataset_dir = osp.abspath(dataset_dir)
+    # Custom dataset
+    if not osp.exists(dataset_dir) or not osp.isdir(dataset_dir):
+        raise DatasetFileNotFoundError(file_path=dataset_dir)
+
+    tags = ["train", "val"]
+    delim = "\t"
+    valid_num_parts = 2
+
+    sample_cnts = dict()
+    label_map_dict = dict()
+    sample_paths = defaultdict(list)
+    labels = []
+
+    label_file = osp.join(dataset_dir, "label.txt")
+    if not osp.exists(label_file):
+        raise DatasetFileNotFoundError(
+            file_path=label_file,
+            solution=f"Ensure that `label.txt` exist in {dataset_dir}",
+        )
+
+    with open(label_file, "r", encoding="utf-8") as f:
+        all_lines = f.readlines()
+        for line in all_lines:
+            substr = line.strip("\n").split(" ", 1)
+            try:
+                label_idx = int(substr[0])
+                labels.append(label_idx)
+                label_map_dict[label_idx] = str(substr[1])
+            except:
+                raise CheckFailedError(
+                    f"Ensure that the first number in each line in {label_file} should be int."
+                )
+    if min(labels) != 0:
+        raise CheckFailedError(
+            f"Ensure that the index starts from 0 in `{label_file}`."
+        )
+
+    for tag in tags:
+        file_list = osp.join(dataset_dir, f"{tag}.txt")
+        if not osp.exists(file_list):
+            if tag in ("train", "val"):
+                # train and val file lists must exist
+                raise DatasetFileNotFoundError(
+                    file_path=file_list,
+                    solution=f"Ensure that both `train.txt` and `val.txt` exist in {dataset_dir}",
+                )
+            else:
+                # tag == 'test'
+                continue
+        else:
+            with open(file_list, "r", encoding="utf-8") as f:
+                all_lines = f.readlines()
+                random.seed(123)
+                random.shuffle(all_lines)
+                sample_cnts[tag] = len(all_lines)
+                for line in all_lines:
+                    substr = line.strip("\n").split(delim)
+                    if len(substr) != valid_num_parts:
+                        raise CheckFailedError(
+                            f"The number of delimiter-separated items in each row in {file_list} \
+                                    should be {valid_num_parts} (current delimiter is '{delim}')."
+                        )
+                    file_name = substr[0]
+                    label = substr[1]
+
+                    img_path = osp.join(dataset_dir, file_name)
+
+                    if not osp.exists(img_path):
+                        raise DatasetFileNotFoundError(file_path=img_path)
+
+                    vis_save_dir = osp.join(output, "demo_img")
+                    if not osp.exists(vis_save_dir):
+                        os.makedirs(vis_save_dir)
+
+                    if len(sample_paths[tag]) < sample_num:
+                        img = Image.open(img_path)
+                        img = ImageOps.exif_transpose(img)
+                        vis_im = draw_multi_label(img, label, label_map_dict)
+                        vis_path = osp.join(vis_save_dir, osp.basename(file_name))
+                        vis_im.save(vis_path)
+                        sample_path = osp.join(
+                            "check_dataset", os.path.relpath(vis_path, output)
+                        )
+                        sample_paths[tag].append(sample_path)
+
+                    try:
+                        label = list(map(int, label.split(",")))
+                    except (ValueError, TypeError) as e:
+                        raise CheckFailedError(
+                            f"Ensure that the second number in each line in {label_file} should be int."
+                        ) from e
+
+    num_classes = max(labels) + 1
+
+    attrs = {}
+    attrs["label_file"] = osp.relpath(label_file, output)
+    attrs["num_classes"] = num_classes
+    attrs["train_samples"] = sample_cnts["train"]
+    attrs["train_sample_paths"] = sample_paths["train"]
+
+    attrs["val_samples"] = sample_cnts["val"]
+    attrs["val_sample_paths"] = sample_paths["val"]
+
+    return attrs

+ 81 - 0
paddlex/modules/multilabel_classification/dataset_checker/dataset_src/split_dataset.py

@@ -0,0 +1,81 @@
+# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+import os
+from random import shuffle
+
+from .....utils.file_interface import custom_open
+
+
+def split_dataset(root_dir, train_rate, val_rate):
+    """
+    将图像数据集按照比例分成训练集、验证集和测试集,并生成对应的.txt文件。
+
+    Args:
+        root_dir (str): 数据集根目录路径。
+        train_rate (int): 训练集占总数据集的比例(%)。
+        val_rate (int): 验证集占总数据集的比例(%)。
+
+    Returns:
+        str: 数据划分结果信息。
+    """
+    sum_rate = train_rate + val_rate
+    assert (
+        sum_rate == 100
+    ), f"The sum of train_rate({train_rate}), val_rate({val_rate}) should equal 100!"
+    assert (
+        train_rate > 0 and val_rate > 0
+    ), f"The train_rate({train_rate}) and val_rate({val_rate}) should be greater than 0!"
+    tags = ["train", "val"]
+    valid_path = False
+    image_files = []
+    for tag in tags:
+        split_image_list = os.path.abspath(os.path.join(root_dir, f"{tag}.txt"))
+        rename_image_list = os.path.abspath(os.path.join(root_dir, f"{tag}.txt.bak"))
+        if os.path.exists(split_image_list):
+            with custom_open(split_image_list, "r") as f:
+                lines = f.readlines()
+            image_files = image_files + lines
+            valid_path = True
+            if not os.path.exists(rename_image_list):
+                os.rename(split_image_list, rename_image_list)
+
+    assert (
+        valid_path
+    ), f"The files to be divided{tags[0]}.txt, {tags[1]}.txt, do not exist in the dataset directory."
+
+    shuffle(image_files)
+    start = 0
+    image_num = len(image_files)
+    rate_list = [train_rate, val_rate]
+    for i, tag in enumerate(tags):
+
+        rate = rate_list[i]
+        if rate == 0:
+            continue
+
+        end = start + round(image_num * rate / 100)
+        if sum(rate_list[i + 1 :]) == 0:
+            end = image_num
+
+        txt_file = os.path.abspath(os.path.join(root_dir, tag + ".txt"))
+        with custom_open(txt_file, "w") as f:
+            m = 0
+            for id in range(start, end):
+                m += 1
+                f.write(image_files[id])
+        start = end
+
+    return root_dir

+ 13 - 0
paddlex/modules/multilabel_classification/dataset_checker/dataset_src/utils/__init__.py

@@ -0,0 +1,13 @@
+# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.

+ 152 - 0
paddlex/modules/multilabel_classification/dataset_checker/dataset_src/utils/visualizer.py

@@ -0,0 +1,152 @@
+# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import os
+import numpy as np
+import json
+from pathlib import Path
+import PIL
+from PIL import Image, ImageDraw, ImageFont
+
+from ......utils.fonts import PINGFANG_FONT_FILE_PATH
+
+
+def colormap(rgb=False):
+    """
+    Get colormap
+    """
+    color_list = np.array(
+        [
+            0xFF,
+            0x00,
+            0x00,
+            0xCC,
+            0xFF,
+            0x00,
+            0x00,
+            0xFF,
+            0x66,
+            0x00,
+            0x66,
+            0xFF,
+            0xCC,
+            0x00,
+            0xFF,
+            0xFF,
+            0x4D,
+            0x00,
+            0x80,
+            0xFF,
+            0x00,
+            0x00,
+            0xFF,
+            0xB2,
+            0x00,
+            0x1A,
+            0xFF,
+            0xFF,
+            0x00,
+            0xE5,
+            0xFF,
+            0x99,
+            0x00,
+            0x33,
+            0xFF,
+            0x00,
+            0x00,
+            0xFF,
+            0xFF,
+            0x33,
+            0x00,
+            0xFF,
+            0xFF,
+            0x00,
+            0x99,
+            0xFF,
+            0xE5,
+            0x00,
+            0x00,
+            0xFF,
+            0x1A,
+            0x00,
+            0xB2,
+            0xFF,
+            0x80,
+            0x00,
+            0xFF,
+            0xFF,
+            0x00,
+            0x4D,
+        ]
+    ).astype(np.float32)
+    color_list = color_list.reshape((-1, 3))
+    if not rgb:
+        color_list = color_list[:, ::-1]
+    return color_list.astype("int32")
+
+
+def font_colormap(color_index):
+    """
+    Get font colormap
+    """
+    dark = np.array([0x14, 0x0E, 0x35])
+    light = np.array([0xFF, 0xFF, 0xFF])
+    light_indexs = [0, 3, 4, 8, 9, 13, 14, 18, 19]
+    if color_index in light_indexs:
+        return light.astype("int32")
+    else:
+        return dark.astype("int32")
+
+def draw_multi_label(image, label, label_map_dict):
+    labels = label.split(",")
+    label_names = [
+        label_map_dict[i] for i, label in enumerate(labels) if int(label) == 1
+    ]
+    image = image.convert("RGB")
+    image_width, image_height = image.size
+    font_size = int(image_width * 0.06)
+
+    font = ImageFont.truetype(PINGFANG_FONT_FILE_PATH, font_size)
+    text_lines = []
+    row_width = 0
+    row_height = 0
+    row_text = "\t"
+    for label_name in label_names:
+        text = f"{label_name}\t"
+        text_width, row_height = font.getsize(text)
+        if row_width + text_width <= image_width:
+            row_text += text
+            row_width += text_width
+        else:
+            text_lines.append(row_text)
+            row_text = "\t" + text
+            row_width = text_width
+    text_lines.append(row_text)
+    color_list = colormap(rgb=True)
+    color = tuple(color_list[0])
+    new_image_height = image_height + len(text_lines) * int(row_height * 1.2)
+    new_image = Image.new("RGB", (image_width, new_image_height), color)
+    new_image.paste(image, (0, 0))
+
+    draw = ImageDraw.Draw(new_image)
+    font_color = tuple(font_colormap(3))
+    for i, text in enumerate(text_lines):
+        text_width, _ = font.getsize(text)
+        draw.text(
+            (0, image_height + i * int(row_height * 1.2)),
+            text,
+            fill=font_color,
+            font=font,
+        )
+    return new_image

+ 43 - 0
paddlex/modules/multilabel_classification/evaluator.py

@@ -0,0 +1,43 @@
+# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from ..base import BaseEvaluator
+from .model_list import MODELS
+
+
+class MlClsEvaluator(BaseEvaluator):
+    """Image Classification Model Evaluator"""
+
+    entities = MODELS
+
+    def update_config(self):
+        """update evalution config"""
+        if self.eval_config.log_interval:
+            self.pdx_config.update_log_interval(self.eval_config.log_interval)
+        if self.pdx_config["Arch"]["name"] == "DistillationModel":
+            self.pdx_config.update_teacher_model(pretrained=False)
+            self.pdx_config.update_student_model(pretrained=False)
+        self.pdx_config.update_dataset(self.global_config.dataset_dir, "MLClsDataset")
+        self.pdx_config.update_pretrained_weights(self.eval_config.weight_path)
+
+    def get_eval_kwargs(self) -> dict:
+        """get key-value arguments of model evalution function
+
+        Returns:
+            dict: the arguments of evaluation function.
+        """
+        return {
+            "weight_path": self.eval_config.weight_path,
+            "device": self.get_device(using_device_number=1),
+        }

+ 22 - 0
paddlex/modules/multilabel_classification/exportor.py

@@ -0,0 +1,22 @@
+# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from ..base import BaseExportor
+from .model_list import MODELS
+
+
+class MlClsExportor(BaseExportor):
+    """Image Classification Model Exportor"""
+
+    entities = MODELS

+ 22 - 0
paddlex/modules/multilabel_classification/model_list.py

@@ -0,0 +1,22 @@
+# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+MODELS = [
+    "ResNet50_ML",
+    "PP-LCNet_x1_0_ML",
+    "PP-HGNetV2-B0_ML",
+    "PP-HGNetV2-B4_ML",
+    "PP-HGNetV2-B6_ML",
+    "CLIP_vit_base_patch16_448_ML",
+]

+ 17 - 0
paddlex/modules/multilabel_classification/predictor/__init__.py

@@ -0,0 +1,17 @@
+# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from .predictor import ClsPredictor
+from .predictor_ml import MLClsPredictor
+from . import transforms

+ 29 - 0
paddlex/modules/multilabel_classification/predictor/keys.py

@@ -0,0 +1,29 @@
+# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+class ClsKeys(object):
+    """
+    This class defines a set of keys used for communication of Cls predictors
+    and transforms. Both predictors and transforms accept a dict or a list of
+    dicts as input, and they get the objects of their interest from the dict, or
+    put the generated objects into the dict, all based on these keys.
+    """
+
+    # Common keys
+    IMAGE = "image"
+    IM_PATH = "input_path"
+    # Suite-specific keys
+    CLS_PRED = "cls_pred"
+    CLS_RESULT = "cls_result"

+ 82 - 0
paddlex/modules/multilabel_classification/predictor/predictor.py

@@ -0,0 +1,82 @@
+# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+import os
+import numpy as np
+from pathlib import Path
+
+from ...base import BasePredictor
+from ...base.predictor.transforms import image_common
+from .keys import ClsKeys as K
+from .utils import InnerConfig
+from ....utils import logging
+from . import transforms as T
+from ..model_list import MODELS
+
+
+class MLClsPredictor(BasePredictor):
+    """Clssification Predictor"""
+
+    entities = MODELS
+
+    def load_other_src(self):
+        """load the inner config file"""
+        infer_cfg_file_path = os.path.join(self.model_dir, "inference.yml")
+        if not os.path.exists(infer_cfg_file_path):
+            raise FileNotFoundError(f"Cannot find config file: {infer_cfg_file_path}")
+        return InnerConfig(infer_cfg_file_path)
+
+    @classmethod
+    def get_input_keys(cls):
+        """get input keys"""
+        return [[K.IMAGE], [K.IM_PATH]]
+
+    @classmethod
+    def get_output_keys(cls):
+        """get output keys"""
+        return [K.CLS_PRED]
+
+    def _run(self, batch_input):
+        """run"""
+        input_dict = {}
+        input_dict[K.IMAGE] = np.stack(
+            [data[K.IMAGE] for data in batch_input], axis=0
+        ).astype(dtype=np.float32, copy=False)
+        input_ = [input_dict[K.IMAGE]]
+        outputs = self._predictor.predict(input_)
+        cls_outs = outputs[0]
+        # In-place update
+        pred = batch_input
+        for dict_, cls_out in zip(pred, cls_outs):
+            dict_[K.CLS_PRED] = cls_out
+        return pred
+
+    def _get_pre_transforms_from_config(self):
+        """get preprocess transforms"""
+        logging.info(
+            f"Transformation operators for data preprocessing will be inferred from config file."
+        )
+        pre_transforms = self.other_src.pre_transforms
+        pre_transforms.insert(0, image_common.ReadImage(format="RGB"))
+        return pre_transforms
+
+    def _get_post_transforms_from_config(self):
+        """ get postprocess transforms """
+        post_transforms = self.other_src.post_transforms
+        post_transforms.extend([
+            T.PrintResult(), T.SaveMLClsResults(self.output,
+                                                self.other_src.labels)
+        ])
+        return post_transforms

+ 259 - 0
paddlex/modules/multilabel_classification/predictor/transforms.py

@@ -0,0 +1,259 @@
+# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import os
+import json
+from pathlib import Path
+import numpy as np
+import PIL
+from PIL import ImageDraw, ImageFont, Image
+
+from .keys import ClsKeys as K
+from ...base import BaseTransform
+from ...base.predictor.io import ImageWriter, ImageReader
+from ....utils.fonts import PINGFANG_FONT_FILE_PATH
+from ....utils import logging
+
+__all__ = [
+    "PrintResult",
+    "SaveClsResults",
+    "MultiLabelThreshOutput",
+]
+
+
+def _parse_class_id_map(class_ids):
+    """parse class id to label map file"""
+    if class_ids is None:
+        return None
+    class_id_map = {id: str(lb) for id, lb in enumerate(class_ids)}
+    return class_id_map
+
+
+class PrintResult(BaseTransform):
+    """Print Result Transform"""
+
+    def apply(self, data):
+        """apply"""
+        logging.info("The prediction result is:")
+        logging.info(data[K.CLS_RESULT])
+        return data
+
+    @classmethod
+    def get_input_keys(cls):
+        """get input keys"""
+        return [K.CLS_RESULT]
+
+    @classmethod
+    def get_output_keys(cls):
+        """get output keys"""
+        return []
+
+
+class SaveMLClsResults(BaseTransform):
+    def __init__(self, save_dir, class_ids=None):
+        super().__init__(save_dir=save_dir)
+        self.save_dir = save_dir
+        self.class_id_map = _parse_class_id_map(class_ids)
+        self._writer = ImageWriter(backend="pillow")
+
+    def _get_colormap(self, rgb=False):
+        """
+        Get colormap
+        """
+        color_list = np.array(
+            [
+                0xFF,
+                0x00,
+                0x00,
+                0xCC,
+                0xFF,
+                0x00,
+                0x00,
+                0xFF,
+                0x66,
+                0x00,
+                0x66,
+                0xFF,
+                0xCC,
+                0x00,
+                0xFF,
+                0xFF,
+                0x4D,
+                0x00,
+                0x80,
+                0xFF,
+                0x00,
+                0x00,
+                0xFF,
+                0xB2,
+                0x00,
+                0x1A,
+                0xFF,
+                0xFF,
+                0x00,
+                0xE5,
+                0xFF,
+                0x99,
+                0x00,
+                0x33,
+                0xFF,
+                0x00,
+                0x00,
+                0xFF,
+                0xFF,
+                0x33,
+                0x00,
+                0xFF,
+                0xFF,
+                0x00,
+                0x99,
+                0xFF,
+                0xE5,
+                0x00,
+                0x00,
+                0xFF,
+                0x1A,
+                0x00,
+                0xB2,
+                0xFF,
+                0x80,
+                0x00,
+                0xFF,
+                0xFF,
+                0x00,
+                0x4D,
+            ]
+        ).astype(np.float32)
+        color_list = color_list.reshape((-1, 3))
+        if not rgb:
+            color_list = color_list[:, ::-1]
+        return color_list.astype("int32")
+
+    def _get_font_colormap(self, color_index):
+        """
+        Get font colormap
+        """
+        dark = np.array([0x14, 0x0E, 0x35])
+        light = np.array([0xFF, 0xFF, 0xFF])
+        light_indexs = [0, 3, 4, 8, 9, 13, 14, 18, 19]
+        if color_index in light_indexs:
+            return light.astype("int32")
+        else:
+            return dark.astype("int32")
+
+    def apply(self, data):
+        """Draw label on image"""
+        ori_path = data[K.IM_PATH]
+        results = data[K.CLS_RESULT]
+        scores = results[0]["scores"]
+        label_names = results[0]["label_names"]
+        file_name = os.path.basename(ori_path)
+        save_path = os.path.join(self.save_dir, file_name)
+        image = ImageReader(backend="pil").read(ori_path)
+        image = image.convert("RGB")
+        image_width, image_height = image.size
+        font_size = int(image_width * 0.06)
+
+        font = ImageFont.truetype(PINGFANG_FONT_FILE_PATH, font_size)
+        text_lines = []
+        row_width = 0
+        row_height = 0
+        row_text = "\t"
+        for label_name, score in zip(label_names, scores):
+            text = f"{label_name}({score})\t"
+            text_width, row_height = font.getsize(text)
+            if row_width + text_width <= image_width:
+                row_text += text
+                row_width += text_width
+            else:
+                text_lines.append(row_text)
+                row_text = "\t" + text
+                row_width = text_width
+        text_lines.append(row_text)
+        color_list = self._get_colormap(rgb=True)
+        color = tuple(color_list[0])
+        new_image_height = image_height + len(text_lines) * int(row_height * 1.2)
+        new_image = Image.new("RGB", (image_width, new_image_height), color)
+        new_image.paste(image, (0, 0))
+
+        draw = ImageDraw.Draw(new_image)
+        font_color = tuple(self._get_font_colormap(3))
+        for i, text in enumerate(text_lines):
+            text_width, _ = font.getsize(text)
+            draw.text(
+                (0, image_height + i * int(row_height * 1.2)),
+                text,
+                fill=font_color,
+                font=font,
+            )
+        self._write_image(save_path, new_image)
+        return data
+
+    def _write_image(self, path, image):
+        """write image"""
+        if os.path.exists(path):
+            logging.warning(f"{path} already exists. Overwriting it.")
+        self._writer.write(path, image)
+
+    @classmethod
+    def get_input_keys(cls):
+        """get input keys"""
+        return [K.IM_PATH, K.CLS_PRED]
+
+    @classmethod
+    def get_output_keys(cls):
+        """get output keys"""
+        return []
+
+
+class MultiLabelThreshOutput(BaseTransform):
+    def __init__(self, threshold=0.5, class_ids=None, delimiter=None):
+        super().__init__()
+        assert isinstance(threshold, (float,))
+        self.threshold = threshold
+        self.delimiter = delimiter if delimiter is not None else " "
+        self.class_id_map = _parse_class_id_map(class_ids)
+
+    def apply(self, data):
+        """apply"""
+        y = []
+        x = data[K.CLS_PRED]
+        pred_index = np.where(x >= self.threshold)[0].astype("int32")
+        index = pred_index[np.argsort(x[pred_index])][::-1]
+        clas_id_list = []
+        score_list = []
+        label_name_list = []
+        for i in index:
+            clas_id_list.append(i.item())
+            score_list.append(x[i].item())
+            if self.class_id_map is not None:
+                label_name_list.append(self.class_id_map[i.item()])
+        result = {
+            "class_ids": clas_id_list,
+            "scores": np.around(score_list, decimals=5).tolist(),
+            "label_names": label_name_list,
+        }
+        y.append(result)
+        data[K.CLS_RESULT] = y
+        return data
+
+    @classmethod
+    def get_input_keys(cls):
+        """get input keys"""
+        return [K.IM_PATH, K.CLS_PRED]
+
+    @classmethod
+    def get_output_keys(cls):
+        """get output keys"""
+        return [K.CLS_RESULT]

+ 105 - 0
paddlex/modules/multilabel_classification/predictor/utils.py

@@ -0,0 +1,105 @@
+# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import codecs
+
+import yaml
+
+from ...base.predictor.transforms import image_common
+from . import transforms as T
+
+
+class InnerConfig(object):
+    """Inner Config"""
+
+    def __init__(self, config_path):
+        self.inner_cfg = self.load(config_path)
+
+    def load(self, config_path):
+        """load infer config"""
+        with codecs.open(config_path, "r", "utf-8") as file:
+            dic = yaml.load(file, Loader=yaml.FullLoader)
+        return dic
+
+    @property
+    def pre_transforms(self):
+        """read preprocess transforms from config file"""
+        if "RecPreProcess" in list(self.inner_cfg.keys()):
+            tfs_cfg = self.inner_cfg["RecPreProcess"]["transform_ops"]
+        else:
+            tfs_cfg = self.inner_cfg["PreProcess"]["transform_ops"]
+        tfs = []
+        for cfg in tfs_cfg:
+            tf_key = list(cfg.keys())[0]
+            if tf_key == "NormalizeImage":
+                tf = image_common.Normalize(
+                    mean=cfg["NormalizeImage"].get("mean", [0.485, 0.456, 0.406]),
+                    std=cfg["NormalizeImage"].get("std", [0.229, 0.224, 0.225]),
+                )
+            elif tf_key == "ResizeImage":
+                if "resize_short" in list(cfg[tf_key].keys()):
+                    tf = image_common.ResizeByShort(
+                        target_short_edge=cfg["ResizeImage"].get("resize_short", 224),
+                        size_divisor=None,
+                        interp="LINEAR",
+                    )
+                else:
+                    tf = image_common.Resize(
+                        target_size=cfg["ResizeImage"].get("size", (224, 224))
+                    )
+            elif tf_key == "CropImage":
+                tf = image_common.Crop(crop_size=cfg["CropImage"].get("size", 224))
+            elif tf_key == "ToCHWImage":
+                tf = image_common.ToCHWImage()
+            else:
+                raise RuntimeError(f"Unsupported type: {tf_key}")
+            tfs.append(tf)
+        return tfs
+
+    @property
+    def post_transforms(self):
+        """read postprocess transforms from config file"""
+        IGNORE_OPS = ["main_indicator", "SavePreLabel"]
+        tfs_cfg = self.inner_cfg["PostProcess"]
+        tfs = []
+        for tf_key in tfs_cfg:
+            if tf_key == "Topk":
+                tf = T.Topk(
+                    topk=tfs_cfg["Topk"]["topk"],
+                    class_ids=tfs_cfg["Topk"].get("label_list", None),
+                )
+            elif tf_key == "MultiLabelThreshOutput":
+                tf = T.MultiLabelThreshOutput(
+                    threshold=0.5,
+                    class_ids=tfs_cfg["MultiLabelThreshOutput"].get("label_list", None),
+                )
+            elif tf_key in IGNORE_OPS:
+                continue
+            else:
+                raise RuntimeError(f"Unsupported type: {tf_key}")
+            tfs.append(tf)
+        return tfs
+
+    @property
+    def labels(self):
+        """the labels in inner config"""
+        postprocess_name = self.inner_cfg["PostProcess"].keys()
+        if "Topk" in postprocess_name:
+            return self.inner_cfg["PostProcess"]["Topk"].get("label_list", None)
+        elif "MultiLabelThreshOutput" in postprocess_name:
+            return self.inner_cfg["PostProcess"]["MultiLabelThreshOutput"].get(
+                "label_list", None
+            )
+        else:
+            return None

+ 147 - 0
paddlex/modules/multilabel_classification/trainer.py

@@ -0,0 +1,147 @@
+# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import json
+import shutil
+import paddle
+from pathlib import Path
+
+from ..base import BaseTrainer, BaseTrainDeamon
+from .model_list import MODELS
+from ...utils.config import AttrDict
+
+
+class MLClsTrainer(BaseTrainer):
+    """Image Classification Model Trainer"""
+
+    entities = MODELS
+
+    def dump_label_dict(self, src_label_dict_path: str):
+        """dump label dict config
+
+        Args:
+            src_label_dict_path (str): path to label dict file to be saved.
+        """
+        dst_label_dict_path = Path(self.global_config.output).joinpath("label_dict.txt")
+        shutil.copyfile(src_label_dict_path, dst_label_dict_path)
+
+    def build_deamon(self, config: AttrDict) -> "ClsTrainDeamon":
+        """build deamon thread for saving training outputs timely
+
+        Args:
+            config (AttrDict): PaddleX pipeline config, which is loaded from pipeline yaml file.
+
+        Returns:
+            ClsTrainDeamon: the training deamon thread object for saving training outputs timely.
+        """
+        return ClsTrainDeamon(config)
+
+    def update_config(self):
+        """update training config
+        """
+        if self.train_config.log_interval:
+            self.pdx_config.update_log_interval(self.train_config.log_interval)
+        if self.train_config.eval_interval:
+            self.pdx_config.update_eval_interval(
+                self.train_config.eval_interval)
+        if self.train_config.save_interval:
+            self.pdx_config.update_save_interval(
+                self.train_config.save_interval)
+
+        self.pdx_config.update_dataset(self.global_config.dataset_dir,
+                                       "MLClsDataset")
+        if self.train_config.num_classes is not None:
+            self.pdx_config.update_num_classes(self.train_config.num_classes)
+        if self.train_config.pretrain_weight_path and self.train_config.pretrain_weight_path != "":
+            self.pdx_config.update_pretrained_weights(
+                self.train_config.pretrain_weight_path)
+
+        label_dict_path = Path(self.global_config.dataset_dir).joinpath(
+            "label.txt")
+        if label_dict_path.exists():
+            self.dump_label_dict(label_dict_path)
+        if self.train_config.batch_size is not None:
+            self.pdx_config.update_batch_size(self.train_config.batch_size)
+        if self.train_config.learning_rate is not None:
+            self.pdx_config.update_learning_rate(
+                self.train_config.learning_rate)
+        if self.train_config.epochs_iters is not None:
+            self.pdx_config._update_epochs(self.train_config.epochs_iters)
+        if self.train_config.warmup_steps is not None:
+            self.pdx_config.update_warmup_epochs(self.train_config.warmup_steps)
+        if self.global_config.output is not None:
+            self.pdx_config._update_output_dir(self.global_config.output)
+
+    def get_train_kwargs(self) -> dict:
+        """get key-value arguments of model training function
+
+        Returns:
+            dict: the arguments of training function.
+        """
+        train_args = {"device": self.get_device()}
+        if (
+            self.train_config.resume_path is not None
+            and self.train_config.resume_path != ""
+        ):
+            train_args["resume_path"] = self.train_config.resume_path
+        train_args["dy2st"] = self.train_config.get("dy2st", False)
+        return train_args
+
+
+class ClsTrainDeamon(BaseTrainDeamon):
+    """ClsTrainResultDemon"""
+
+    def __init__(self, *args, **kwargs):
+        super().__init__(*args, **kwargs)
+
+    def get_the_pdparams_suffix(self):
+        """get the suffix of pdparams file"""
+        return "pdparams"
+
+    def get_the_pdema_suffix(self):
+        """get the suffix of pdema file"""
+        return "pdema"
+
+    def get_the_pdopt_suffix(self):
+        """get the suffix of pdopt file"""
+        return "pdopt"
+
+    def get_the_pdstates_suffix(self):
+        """get the suffix of pdstates file"""
+        return "pdstates"
+
+    def get_ith_ckp_prefix(self, epoch_id):
+        """get the prefix of the epoch_id checkpoint file"""
+        return f"epoch_{epoch_id}"
+
+    def get_best_ckp_prefix(self):
+        """get the prefix of the best checkpoint file"""
+        return "best_model"
+
+    def get_score(self, pdstates_path):
+        """get the score by pdstates file"""
+        if not Path(pdstates_path).exists():
+            return 0
+        return paddle.load(pdstates_path)["metric"]
+
+    def get_epoch_id_by_pdparams_prefix(self, pdparams_prefix):
+        """get the epoch_id by pdparams file"""
+        return int(pdparams_prefix.split("_")[-1])
+
+    def update_label_dict(self, train_output):
+        """update label dict"""
+        dict_path = train_output.joinpath("label_dict.txt")
+        if not dict_path.exists():
+            return ""
+        return dict_path