Преглед изворни кода

fix bug in multi-label and add convert.py (#2029)

liuhongen1234567 пре 1 година
родитељ
комит
f5c8ef9c1d

+ 9 - 0
paddlex/modules/__init__.py

@@ -29,6 +29,15 @@ from .image_classification import (
     ClsExportor,
     ClsPredictor,
 )
+
+from .multilabel_classification import (
+    MLClsDatasetChecker,
+    MLClsTrainer,
+    MLClsEvaluator,
+    MLClsExportor,
+    MLClsPredictor,
+)
+
 from .anomaly_detection import (
     UadDatasetChecker,
     UadTrainer,

+ 2 - 2
paddlex/modules/multilabel_classification/__init__.py

@@ -14,6 +14,6 @@
 
 from .trainer import MLClsTrainer
 from .dataset_checker import MLClsDatasetChecker
-from .evaluator import MlClsEvaluator
-from .exportor import MlClsExportor
+from .evaluator import MLClsEvaluator
+from .exportor import MLClsExportor
 from .predictor import MLClsPredictor, transforms

+ 4 - 2
paddlex/modules/multilabel_classification/dataset_checker/__init__.py

@@ -15,7 +15,7 @@
 from pathlib import Path
 
 from ...base import BaseDatasetChecker
-from .dataset_src import check, split_dataset, deep_analyse
+from .dataset_src import check, convert, split_dataset, deep_analyse
 from ..model_list import MODELS
 
 
@@ -48,7 +48,9 @@ class MLClsDatasetChecker(BaseDatasetChecker):
         Returns:
             str: the root directory of converted dataset.
         """
-        return src_dataset_dir
+        return convert(
+            self.check_dataset_config.convert.src_dataset_type, src_dataset_dir
+        )
 
     def split_dataset(self, src_dataset_dir: str) -> str:
         """repartition the train and validation dataset

+ 1 - 0
paddlex/modules/multilabel_classification/dataset_checker/dataset_src/__init__.py

@@ -14,5 +14,6 @@
 
 
 from .check_dataset import check
+from .convert_dataset import convert
 from .split_dataset import split_dataset
 from .analyse_dataset import deep_analyse

+ 114 - 0
paddlex/modules/multilabel_classification/dataset_checker/dataset_src/convert_dataset.py

@@ -0,0 +1,114 @@
+# copyright (c) 2024 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+import os
+import json
+from tqdm import tqdm
+from pycocotools.coco import COCO
+from .....utils.errors import ConvertFailedError
+from .....utils.logging import info, warning
+
+
+def check_src_dataset(root_dir, dataset_type):
+    """check src dataset format validity"""
+    if dataset_type in ("COCO"):
+        anno_suffix = ".json"
+    else:
+        raise ConvertFailedError(
+            message=f"数据格式转换失败!不支持{dataset_type}格式数据集。当前仅支持 COCO 格式。"
+        )
+
+    err_msg_prefix = f"数据格式转换失败!请参考上述`{dataset_type}格式数据集示例`检查待转换数据集格式。"
+
+    for anno in ["annotations/instance_train.json", "annotations/instance_val.json"]:
+        src_anno_path = os.path.join(root_dir, anno)
+        if not os.path.exists(src_anno_path):
+            raise ConvertFailedError(
+                message=f"{err_msg_prefix}保证{src_anno_path}文件存在。"
+            )
+    return None
+
+
+def convert(dataset_type, input_dir):
+    """convert dataset to multilabel format"""
+    # check format validity
+    check_src_dataset(input_dir, dataset_type)
+    
+    if dataset_type in ("COCO"):
+        convert_coco_dataset(input_dir)
+    else:
+        raise ConvertFailedError(
+            message=f"数据格式转换失败!不支持{dataset_type}格式数据集。当前仅支持 COCO 格式。"
+        )
+
+
+def convert_coco_dataset(root_dir):
+    for anno in  ["annotations/instance_train.json", "annotations/instance_val.json"]:
+        src_img_dir = root_dir
+        src_anno_path = os.path.join(root_dir, anno)
+        coco2multilabels(src_img_dir, src_anno_path, root_dir)
+
+
+def coco2multilabels(src_img_dir, src_anno_path, root_dir):
+    image_dir = os.path.join(root_dir, "images")
+    label_type = os.path.basename(src_anno_path).replace("instance_","").replace(".json","")
+    anno_save_path = os.path.join(root_dir, "{}.txt".format(label_type))  
+    coco = COCO(src_anno_path)
+    cat_id_map = {
+        old_cat_id: new_cat_id
+        for new_cat_id, old_cat_id in enumerate(coco.getCatIds())
+    }
+    num_classes = len(list(cat_id_map.keys()))
+
+    with open(anno_save_path, 'w') as fp:
+        lines = []
+        for img_id in tqdm(sorted(coco.getImgIds())):
+            img_info = coco.loadImgs([img_id])[0]
+            img_filename = img_info['file_name']
+            img_w = img_info['width']
+            img_h = img_info['height']
+
+            img_filepath = os.path.join(image_dir, img_filename)
+            if not os.path.exists(img_filepath):
+                warning('Illegal image file: {}, '
+                               'and it will be ignored'.format(img_filepath))
+                continue
+
+            if img_w < 0 or img_h < 0:
+                warning(msg)(
+                    'Illegal width: {} or height: {} in annotation, '
+                    'and im_id: {} will be ignored'.format(img_w, img_h, img_id))
+                continue
+
+            ins_anno_ids = coco.getAnnIds(imgIds=[img_id])
+            instances = coco.loadAnns(ins_anno_ids)
+
+            label = [0] * num_classes
+            for instance in instances:
+                label[cat_id_map[instance['category_id']]] = 1
+            img_filename = os.path.join("images", img_filename)
+            fp.writelines("{}\t{}\n".format(img_filename, ','.join(map(str, label))))
+        fp.close()
+    if label_type == "train":
+        label_txt_save_path = os.path.join(root_dir, "label.txt")
+        with open(label_txt_save_path, 'w') as fp:
+            label_name_list = []
+            for cat in coco.cats.values():
+                id = cat["id"]
+                name = cat["name"]
+                fp.writelines("{} {}\n".format(id, name))
+            fp.close()
+            info("Save label names to {}.".format(label_txt_save_path))
+

+ 1 - 1
paddlex/modules/multilabel_classification/evaluator.py

@@ -16,7 +16,7 @@ from ..base import BaseEvaluator
 from .model_list import MODELS
 
 
-class MlClsEvaluator(BaseEvaluator):
+class MLClsEvaluator(BaseEvaluator):
     """Image Classification Model Evaluator"""
 
     entities = MODELS

+ 1 - 1
paddlex/modules/multilabel_classification/exportor.py

@@ -16,7 +16,7 @@ from ..base import BaseExportor
 from .model_list import MODELS
 
 
-class MlClsExportor(BaseExportor):
+class MLClsExportor(BaseExportor):
     """Image Classification Model Exportor"""
 
     entities = MODELS

+ 1 - 2
paddlex/modules/multilabel_classification/predictor/__init__.py

@@ -12,6 +12,5 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-from .predictor import ClsPredictor
-from .predictor_ml import MLClsPredictor
+from .predictor import MLClsPredictor
 from . import transforms

+ 1 - 1
paddlex/modules/multilabel_classification/predictor/transforms.py

@@ -62,7 +62,7 @@ class PrintResult(BaseTransform):
 
 class SaveMLClsResults(BaseTransform):
     def __init__(self, save_dir, class_ids=None):
-        super().__init__(save_dir=save_dir)
+        super().__init__()
         self.save_dir = save_dir
         self.class_id_map = _parse_class_id_map(class_ids)
         self._writer = ImageWriter(backend="pillow")