Forráskód Böngészése

add batch code for text recognition

liuhongen1234567 10 hónapja
szülő
commit
81c9b2ec43

+ 1 - 2
paddlex/inference/models_new/text_recognition/predictor.py

@@ -21,11 +21,10 @@ from ..common import (
     ResizeByShort,
     Normalize,
     ToCHWImage,
-    ToBatch,
     StaticInfer,
 )
 from ..base import BasicPredictor
-from .processors import OCRReisizeNormImg, CTCLabelDecode
+from .processors import OCRReisizeNormImg, CTCLabelDecode, ToBatch
 from .result import TextRecResult
 
 

+ 40 - 0
paddlex/inference/models_new/text_recognition/processors.py

@@ -15,6 +15,7 @@
 
 import os
 import os.path as osp
+from typing import List
 
 import re
 import numpy as np
@@ -184,3 +185,42 @@ class CTCLabelDecode(BaseRecLabelDecode):
         """add_special_char"""
         character_list = ["blank"] + character_list
         return character_list
+
+
+class ToBatch:
+    """A class for batching and padding images to a uniform width."""
+
+    def __pad_imgs(self, imgs: List[np.ndarray]) -> List[np.ndarray]:
+        """Pad images to the maximum width in the batch.
+
+        Args:
+            imgs (list of np.ndarrays): List of images to pad.
+
+        Returns:
+            list of np.ndarrays: List of padded images.
+        """
+        max_width = max(img.shape[2] for img in imgs)
+        padded_imgs = []
+        for img in imgs:
+            _, height, width = img.shape
+            pad_width = max_width - width
+            padded_img = np.pad(
+                img,
+                ((0, 0), (0, 0), (0, pad_width)),
+                mode="constant",
+                constant_values=0,
+            )
+            padded_imgs.append(padded_img)
+        return padded_imgs
+
+    def __call__(self, imgs: List[np.ndarray]) -> List[np.ndarray]:
+        """Call method to pad images and stack them into a batch.
+
+        Args:
+            imgs (list of np.ndarrays): List of images to process.
+
+        Returns:
+            list of np.ndarrays: List containing a stacked tensor of the padded images.
+        """
+        imgs = self.__pad_imgs(imgs)
+        return [np.stack(imgs, axis=0).astype(dtype=np.float32, copy=False)]