소스 검색

check batchsize 1 for latex_rec (#3516)

liuhongen1234567 8 달 전
부모
커밋
86bcfd7043
1개의 변경된 파일37개의 추가작업 그리고 3개의 파일을 삭제
  1. 37 3
      paddlex/inference/models/formula_recognition/predictor.py

+ 37 - 3
paddlex/inference/models/formula_recognition/predictor.py

@@ -12,6 +12,7 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
+import numpy as np
 from ....utils import logging
 from ....utils.func_register import FuncRegister
 from ....modules.formula_recognition.model_list import MODELS
@@ -38,6 +39,7 @@ from .result import FormulaRecResult
 
 
 class FormulaRecPredictor(BasicPredictor):
+    """FormulaRecPredictor that inherits from BasicPredictor."""
 
     entities = MODELS
 
@@ -45,7 +47,23 @@ class FormulaRecPredictor(BasicPredictor):
     register = FuncRegister(_FUNC_MAP)
 
     def __init__(self, *args, **kwargs):
+        """Initializes FormulaRecPredictor.
+        Args:
+            *args: Arbitrary positional arguments passed to the superclass.
+            **kwargs: Arbitrary keyword arguments passed to the superclass.
+        """
         super().__init__(*args, **kwargs)
+
+        self.model_names_only_supports_batchsize_of_one = {
+            "LaTeX_OCR_rec",
+        }
+        if self.model_name in self.model_names_only_supports_batchsize_of_one:
+            logging.warning(
+                f"Formula Recognition Models: \"{', '.join(list(self.model_names_only_supports_batchsize_of_one))}\" only supports prediction with a batch_size of one, "
+                "if you set the predictor with a batch_size larger than one, no error will occur, however, it will actually inference with a batch_size of one, "
+                f"which will lead to a slower inference speed. You are now using {self.config['Global']['model_name']}."
+            )
+
         self.pre_tfs, self.infer, self.post_op = self._build()
 
     def _build_batch_sampler(self):
@@ -91,9 +109,25 @@ class FormulaRecPredictor(BasicPredictor):
             batch_imgs = self.pre_tfs["UniMERNetTestTransform"](imgs=batch_imgs)
             batch_imgs = self.pre_tfs["LatexImageFormat"](imgs=batch_imgs)
 
-        x = self.pre_tfs["ToBatch"](imgs=batch_imgs)
-        batch_preds = self.infer(x=x)
-        batch_preds = [p.reshape([-1]) for p in batch_preds[0]]
+        if self.model_name in self.model_names_only_supports_batchsize_of_one:
+            batch_preds = []
+            max_length = 0
+            for batch_img in batch_imgs:
+                batch_pred_ = self.infer([batch_img])[0].reshape([-1])
+                max_length = max(max_length, batch_pred_.shape[0])
+                batch_preds.append(batch_pred_)
+            for i in range(len(batch_preds)):
+                batch_preds[i] = np.pad(
+                    batch_preds[i],
+                    (0, max_length - batch_preds[i].shape[0]),
+                    mode="constant",
+                    constant_values=0,
+                )
+        else:
+            x = self.pre_tfs["ToBatch"](imgs=batch_imgs)
+            batch_preds = self.infer(x=x)
+            batch_preds = [p.reshape([-1]) for p in batch_preds[0]]
+
         rec_formula = self.post_op(batch_preds)
         return {
             "input_path": batch_data.input_paths,