Преглед на файлове

[WIP]adapt to aistudio (#2560)

* fix config input bug

* Adapt to aistudio
zhangyubo0722 преди 11 месеца
родител
ревизия
10cdbcda49

+ 8 - 1
paddlex/modules/base/dataset_checker/dataset_checker.py

@@ -32,6 +32,13 @@ def build_dataset_checker(config: AttrDict) -> "BaseDatasetChecker":
         BaseDatasetChecker: the dataset checker, which is subclass of BaseDatasetChecker.
     """
     model_name = config.Global.model
+    try:
+        import feature_line_modules
+    except ModuleNotFoundError:
+        info(
+            "The PaddleX FeaTure Line plugin is not installed, but continuing execution."
+        )
+
     return BaseDatasetChecker.get(model_name)(config)
 
 
@@ -77,7 +84,7 @@ class BaseDatasetChecker(ABC, metaclass=AutoRegisterABCMetaClass):
         check_result = build_res_dict(True)
         check_result["attributes"] = attrs
         check_result["analysis"] = analysis
-        check_result["dataset_path"] = self.global_config.dataset_dir
+        check_result["dataset_path"] = os.path.basename(dataset_dir)
         check_result["show_type"] = self.get_show_type()
         check_result["dataset_type"] = self.get_dataset_type()
         info("Check dataset passed !")

+ 9 - 3
paddlex/modules/base/evaluator.py

@@ -33,6 +33,12 @@ def build_evaluater(config: AttrDict) -> "BaseEvaluator":
         BaseEvaluator: the evaluater, which is subclass of BaseEvaluator.
     """
     model_name = config.Global.model
+    try:
+        import feature_line_modules
+    except ModuleNotFoundError:
+        info(
+            "The PaddleX FeaTure Line plugin is not installed, but continuing execution."
+        )
     return BaseEvaluator.get(model_name)(config)
 
 
@@ -51,9 +57,9 @@ class BaseEvaluator(ABC, metaclass=AutoRegisterABCMetaClass):
         self.global_config = config.Global
         self.eval_config = config.Evaluate
 
-        config_path = self.eval_config.get("basic_config_path", None)
-        if not config_path:
-            config_path = self.get_config_path(self.eval_config.weight_path)
+        config_path = self.get_config_path(self.eval_config.weight_path)
+        if self.eval_config.get("basic_config_path", None):
+            config_path = self.eval_config.get("basic_config_path", None)
 
         self.pdx_config, self.pdx_model = build_model(
             self.global_config.model, config_path=config_path

+ 9 - 3
paddlex/modules/base/exportor.py

@@ -33,6 +33,12 @@ def build_exportor(config: AttrDict) -> "BaseExportor":
         BaseExportor: the exportor, which is subclass of BaseExportor.
     """
     model_name = config.Global.model
+    try:
+        import feature_line_modules
+    except ModuleNotFoundError:
+        logging.info(
+            "The PaddleX FeaTure Line plugin is not installed, but continuing execution."
+        )
     return BaseExportor.get(model_name)(config)
 
 
@@ -51,9 +57,9 @@ class BaseExportor(ABC, metaclass=AutoRegisterABCMetaClass):
         self.global_config = config.Global
         self.export_config = config.Export
 
-        config_path = self.export_config.get("basic_config_path", None)
-        if not config_path:
-            config_path = self.get_config_path(self.export_config.weight_path)
+        config_path = self.get_config_path(self.export_config.weight_path)
+        if self.export_config.get("basic_config_path", None):
+            config_path = self.export_config.get("basic_config_path", None)
 
         self.pdx_config, self.pdx_model = build_model(
             self.global_config.model, config_path=config_path

+ 7 - 0
paddlex/modules/base/trainer.py

@@ -19,6 +19,7 @@ from .build_model import build_model
 from ...utils.device import update_device_num, set_env_for_device
 from ...utils.misc import AutoRegisterABCMetaClass
 from ...utils.config import AttrDict
+from ...utils.logging import info
 
 
 def build_trainer(config: AttrDict) -> "BaseTrainer":
@@ -31,6 +32,12 @@ def build_trainer(config: AttrDict) -> "BaseTrainer":
         BaseTrainer: the trainer, which is subclass of BaseTrainer.
     """
     model_name = config.Global.model
+    try:
+        import feature_line_modules
+    except ModuleNotFoundError:
+        info(
+            "The PaddleX FeaTure Line plugin is not installed, but continuing execution."
+        )
     return BaseTrainer.get(model_name)(config)
 
 

+ 15 - 0
paddlex/modules/instance_segmentation/dataset_checker/__init__.py

@@ -13,6 +13,7 @@
 # limitations under the License.
 
 import os
+from pathlib import Path
 
 from .dataset_src import check, convert, split_dataset, deep_analyse
 from ...base import BaseDatasetChecker
@@ -26,6 +27,20 @@ class COCOInstSegDatasetChecker(BaseDatasetChecker):
     entities = MODELS
     sample_num = 10
 
+    def get_dataset_root(self, dataset_dir: str) -> str:
+        """find the dataset root dir
+
+        Args:
+            dataset_dir (str): the directory that contain dataset.
+
+        Returns:
+            str: the root directory of dataset.
+        """
+        anno_dirs = list(Path(dataset_dir).glob("**/images"))
+        assert len(anno_dirs) == 1
+        dataset_dir = anno_dirs[0].parent.as_posix()
+        return dataset_dir
+
     def convert_dataset(self, src_dataset_dir: str) -> str:
         """convert the dataset from other type to specified type
 

+ 15 - 0
paddlex/modules/semantic_segmentation/dataset_checker/__init__.py

@@ -15,6 +15,7 @@
 
 import os
 import os.path as osp
+from pathlib import Path
 
 from ...base import BaseDatasetChecker
 from .dataset_src import check_dataset, convert_dataset, split_dataset, anaylse_dataset
@@ -28,6 +29,20 @@ class SegDatasetChecker(BaseDatasetChecker):
     entities = MODELS
     sample_num = 10
 
+    def get_dataset_root(self, dataset_dir: str) -> str:
+        """find the dataset root dir
+
+        Args:
+            dataset_dir (str): the directory that contain dataset.
+
+        Returns:
+            str: the root directory of dataset.
+        """
+        anno_dirs = list(Path(dataset_dir).glob("**/images"))
+        assert len(anno_dirs) == 1
+        dataset_dir = anno_dirs[0].parent.as_posix()
+        return dataset_dir
+
     def convert_dataset(self, src_dataset_dir: str) -> str:
         """convert the dataset from other type to specified type
 

+ 16 - 1
paddlex/modules/table_recognition/dataset_checker/__init__.py

@@ -15,6 +15,7 @@
 
 import os
 import os.path as osp
+from pathlib import Path
 from collections import defaultdict, Counter
 from PIL import Image
 import json
@@ -30,6 +31,20 @@ class TableRecDatasetChecker(BaseDatasetChecker):
 
     entities = MODELS
 
+    def get_dataset_root(self, dataset_dir: str) -> str:
+        """find the dataset root dir
+
+        Args:
+            dataset_dir (str): the directory that contain dataset.
+
+        Returns:
+            str: the root directory of dataset.
+        """
+        anno_dirs = list(Path(dataset_dir).glob("**/train.txt"))
+        assert len(anno_dirs) == 1
+        dataset_dir = anno_dirs[0].parent.as_posix()
+        return dataset_dir
+
     def convert_dataset(self, src_dataset_dir: str) -> str:
         """convert the dataset from other type to specified type
 
@@ -64,7 +79,7 @@ class TableRecDatasetChecker(BaseDatasetChecker):
         Returns:
             dict: dataset summary.
         """
-        return check(dataset_dir, self.global_config.output, sample_num=10)
+        return check(dataset_dir, self.output, sample_num=10)
 
     def get_show_type(self) -> str:
         """get the show type of dataset

+ 13 - 2
paddlex/modules/table_recognition/dataset_checker/dataset_src/check_dataset.py

@@ -16,6 +16,7 @@
 import os
 import json
 import os.path as osp
+from PIL import Image, ImageOps
 from collections import defaultdict
 from .....utils.errors import DatasetFileNotFoundError, CheckFailedError
 
@@ -58,11 +59,21 @@ def check(dataset_dir, output, dataset_type="PubTabTableRecDataset", sample_num=
                         structure = info["html"]["structure"]["tokens"].copy()
 
                         img_path = osp.join(dataset_dir, file_name)
-                        if len(sample_paths[tag]) < max_recorded_sample_cnts:
-                            sample_paths[tag].append(os.path.relpath(img_path, output))
 
                         if not os.path.exists(img_path):
                             raise DatasetFileNotFoundError(file_path=img_path)
+                        vis_save_dir = osp.join(output, "demo_img")
+                        if not osp.exists(vis_save_dir):
+                            os.makedirs(vis_save_dir)
+                        if len(sample_paths[tag]) < sample_num:
+                            img = Image.open(img_path)
+                            img = ImageOps.exif_transpose(img)
+                            vis_path = osp.join(vis_save_dir, osp.basename(file_name))
+                            img.save(vis_path)
+                            sample_path = osp.join(
+                                "check_dataset", os.path.relpath(vis_path, output)
+                            )
+                            sample_paths[tag].append(sample_path)
 
                         boxes_num = len(cells)
                         tokens_num = sum(

+ 16 - 1
paddlex/modules/text_detection/dataset_checker/__init__.py

@@ -15,6 +15,7 @@
 
 import os
 import os.path as osp
+from pathlib import Path
 from collections import defaultdict, Counter
 
 import json
@@ -30,6 +31,20 @@ class TextDetDatasetChecker(BaseDatasetChecker):
 
     entities = MODELS
 
+    def get_dataset_root(self, dataset_dir: str) -> str:
+        """find the dataset root dir
+
+        Args:
+            dataset_dir (str): the directory that contain dataset.
+
+        Returns:
+            str: the root directory of dataset.
+        """
+        anno_dirs = list(Path(dataset_dir).glob("**/images"))
+        assert len(anno_dirs) == 1
+        dataset_dir = anno_dirs[0].parent.as_posix()
+        return dataset_dir
+
     def convert_dataset(self, src_dataset_dir: str) -> str:
         """convert the dataset from other type to specified type
 
@@ -64,7 +79,7 @@ class TextDetDatasetChecker(BaseDatasetChecker):
         Returns:
             dict: dataset summary.
         """
-        return check(dataset_dir, self.global_config.output, sample_num=10)
+        return check(dataset_dir, self.output, sample_num=10)
 
     def analyse(self, dataset_dir: str) -> dict:
         """deep analyse dataset

+ 13 - 3
paddlex/modules/text_detection/dataset_checker/dataset_src/check_dataset.py

@@ -17,7 +17,7 @@ import os
 import os.path as osp
 from collections import defaultdict
 
-from PIL import Image
+from PIL import Image, ImageOps
 import json
 import numpy as np
 
@@ -66,10 +66,20 @@ def check(dataset_dir, output, sample_num=10):
                     file_name = substr[0]
                     label = substr[1]
                     img_path = osp.join(dataset_dir, file_name)
-                    if len(sample_paths[tag]) < sample_num:
-                        sample_paths[tag].append(os.path.relpath(img_path, output))
                     if not osp.exists(img_path):
                         raise DatasetFileNotFoundError(file_path=img_path)
+                    vis_save_dir = osp.join(output, "demo_img")
+                    if not osp.exists(vis_save_dir):
+                        os.makedirs(vis_save_dir)
+                    if len(sample_paths[tag]) < sample_num:
+                        img = Image.open(img_path)
+                        img = ImageOps.exif_transpose(img)
+                        vis_path = osp.join(vis_save_dir, osp.basename(file_name))
+                        img.save(vis_path)
+                        sample_path = osp.join(
+                            "check_dataset", os.path.relpath(vis_path, output)
+                        )
+                        sample_paths[tag].append(sample_path)
 
                     # check det label
                     label = json.loads(label)

+ 16 - 1
paddlex/modules/text_recognition/dataset_checker/__init__.py

@@ -15,6 +15,7 @@
 
 import os
 import os.path as osp
+from pathlib import Path
 from collections import defaultdict, Counter
 
 from PIL import Image
@@ -35,6 +36,20 @@ class TextRecDatasetChecker(BaseDatasetChecker):
     entities = MODELS
     sample_num = 10
 
+    def get_dataset_root(self, dataset_dir: str) -> str:
+        """find the dataset root dir
+
+        Args:
+            dataset_dir (str): the directory that contain dataset.
+
+        Returns:
+            str: the root directory of dataset.
+        """
+        anno_dirs = list(Path(dataset_dir).glob("**/train.txt"))
+        assert len(anno_dirs) == 1
+        dataset_dir = anno_dirs[0].parent.as_posix()
+        return dataset_dir
+
     def convert_dataset(self, src_dataset_dir: str) -> str:
         """convert the dataset from other type to specified type
 
@@ -74,7 +89,7 @@ class TextRecDatasetChecker(BaseDatasetChecker):
         """
         return check(
             dataset_dir,
-            self.global_config.output,
+            self.output,
             sample_num=10,
             dataset_type=self.get_dataset_type(),
         )

+ 13 - 3
paddlex/modules/text_recognition/dataset_checker/dataset_src/check_dataset.py

@@ -17,7 +17,7 @@ import os
 import os.path as osp
 from collections import defaultdict
 
-from PIL import Image
+from PIL import Image, ImageOps
 import json
 import numpy as np
 
@@ -79,11 +79,21 @@ def check(
                             )
                         file_name = substr[0]
                         img_path = osp.join(dataset_dir, file_name)
-                        if len(sample_paths[tag]) < max_recorded_sample_cnts:
-                            sample_paths[tag].append(os.path.relpath(img_path, output))
 
                         if not os.path.exists(img_path):
                             raise DatasetFileNotFoundError(file_path=img_path)
+                        vis_save_dir = osp.join(output, "demo_img")
+                        if not osp.exists(vis_save_dir):
+                            os.makedirs(vis_save_dir)
+                        if len(sample_paths[tag]) < sample_num:
+                            img = Image.open(img_path)
+                            img = ImageOps.exif_transpose(img)
+                            vis_path = osp.join(vis_save_dir, osp.basename(file_name))
+                            img.save(vis_path)
+                            sample_path = osp.join(
+                                "check_dataset", os.path.relpath(vis_path, output)
+                            )
+                            sample_paths[tag].append(sample_path)
 
         meta = {}
         meta["train_samples"] = sample_cnts["train"]

+ 15 - 0
paddlex/modules/ts_anomaly_detection/dataset_checker/__init__.py

@@ -15,6 +15,7 @@
 
 import os
 import os.path as osp
+from pathlib import Path
 from collections import defaultdict, Counter
 
 from PIL import Image
@@ -32,6 +33,20 @@ class TSADDatasetChecker(BaseDatasetChecker):
     entities = MODELS
     sample_num = 10
 
+    def get_dataset_root(self, dataset_dir: str) -> str:
+        """find the dataset root dir
+
+        Args:
+            dataset_dir (str): the directory that contain dataset.
+
+        Returns:
+            str: the root directory of dataset.
+        """
+        anno_dirs = list(Path(dataset_dir).glob("**/train.csv"))
+        assert len(anno_dirs) == 1
+        dataset_dir = anno_dirs[0].parent.as_posix()
+        return dataset_dir
+
     def convert_dataset(self, src_dataset_dir: str) -> str:
         """convert the dataset from other type to specified type
 

+ 15 - 0
paddlex/modules/ts_classification/dataset_checker/__init__.py

@@ -15,6 +15,7 @@
 
 import os
 import os.path as osp
+from pathlib import Path
 from collections import defaultdict, Counter
 
 from PIL import Image
@@ -32,6 +33,20 @@ class TSCLSDatasetChecker(BaseDatasetChecker):
     entities = MODELS
     sample_num = 10
 
+    def get_dataset_root(self, dataset_dir: str) -> str:
+        """find the dataset root dir
+
+        Args:
+            dataset_dir (str): the directory that contain dataset.
+
+        Returns:
+            str: the root directory of dataset.
+        """
+        anno_dirs = list(Path(dataset_dir).glob("**/train.csv"))
+        assert len(anno_dirs) == 1
+        dataset_dir = anno_dirs[0].parent.as_posix()
+        return dataset_dir
+
     def convert_dataset(self, src_dataset_dir: str) -> str:
         """convert the dataset from other type to specified type
 

+ 15 - 0
paddlex/modules/ts_forecast/dataset_checker/__init__.py

@@ -15,6 +15,7 @@
 
 import os
 import os.path as osp
+from pathlib import Path
 from collections import defaultdict, Counter
 
 from PIL import Image
@@ -32,6 +33,20 @@ class TSFCDatasetChecker(BaseDatasetChecker):
     entities = MODELS
     sample_num = 10
 
+    def get_dataset_root(self, dataset_dir: str) -> str:
+        """find the dataset root dir
+
+        Args:
+            dataset_dir (str): the directory that contain dataset.
+
+        Returns:
+            str: the root directory of dataset.
+        """
+        anno_dirs = list(Path(dataset_dir).glob("**/train.csv"))
+        assert len(anno_dirs) == 1
+        dataset_dir = anno_dirs[0].parent.as_posix()
+        return dataset_dir
+
     def convert_dataset(self, src_dataset_dir: str) -> str:
         """convert the dataset from other type to specified type
 

+ 8 - 0
paddlex/repo_apis/PaddleOCR_api/text_rec/config.py

@@ -118,6 +118,14 @@ class TextRecConfig(BaseConfig):
         else:
             raise ValueError(f"{repr(dataset_type)} is not supported.")
 
+    def update_dataset_by_list(self, label_file_list, ratio_list):
+        _cfg = {
+            "Train.dataset.name": "MSTextRecDataset",
+            "Train.dataset.label_file_list": label_file_list,
+            "Train.dataset.ratio_list": ratio_list,
+        }
+        self.update(_cfg)
+
     def update_batch_size(self, batch_size: int, mode: str = "train"):
         """update batch size setting
 

+ 3 - 0
paddlex/repo_apis/PaddleTS_api/ts_base/config.py

@@ -57,6 +57,9 @@ class BaseTSConfig(BaseConfig):
         Args:
             config_file_path (str): the path to save self as yaml file.
         """
+        output_dir = os.path.dirname(config_file_path)
+        if not os.path.exists(output_dir):
+            os.makedirs(output_dir)
         yaml = ruamel.yaml.YAML()
         with open(config_file_path, "w", encoding="utf-8") as f:
             dict_to_dump = self.dict