Pārlūkot izejas kodu

update ts_batch_sampler for benchmark (#3762)

zhang-prog 7 mēneši atpakaļ
vecāks
revīzija
3408bd9813

+ 2 - 2
paddlex/inference/common/batch_sampler/ts_batch_sampler.py

@@ -88,7 +88,7 @@ class TSBatchSampler(BaseBatchSampler):
                 batch.append(input, None)
                 if len(batch) == self.batch_size:
                     yield batch
-                    batch.reset()
+                    batch = Batch()
             elif isinstance(input, str):
                 file_path = (
                     self._download_from_url(input)
@@ -100,7 +100,7 @@ class TSBatchSampler(BaseBatchSampler):
                     batch.append(file_path, file_path)
                     if len(batch) == self.batch_size:
                         yield batch
-                        batch.reset()
+                        batch = Batch()
             else:
                 logging.warning(
                     f"Not supported input data type! Only `pd.DataFrame` and `str` are supported! So has been ignored: {input}."