Эх сурвалжийг харах

fix default export label for det (#2066)

* fix default export label for det

* fix default export label for det
Sunflower7788 1 жил өмнө
parent
commit
129672343e

+ 1 - 1
paddlex/inference/utils/official_models.py

@@ -195,7 +195,7 @@ openatom_rec_svtrv2_ch_infer.tar",
 PP-LCNet_x1_0_pedestrian_attribute_infer.tar",
     "PP-LCNet_x1_0_vehicle_attribute": "https://paddle-model-ecology.bj.bcebos.com/paddlex/official_inference_model/paddle3.0b1_v2/\
 PP-LCNet_x1_0_vehicle_attribute_infer.tar",
-    "PicoDet_layout_1x": "https://paddle-model-ecology.bj.bcebos.com/paddlex/official_inference_model/paddle3.0b1/PicoDet-L_layout_infer.tar",
+    "PicoDet_layout_1x": "https://paddle-model-ecology.bj.bcebos.com/paddlex/official_inference_model/paddle3.0b1/PicoDet_layout_1x_infer.tar",
     "SLANet": "https://paddle-model-ecology.bj.bcebos.com/paddlex/official_inference_model/paddle3.0b1_v2/SLANet_infer.tar",
     "LaTeX_OCR_rec": "https://paddle-model-ecology.bj.bcebos.com/paddlex/official_inference_model/paddle3.0b1_v2/LaTeX_OCR_rec_infer.tar",
     "FasterRCNN-ResNet34-FPN": "https://paddle-model-ecology.bj.bcebos.com/paddlex/official_inference_model/paddle3.0b1_v2/FasterRCNN-ResNet34-FPN_infer.tar",

+ 1 - 0
paddlex/repo_apis/PaddleDetection_api/object_det/__init__.py

@@ -16,3 +16,4 @@
 from .model import DetModel
 from .runner import DetRunner
 from . import register
+from .official_categories import official_categories

+ 17 - 0
paddlex/repo_apis/PaddleDetection_api/object_det/model.py

@@ -13,13 +13,16 @@
 # limitations under the License.
 
 import os
+import json
 
 from ...base import BaseModel
 from ...base.utils.arg import CLIArgument
 from ...base.utils.subprocess import CompletedProcess
 from ....utils.misc import abspath
+from ....utils import logging
 
 from .config import DetConfig
+from .official_categories import official_categories
 
 
 class DetModel(BaseModel):
@@ -281,6 +284,20 @@ class DetModel(BaseModel):
             hpi_config_path = hpi_config_path.as_posix()
         config.update({"hpi_config_path": hpi_config_path})
 
+        if self.name in official_categories.keys():
+            anno_val_file = abspath(os.path.join(config.TestDataset['dataset_dir'], config.TestDataset['anno_path']))
+            if anno_val_file == None or (not os.path.isfile(anno_val_file)):
+                categories = official_categories[self.name]
+                temp_anno = {'images': [], 'annotations': [], 'categories': categories}
+                with self._create_new_val_json_file() as anno_file:
+                    json.dump(temp_anno, open(anno_file, 'w'))
+                    config.update({"TestDataset": {"dataset_dir": '', "anno_path": anno_file}})
+                    logging.warning(f"{self.name} does not have validate annotations, use {anno_file} default instead.")
+                    self._assert_empty_kwargs(kwargs)
+                    with self._create_new_config_file() as config_path:
+                        config.dump(config_path)
+                        return self.runner.export(config_path, cli_args, None)
+                
         self._assert_empty_kwargs(kwargs)
 
         with self._create_new_config_file() as config_path:

+ 14 - 0
paddlex/repo_apis/PaddleDetection_api/object_det/official_categories.py

@@ -0,0 +1,14 @@
+official_categories = {
+'PP-YOLOE-L_human': [{"name": "pedestrian", "id": 0}],
+'PP-YOLOE-S_human': [{"name": "pedestrian", "id": 0}],
+'PP-YOLOE-S_vehicle': [{"name": "vehicle", "id": 0}],
+'PP-YOLOE-L_vehicle': [{"name": "vehicle", "id": 0}],
+'PP-ShiTuV2_det': [{"name": "mainbody", "id": 0}],
+'PicoDet_layout_1x': [{"name": "Text", "id": 0}, {"name": "Title", "id": 1}, {"name": "List", "id": 2}, {"name": "Table", "id": 3}, {"name": "Figure", "id": 4}],
+'PicoDet-L_layout_3cls': [{"name": "image", "id": 0}, {"name": "table", "id": 1}, {"name": "seal", "id": 2}],
+'RT-DETR-H_layout_3cls': [{"name": "image", "id": 0}, {"name": "table", "id": 1}, {"name": "seal", "id": 2}],
+'RT-DETR-H_layout_17cls': [{"name": "paragraph_title", "id": 0}, {"name": "image", "id": 1}, {"name": "text", "id": 2}, {"name": "number", "id": 3}, {"name": "abstract", "id": 4}, {"name": "content", "id": 5}, {"name": "figure_title", "id": 6}, {"name": "formula", "id": 7}, {"name": "table", "id": 8}, {"name": "tabke_title", "id": 9}, {"name":"reference", "id": 10}, {"name": "doc_title", "id": 11}, {"name": "footnote", "id": 12}, {"name": "header", "id": 13}, {"name": "algorithm", "id": 14}, {"name": "footer", "id": 15}, {"name": "seal", "id": 16}],
+'PP-YOLOE_plus_SOD-S':  [{"name": "pedestrian", "id": 0}, {"name": "people", "id": 1}, {"name": "bicycle", "id": 2}, {"name": "car", "id": 3}, {"name": "van", "id": 4}, {"name": "truck", "id": 5}, {"name": "tricycle", "id": 6}, {"name": "awning-tricycle", "id": 7}, {"name": "bus", "id": 8}, {"name": "motorcycle", "id": 9}],
+'PP-YOLOE_plus_SOD-L': [{"name": "pedestrian", "id": 0}, {"name": "people", "id": 1}, {"name": "bicycle", "id": 2}, {"name": "car", "id": 3}, {"name": "van", "id": 4}, {"name": "truck", "id": 5}, {"name": "tricycle", "id": 6}, {"name": "awning-tricycle", "id": 7}, {"name": "bus", "id": 8}, {"name": "motorcycle", "id": 9}],
+'PP-YOLOE_plus_SOD-largesize-L': [{"name": "pedestrian", "id": 0}, {"name": "people", "id": 1}, {"name": "bicycle", "id": 2}, {"name": "car", "id": 3}, {"name": "van", "id": 4}, {"name": "truck", "id": 5}, {"name": "tricycle", "id": 6}, {"name": "awning-tricycle", "id": 7}, {"name": "bus", "id": 8}, {"name": "motorcycle", "id": 9}],
+}

+ 18 - 0
paddlex/repo_apis/base/model.py

@@ -295,6 +295,24 @@ configuration item, "
                 pass
             yield path
 
+    @contextlib.contextmanager
+    def _create_new_val_json_file(self):
+        cls = self.__class__
+        model_name = self.model_info["model_name"]
+        tag = "_".join([cls.__name__.lower(), model_name])
+        json_file_name = tag + "_test.json"
+        if not flags.DEBUG:
+            with tempfile.TemporaryDirectory(dir=get_cache_dir()) as td:
+                path = os.path.join(td, json_file_name)
+                with open(path, "w", encoding="utf-8"):
+                    pass
+                yield path
+        else:
+            path = os.path.join(get_cache_dir(), json_file_name)
+            with open(path, "w", encoding="utf-8"):
+                pass
+            yield path
+
     @cached_property
     def supported_apis(self):
         """supported apis"""