Browse Source

fix ocr when bs > 1

gaotingquan 11 tháng trước cách đây
mục cha
commit
a3e828b315
1 tập tin đã thay đổi với 3 bổ sung3 xóa
  1. 3 3
      paddlex/inference/components/task_related/text_rec.py

+ 3 - 3
paddlex/inference/components/task_related/text_rec.py

@@ -283,9 +283,9 @@ class CTCLabelDecode(BaseRecLabelDecode):
 
     def apply(self, pred):
         """apply"""
-        preds = np.array(pred[0])
-        preds_idx = preds.argmax(axis=2)
-        preds_prob = preds.max(axis=2)
+        preds = np.array(pred)
+        preds_idx = preds.argmax(axis=-1).squeeze(axis=1)
+        preds_prob = preds.max(axis=-1).squeeze(axis=1)
         text = self.decode(preds_idx, preds_prob, is_remove_duplicate=True)
         return [{"rec_text": t[0], "rec_score": t[1]} for t in text]