|
|
@@ -40,7 +40,8 @@ class ImgBatch(Batch):
|
|
|
|
|
|
class ImageBatchSampler(BaseBatchSampler):
|
|
|
|
|
|
- SUFFIX = ["jpg", "png", "jpeg", "JPEG", "JPG", "bmp"]
|
|
|
+ IMG_SUFFIX = ["jpg", "png", "jpeg", "bmp"]
|
|
|
+ PDF_SUFFIX = ["pdf"]
|
|
|
|
|
|
def __init__(self, *args, **kwargs):
|
|
|
super().__init__(*args, **kwargs)
|
|
|
@@ -54,16 +55,17 @@ class ImageBatchSampler(BaseBatchSampler):
|
|
|
return save_path.as_posix()
|
|
|
|
|
|
def _get_files_list(self, fp):
|
|
|
- file_list = []
|
|
|
if fp is None or not os.path.exists(fp):
|
|
|
- raise Exception(f"Not found any img file in path: {fp}")
|
|
|
+ raise Exception(f"Not found any image files or pdf files in path: {fp}")
|
|
|
|
|
|
- if os.path.isfile(fp) and fp.split(".")[-1] in self.SUFFIX:
|
|
|
- file_list.append(fp)
|
|
|
- elif os.path.isdir(fp):
|
|
|
+ file_list = []
|
|
|
+ if os.path.isdir(fp):
|
|
|
for root, dirs, files in os.walk(fp):
|
|
|
for single_file in files:
|
|
|
- if single_file.split(".")[-1] in self.SUFFIX:
|
|
|
+ if (
|
|
|
+ single_file.split(".")[-1].lower()
|
|
|
+ in self.IMG_SUFFIX + self.PDF_SUFFIX
|
|
|
+ ):
|
|
|
file_list.append(os.path.join(root, single_file))
|
|
|
if len(file_list) == 0:
|
|
|
raise Exception("Not found any file in {}".format(fp))
|
|
|
@@ -81,29 +83,34 @@ class ImageBatchSampler(BaseBatchSampler):
|
|
|
if len(batch) == self.batch_size:
|
|
|
yield batch
|
|
|
batch = ImgBatch()
|
|
|
- elif isinstance(input, str) and input.split(".")[-1] in ("PDF", "pdf"):
|
|
|
- file_path = (
|
|
|
- self._download_from_url(input)
|
|
|
- if input.startswith("http")
|
|
|
- else input
|
|
|
- )
|
|
|
- for page_idx, page_img in enumerate(self.pdf_reader.read(file_path)):
|
|
|
- batch.append(page_img, file_path, page_idx)
|
|
|
- if len(batch) == self.batch_size:
|
|
|
- yield batch
|
|
|
- batch = ImgBatch()
|
|
|
elif isinstance(input, str):
|
|
|
- file_path = (
|
|
|
- self._download_from_url(input)
|
|
|
- if input.startswith("http")
|
|
|
- else input
|
|
|
- )
|
|
|
- file_list = self._get_files_list(file_path)
|
|
|
- for file_path in file_list:
|
|
|
+ suffix = input.split(".")[-1].lower()
|
|
|
+ if suffix in self.PDF_SUFFIX:
|
|
|
+ file_path = (
|
|
|
+ self._download_from_url(input)
|
|
|
+ if input.startswith("http")
|
|
|
+ else input
|
|
|
+ )
|
|
|
+ for page_idx, page_img in enumerate(
|
|
|
+ self.pdf_reader.read(file_path)
|
|
|
+ ):
|
|
|
+ batch.append(page_img, file_path, page_idx)
|
|
|
+ if len(batch) == self.batch_size:
|
|
|
+ yield batch
|
|
|
+ batch = ImgBatch()
|
|
|
+ elif suffix in self.IMG_SUFFIX:
|
|
|
+ file_path = (
|
|
|
+ self._download_from_url(input)
|
|
|
+ if input.startswith("http")
|
|
|
+ else input
|
|
|
+ )
|
|
|
batch.append(file_path, file_path, None)
|
|
|
if len(batch) == self.batch_size:
|
|
|
yield batch
|
|
|
batch = ImgBatch()
|
|
|
+ else:
|
|
|
+ file_list = self._get_files_list(input)
|
|
|
+ yield from self.sample(file_list)
|
|
|
else:
|
|
|
logging.warning(
|
|
|
f"Not supported input data type! Only `numpy.ndarray` and `str` are supported! So has been ignored: {input}."
|