|
|
@@ -20,28 +20,22 @@ from ....utils import logging
|
|
|
from ....utils.download import download
|
|
|
from ....utils.cache import CACHE_DIR
|
|
|
from ...utils.io import PDFReader
|
|
|
-from .base_batch_sampler import BaseBatchSampler
|
|
|
+from .base_batch_sampler import BaseBatchSampler, Batch
|
|
|
|
|
|
|
|
|
-class ImgInstance:
|
|
|
+class ImgBatch(Batch):
|
|
|
def __init__(self):
|
|
|
- self.instances = []
|
|
|
- self.input_paths = []
|
|
|
+ super().__init__()
|
|
|
self.page_indexes = []
|
|
|
|
|
|
def append(self, instance, input_path, page_index):
|
|
|
- self.instances.append(instance)
|
|
|
- self.input_paths.append(input_path)
|
|
|
+ super().append(instance, input_path)
|
|
|
self.page_indexes.append(page_index)
|
|
|
|
|
|
def reset(self):
|
|
|
- self.instances = []
|
|
|
- self.input_paths = []
|
|
|
+ super().reset()
|
|
|
self.page_indexes = []
|
|
|
|
|
|
- def __len__(self):
|
|
|
- return len(self.instances)
|
|
|
-
|
|
|
|
|
|
class ImageBatchSampler(BaseBatchSampler):
|
|
|
|
|
|
@@ -79,13 +73,13 @@ class ImageBatchSampler(BaseBatchSampler):
|
|
|
if not isinstance(inputs, list):
|
|
|
inputs = [inputs]
|
|
|
|
|
|
- batch = ImgInstance()
|
|
|
+ batch = ImgBatch()
|
|
|
for input in inputs:
|
|
|
if isinstance(input, np.ndarray):
|
|
|
batch.append(input, None, None)
|
|
|
if len(batch) == self.batch_size:
|
|
|
yield batch
|
|
|
- batch = ImgInstance()
|
|
|
+ batch.reset()
|
|
|
elif isinstance(input, str) and input.split(".")[-1] in ("PDF", "pdf"):
|
|
|
file_path = (
|
|
|
self._download_from_url(input)
|
|
|
@@ -96,7 +90,7 @@ class ImageBatchSampler(BaseBatchSampler):
|
|
|
batch.append(page_img, file_path, page_idx)
|
|
|
if len(batch) == self.batch_size:
|
|
|
yield batch
|
|
|
- batch = ImgInstance()
|
|
|
+ batch.reset()
|
|
|
elif isinstance(input, str):
|
|
|
file_path = (
|
|
|
self._download_from_url(input)
|
|
|
@@ -108,7 +102,7 @@ class ImageBatchSampler(BaseBatchSampler):
|
|
|
batch.append(file_path, file_path, None)
|
|
|
if len(batch) == self.batch_size:
|
|
|
yield batch
|
|
|
- batch = ImgInstance()
|
|
|
+ batch.reset()
|
|
|
else:
|
|
|
logging.warning(
|
|
|
f"Not supported input data type! Only `numpy.ndarray` and `str` are supported! So has been ignored: {input}."
|