Browse Source

fix bug & support log_interval for TS

zhouchangda 1 year ago
parent
commit
e896b0c8e9

+ 1 - 0
paddlex/configs/ts_anomaly_detection/AutoEncoder_ad.yaml

@@ -23,6 +23,7 @@ Train:
   epochs_iters: 20
   batch_size: 16
   learning_rate: 0.0005
+  log_interval: 10
 
 Evaluate:
   weight_path: "output/best_model/model.pdparams"

+ 1 - 0
paddlex/configs/ts_anomaly_detection/DLinear_ad.yaml

@@ -23,6 +23,7 @@ Train:
   epochs_iters: 20
   batch_size: 16
   learning_rate: 0.0005
+  log_interval: 10
 
 Evaluate:
   weight_path: "output/best_model/model.pdparams"

+ 1 - 0
paddlex/configs/ts_anomaly_detection/Nonstationary_ad.yaml

@@ -23,6 +23,7 @@ Train:
   epochs_iters: 20
   batch_size: 16
   learning_rate: 0.0005
+  log_interval: 10
 
 Evaluate:
   weight_path: "output/best_model/model.pdparams"

+ 1 - 0
paddlex/configs/ts_anomaly_detection/PatchTST_ad.yaml

@@ -23,6 +23,7 @@ Train:
   epochs_iters: 20
   batch_size: 16
   learning_rate: 0.0005
+  log_interval: 10
 
 Evaluate:
   weight_path: "output/best_model/model.pdparams"

+ 1 - 0
paddlex/configs/ts_anomaly_detection/TimesNet_ad.yaml

@@ -23,6 +23,7 @@ Train:
   epochs_iters: 20
   batch_size: 16
   learning_rate: 0.0005
+  log_interval: 10
 
 Evaluate:
   weight_path: "output/best_model/model.pdparams"

+ 1 - 0
paddlex/configs/ts_classification/TimesNet_cls.yaml

@@ -23,6 +23,7 @@ Train:
   epochs_iters: 40
   batch_size: 16
   learning_rate: 0.0001
+  log_interval: 10
 
 Evaluate:
   weight_path: "output/best_model/model.pdparams"

+ 1 - 0
paddlex/configs/ts_forecast/DLinear.yaml

@@ -24,6 +24,7 @@ Train:
   batch_size: 16
   learning_rate: 0.0001
   patience: 10
+  log_interval: 10
 
 Evaluate:
   weight_path: "output/best_model/model.pdparams"

+ 1 - 0
paddlex/configs/ts_forecast/NLinear.yaml

@@ -24,6 +24,7 @@ Train:
   batch_size: 16
   learning_rate: 0.0001
   patience: 10
+  log_interval: 10
 
 Evaluate:
   weight_path: "output/best_model/model.pdparams"

+ 1 - 0
paddlex/configs/ts_forecast/Nonstationary.yaml

@@ -24,6 +24,7 @@ Train:
   batch_size: 16
   learning_rate: 0.0001
   patience: 10
+  log_interval: 10
 
 Evaluate:
   weight_path: "output/best_model/model.pdparams"

+ 1 - 0
paddlex/configs/ts_forecast/PatchTST.yaml

@@ -24,6 +24,7 @@ Train:
   batch_size: 16
   learning_rate: 0.0001
   patience: 10
+  log_interval: 10
 
 Evaluate:
   weight_path: "output/best_model/model.pdparams"

+ 1 - 0
paddlex/configs/ts_forecast/RLinear.yaml

@@ -24,6 +24,7 @@ Train:
   batch_size: 16
   learning_rate: 0.0001
   patience: 10
+  log_interval: 10
 
 Evaluate:
   weight_path: "output/best_model/model.pdparams"

+ 1 - 0
paddlex/configs/ts_forecast/TiDE.yaml

@@ -24,6 +24,7 @@ Train:
   batch_size: 16
   learning_rate: 0.0001
   patience: 10
+  log_interval: 10
 
 Evaluate:
   weight_path: "output/best_model/model.pdparams"

+ 1 - 0
paddlex/configs/ts_forecast/TimesNet.yaml

@@ -24,6 +24,7 @@ Train:
   batch_size: 16
   learning_rate: 0.0001
   patience: 10
+  log_interval: 10
 
 Evaluate:
   weight_path: "output/best_model/model.pdparams"

+ 1 - 1
paddlex/inference/pipelines/layout_parsing/layout_parsing.py

@@ -185,7 +185,7 @@ class LayoutParsingPipeline(_TableRecPipeline):
                 "ocr_result": OCRResult({}),
                 "table_ocr_result": [],
                 "table_result": StructureTableResult([]),
-                "layout_parsing_result": [],
+                "layout_parsing_result": {},
                 "oricls_result": TopkResult({}),
                 "formula_result": TextRecResult({}),
                 "unwarp_result": DocTrResult({}),

+ 4 - 4
paddlex/inference/pipelines/ppchatocrv3/ppchatocrv3.py

@@ -257,7 +257,7 @@ class PPChatOCRPipeline(_TableRecPipeline):
                 "ocr_result": OCRResult({}),
                 "table_ocr_result": [],
                 "table_result": StructureTableResult([]),
-                "structure_result": [],
+                "layout_parsing_result": {},
                 "oricls_result": TopkResult({}),
                 "unwarp_result": DocTrResult({}),
                 "curve_result": [],
@@ -364,14 +364,14 @@ class PPChatOCRPipeline(_TableRecPipeline):
             structure_res = LayoutParsingResult(
                 {
                     "input_path": layout_pred["input_path"],
-                    "layout_parsing_result": structure_res,
+                    "parsing_result": structure_res,
                 }
             )
 
             single_img_res["table_result"] = all_table_res
             single_img_res["ocr_result"] = ocr_res
             single_img_res["table_ocr_result"] = all_table_ocr_res
-            single_img_res["structure_result"] = structure_res
+            single_img_res["layout_parsing_result"] = structure_res
 
             yield VisualResult(single_img_res)
 
@@ -380,7 +380,7 @@ class PPChatOCRPipeline(_TableRecPipeline):
         table_text_list = []
         table_html = []
         for single_img_pred in visual_result:
-            layout_res = single_img_pred["structure_result"]["layout_parsing_result"]
+            layout_res = single_img_pred["layout_parsing_result"]["parsing_result"]
             layout_res_copy = deepcopy(layout_res)
             # layout_res is [{"layout_bbox": [x1, y1, x2, y2], "layout": "single","words in text block":"xxx"}, {"layout_bbox": [x1, y1, x2, y2], "layout": "double","印章":"xxx"}
             ocr_res = {}

+ 2 - 0
paddlex/modules/ts_anomaly_detection/trainer.py

@@ -81,6 +81,8 @@ training!"
             self.pdx_config.update_learning_rate(self.train_config.learning_rate)
         if self.train_config.epochs_iters is not None:
             self.pdx_config.update_epochs(self.train_config.epochs_iters)
+        if self.train_config.log_interval is not None:
+            self.pdx_config.update_log_interval(self.train_config.log_interval)
         if self.global_config.output is not None:
             self.pdx_config.update_save_dir(self.global_config.output)
 

+ 2 - 0
paddlex/modules/ts_classification/trainer.py

@@ -76,6 +76,8 @@ training!"
             self.pdx_config.update_learning_rate(self.train_config.learning_rate)
         if self.train_config.epochs_iters is not None:
             self.pdx_config.update_epochs(self.train_config.epochs_iters)
+        if self.train_config.log_interval is not None:
+            self.pdx_config.update_log_interval(self.train_config.log_interval)
         if self.global_config.output is not None:
             self.pdx_config.update_save_dir(self.global_config.output)
 

+ 2 - 0
paddlex/modules/ts_forecast/trainer.py

@@ -76,6 +76,8 @@ training!"
             self.pdx_config.update_learning_rate(self.train_config.learning_rate)
         if self.train_config.epochs_iters is not None:
             self.pdx_config.update_epochs(self.train_config.epochs_iters)
+        if self.train_config.log_interval is not None:
+            self.pdx_config.update_log_interval(self.train_config.log_interval)
         if self.global_config.output is not None:
             self.pdx_config.update_save_dir(self.global_config.output)
 

+ 8 - 0
paddlex/repo_apis/PaddleTS_api/ts_base/config.py

@@ -153,6 +153,14 @@ class BaseTSConfig(BaseConfig):
         assert isinstance(print_mem_info, bool), "print_mem_info should be a bool"
         self.update({"print_mem_info": print_mem_info})
 
+    def update_log_interval(self, log_interval: int):
+        """update log interval(steps)
+
+        Args:
+            log_interval (int): the log interval value to set.
+        """
+        self.update({"log_interval": log_interval})
+
     def update_dataset(self, dataset_dir: str, dataset_type: str = None):
         """update dataset settings"""
         raise NotImplementedError