Przeglądaj źródła

TSBatchSampler support input_path

gaotingquan 8 miesięcy temu
rodzic
commit
79a41c445d

+ 17 - 0
paddlex/inference/common/batch_sampler/base_batch_sampler.py

@@ -16,6 +16,23 @@ from typing import Union, Tuple, List, Dict, Any, Iterator
 from abc import ABC, abstractmethod
 
 
+class Batch:
+    def __init__(self):
+        self.instances = []
+        self.input_paths = []
+
+    def append(self, instance, input_path):
+        self.instances.append(instance)
+        self.input_paths.append(input_path)
+
+    def reset(self):
+        self.instances = []
+        self.input_paths = []
+
+    def __len__(self):
+        return len(self.instances)
+
+
 class BaseBatchSampler:
     """BaseBatchSampler"""
 

+ 9 - 15
paddlex/inference/common/batch_sampler/image_batch_sampler.py

@@ -20,28 +20,22 @@ from ....utils import logging
 from ....utils.download import download
 from ....utils.cache import CACHE_DIR
 from ...utils.io import PDFReader
-from .base_batch_sampler import BaseBatchSampler
+from .base_batch_sampler import BaseBatchSampler, Batch
 
 
-class ImgInstance:
+class ImgBatch(Batch):
     def __init__(self):
-        self.instances = []
-        self.input_paths = []
+        super().__init__()
         self.page_indexes = []
 
     def append(self, instance, input_path, page_index):
-        self.instances.append(instance)
-        self.input_paths.append(input_path)
+        super().append(instance, input_path)
         self.page_indexes.append(page_index)
 
     def reset(self):
-        self.instances = []
-        self.input_paths = []
+        super().reset()
         self.page_indexes = []
 
-    def __len__(self):
-        return len(self.instances)
-
 
 class ImageBatchSampler(BaseBatchSampler):
 
@@ -79,13 +73,13 @@ class ImageBatchSampler(BaseBatchSampler):
         if not isinstance(inputs, list):
             inputs = [inputs]
 
-        batch = ImgInstance()
+        batch = ImgBatch()
         for input in inputs:
             if isinstance(input, np.ndarray):
                 batch.append(input, None, None)
                 if len(batch) == self.batch_size:
                     yield batch
-                    batch = ImgInstance()
+                    batch.reset()
             elif isinstance(input, str) and input.split(".")[-1] in ("PDF", "pdf"):
                 file_path = (
                     self._download_from_url(input)
@@ -96,7 +90,7 @@ class ImageBatchSampler(BaseBatchSampler):
                     batch.append(page_img, file_path, page_idx)
                     if len(batch) == self.batch_size:
                         yield batch
-                        batch = ImgInstance()
+                        batch.reset()
             elif isinstance(input, str):
                 file_path = (
                     self._download_from_url(input)
@@ -108,7 +102,7 @@ class ImageBatchSampler(BaseBatchSampler):
                     batch.append(file_path, file_path, None)
                     if len(batch) == self.batch_size:
                         yield batch
-                        batch = ImgInstance()
+                        batch.reset()
             else:
                 logging.warning(
                     f"Not supported input data type! Only `numpy.ndarray` and `str` are supported! So has been ignored: {input}."

+ 6 - 6
paddlex/inference/common/batch_sampler/ts_batch_sampler.py

@@ -21,7 +21,7 @@ import pandas as pd
 from ....utils import logging
 from ....utils.download import download
 from ....utils.cache import CACHE_DIR
-from .base_batch_sampler import BaseBatchSampler
+from .base_batch_sampler import BaseBatchSampler, Batch
 
 
 class TSBatchSampler(BaseBatchSampler):
@@ -83,13 +83,13 @@ class TSBatchSampler(BaseBatchSampler):
         if not isinstance(inputs, list):
             inputs = [inputs]
 
-        batch = []
+        batch = Batch()
         for input in inputs:
             if isinstance(input, pd.DataFrame):
-                batch.append(input)
+                batch.append(input, None)
                 if len(batch) == self.batch_size:
                     yield batch
-                    batch = []
+                    batch.reset()
             elif isinstance(input, str):
                 file_path = (
                     self._download_from_url(input)
@@ -98,10 +98,10 @@ class TSBatchSampler(BaseBatchSampler):
                 )
                 file_list = self._get_files_list(file_path)
                 for file_path in file_list:
-                    batch.append(file_path)
+                    batch.append(file_path, file_path)
                     if len(batch) == self.batch_size:
                         yield batch
-                        batch = []
+                        batch.reset()
             else:
                 logging.warning(
                     f"Not supported input data type! Only `pd.DataFrame` and `str` are supported! So has been ignored: {input}."

+ 2 - 2
paddlex/inference/models/ts_anomaly_detection/predictor.py

@@ -116,7 +116,7 @@ class TSAdPredictor(BasicPredictor):
             dict: A dictionary containing the input path, raw image, class IDs, scores, and label names for every instance of the batch. Keys include 'input_path', 'input_img', 'class_ids', 'scores', and 'label_names'.
         """
 
-        batch_raw_ts = self.preprocessors["ReadTS"](ts_list=batch_data)
+        batch_raw_ts = self.preprocessors["ReadTS"](ts_list=batch_data.instances)
         batch_cutoff_ts = self.preprocessors["TSCutOff"](ts_list=batch_raw_ts)
 
         if "TSNormalize" in self.preprocessors:
@@ -140,7 +140,7 @@ class TSAdPredictor(BasicPredictor):
             ori_ts_list=batch_input_ts, pred_list=batch_preds
         )
         return {
-            "input_path": batch_data,
+            "input_path": batch_data.input_paths,
             "input_ts": batch_raw_ts,
             "anomaly": batch_ts_preds,
         }

+ 3 - 3
paddlex/inference/models/ts_classification/predictor.py

@@ -109,7 +109,7 @@ class TSClsPredictor(BasicPredictor):
         Returns:
             Dict[str, Any]: A dictionary containing the paths to the input data, the raw input time series, and the classification results.
         """
-        batch_raw_ts = self.preprocessors["ReadTS"](ts_list=batch_data)
+        batch_raw_ts = self.preprocessors["ReadTS"](ts_list=batch_data.instances)
         batch_raw_ts_ori = copy.deepcopy(batch_raw_ts)
 
         if "TSNormalize" in self.preprocessors:
@@ -127,9 +127,9 @@ class TSClsPredictor(BasicPredictor):
         batch_ts_preds = self.postprocessors["GetCls"](pred_list=batch_preds)
 
         return {
-            "input_path": batch_data,
+            "input_path": batch_data.input_paths,
             "input_ts": batch_raw_ts,
             "input_ts_data": batch_raw_ts_ori,
             "classification": batch_ts_preds,
-            "target_cols": [self.config["info_params"]["target_cols"]]
+            "target_cols": [self.config["info_params"]["target_cols"]],
         }

+ 2 - 2
paddlex/inference/models/ts_forecasting/predictor.py

@@ -122,7 +122,7 @@ class TSFcPredictor(BasicPredictor):
             dict: A dictionary containing the input path, raw image, class IDs, scores, and label names for every instance of the batch. Keys include 'input_path', 'input_img', 'class_ids', 'scores', and 'label_names'.
         """
 
-        batch_raw_ts = self.preprocessors["ReadTS"](ts_list=batch_data)
+        batch_raw_ts = self.preprocessors["ReadTS"](ts_list=batch_data.instances)
         batch_raw_ts_ori = copy.deepcopy(batch_raw_ts)
         batch_cutoff_ts = self.preprocessors["TSCutOff"](ts_list=batch_raw_ts)
 
@@ -152,7 +152,7 @@ class TSFcPredictor(BasicPredictor):
             )
 
         return {
-            "input_path": batch_data,
+            "input_path": batch_data.input_paths,
             "input_ts": batch_raw_ts,
             "cutoff_ts": batch_raw_ts_ori,
             "forecast": batch_ts_preds,