Browse Source

ImageBatchSampler support to sample from PDF file (#2717)

Tingquan Gao 10 months ago
parent
commit
50b61968f4

+ 16 - 0
paddlex/inference/common/batch_sampler/image_batch_sampler.py

@@ -20,6 +20,7 @@ import numpy as np
 from ....utils import logging
 from ....utils.download import download
 from ....utils.cache import CACHE_DIR
+from ...utils.io import PDFReader
 from .base_batch_sampler import BaseBatchSampler
 
 
@@ -27,6 +28,10 @@ class ImageBatchSampler(BaseBatchSampler):
 
     SUFFIX = ["jpg", "png", "jpeg", "JPEG", "JPG", "bmp"]
 
+    def __init__(self, *args, **kwargs):
+        super().__init__(*args, **kwargs)
+        self.pdf_reader = PDFReader()
+
     # XXX: auto download for url
     def _download_from_url(self, in_path):
         file_name = Path(in_path).name
@@ -62,6 +67,17 @@ class ImageBatchSampler(BaseBatchSampler):
                 if len(batch) == self.batch_size:
                     yield batch
                     batch = []
+            elif isinstance(input, str) and input.split(".")[-1] in ("PDF", "pdf"):
+                file_path = (
+                    self._download_from_url(input)
+                    if input.startswith("http")
+                    else input
+                )
+                for page_img in self.pdf_reader.read(file_path):
+                    batch.append(page_img)
+                    if len(batch) == self.batch_size:
+                        yield batch
+                        batch = []
             elif isinstance(input, str):
                 file_path = (
                     self._download_from_url(input)

+ 2 - 4
paddlex/inference/utils/io/readers.py

@@ -90,7 +90,7 @@ class PDFReader(_BaseReader):
         super().__init__(backend, **bk_args)
 
     def read(self, in_path):
-        return self._backend.read_file(str(in_path))
+        yield from self._backend.read_file(str(in_path))
 
     def _init_backend(self, bk_type, bk_args):
         return PDFReaderBackend(**bk_args)
@@ -233,15 +233,13 @@ class PDFReaderBackend(_BaseReaderBackend):
         self.mat = fitz.Matrix(zoom_x, zoom_y).prerotate(rotate)
 
     def read_file(self, in_path):
-        images = []
         for page in fitz.open(in_path):
             pix = page.get_pixmap(matrix=self.mat, alpha=False)
             getpngdata = pix.tobytes(output="png")
             # decode as np.uint8
             image_array = np.frombuffer(getpngdata, dtype=np.uint8)
             img_cv = cv2.imdecode(image_array, cv2.IMREAD_ANYCOLOR)
-            images.append(img_cv)
-        return images
+            yield img_cv
 
 
 class _VideoReaderBackend(_BaseReaderBackend):