|
|
@@ -20,6 +20,7 @@ import numpy as np
|
|
|
from ....utils import logging
|
|
|
from ....utils.download import download
|
|
|
from ....utils.cache import CACHE_DIR
|
|
|
+from ...utils.io import PDFReader
|
|
|
from .base_batch_sampler import BaseBatchSampler
|
|
|
|
|
|
|
|
|
@@ -27,6 +28,10 @@ class ImageBatchSampler(BaseBatchSampler):
|
|
|
|
|
|
SUFFIX = ["jpg", "png", "jpeg", "JPEG", "JPG", "bmp"]
|
|
|
|
|
|
+ def __init__(self, *args, **kwargs):
|
|
|
+ super().__init__(*args, **kwargs)
|
|
|
+ self.pdf_reader = PDFReader()
|
|
|
+
|
|
|
# XXX: auto download for url
|
|
|
def _download_from_url(self, in_path):
|
|
|
file_name = Path(in_path).name
|
|
|
@@ -62,6 +67,17 @@ class ImageBatchSampler(BaseBatchSampler):
|
|
|
if len(batch) == self.batch_size:
|
|
|
yield batch
|
|
|
batch = []
|
|
|
+ elif isinstance(input, str) and input.split(".")[-1] in ("PDF", "pdf"):
|
|
|
+ file_path = (
|
|
|
+ self._download_from_url(input)
|
|
|
+ if input.startswith("http")
|
|
|
+ else input
|
|
|
+ )
|
|
|
+ for page_img in self.pdf_reader.read(file_path):
|
|
|
+ batch.append(page_img)
|
|
|
+ if len(batch) == self.batch_size:
|
|
|
+ yield batch
|
|
|
+ batch = []
|
|
|
elif isinstance(input, str):
|
|
|
file_path = (
|
|
|
self._download_from_url(input)
|