Răsfoiți Sursa

support to pass directory as input that include pdf files

gaotingquan 6 luni în urmă
părinte
comite
ebc715c0ab
1 a modificat fișierele cu 32 adăugiri și 25 ștergeri
  1. 32 25
      paddlex/inference/common/batch_sampler/image_batch_sampler.py

+ 32 - 25
paddlex/inference/common/batch_sampler/image_batch_sampler.py

@@ -40,7 +40,8 @@ class ImgBatch(Batch):
 
 class ImageBatchSampler(BaseBatchSampler):
 
-    SUFFIX = ["jpg", "png", "jpeg", "JPEG", "JPG", "bmp"]
+    IMG_SUFFIX = ["jpg", "png", "jpeg", "bmp"]
+    PDF_SUFFIX = ["pdf"]
 
     def __init__(self, *args, **kwargs):
         super().__init__(*args, **kwargs)
@@ -54,16 +55,17 @@ class ImageBatchSampler(BaseBatchSampler):
         return save_path.as_posix()
 
     def _get_files_list(self, fp):
-        file_list = []
         if fp is None or not os.path.exists(fp):
-            raise Exception(f"Not found any img file in path: {fp}")
+            raise Exception(f"Not found any image files or pdf files in path: {fp}")
 
-        if os.path.isfile(fp) and fp.split(".")[-1] in self.SUFFIX:
-            file_list.append(fp)
-        elif os.path.isdir(fp):
+        file_list = []
+        if os.path.isdir(fp):
             for root, dirs, files in os.walk(fp):
                 for single_file in files:
-                    if single_file.split(".")[-1] in self.SUFFIX:
+                    if (
+                        single_file.split(".")[-1].lower()
+                        in self.IMG_SUFFIX + self.PDF_SUFFIX
+                    ):
                         file_list.append(os.path.join(root, single_file))
         if len(file_list) == 0:
             raise Exception("Not found any file in {}".format(fp))
@@ -81,29 +83,34 @@ class ImageBatchSampler(BaseBatchSampler):
                 if len(batch) == self.batch_size:
                     yield batch
                     batch = ImgBatch()
-            elif isinstance(input, str) and input.split(".")[-1] in ("PDF", "pdf"):
-                file_path = (
-                    self._download_from_url(input)
-                    if input.startswith("http")
-                    else input
-                )
-                for page_idx, page_img in enumerate(self.pdf_reader.read(file_path)):
-                    batch.append(page_img, file_path, page_idx)
-                    if len(batch) == self.batch_size:
-                        yield batch
-                        batch = ImgBatch()
             elif isinstance(input, str):
-                file_path = (
-                    self._download_from_url(input)
-                    if input.startswith("http")
-                    else input
-                )
-                file_list = self._get_files_list(file_path)
-                for file_path in file_list:
+                suffix = input.split(".")[-1].lower()
+                if suffix in self.PDF_SUFFIX:
+                    file_path = (
+                        self._download_from_url(input)
+                        if input.startswith("http")
+                        else input
+                    )
+                    for page_idx, page_img in enumerate(
+                        self.pdf_reader.read(file_path)
+                    ):
+                        batch.append(page_img, file_path, page_idx)
+                        if len(batch) == self.batch_size:
+                            yield batch
+                            batch = ImgBatch()
+                elif suffix in self.IMG_SUFFIX:
+                    file_path = (
+                        self._download_from_url(input)
+                        if input.startswith("http")
+                        else input
+                    )
                     batch.append(file_path, file_path, None)
                     if len(batch) == self.batch_size:
                         yield batch
                         batch = ImgBatch()
+                else:
+                    file_list = self._get_files_list(input)
+                    yield from self.sample(file_list)
             else:
                 logging.warning(
                     f"Not supported input data type! Only `numpy.ndarray` and `str` are supported! So has been ignored: {input}."