Prechádzať zdrojové kódy

fix ocr when bs > 1

gaotingquan 11 mesiacov pred
rodič
commit
a3e828b315

+ 3 - 3
paddlex/inference/components/task_related/text_rec.py

@@ -283,9 +283,9 @@ class CTCLabelDecode(BaseRecLabelDecode):
 
     def apply(self, pred):
         """apply"""
-        preds = np.array(pred[0])
-        preds_idx = preds.argmax(axis=2)
-        preds_prob = preds.max(axis=2)
+        preds = np.array(pred)
+        preds_idx = preds.argmax(axis=-1).squeeze(axis=1)
+        preds_prob = preds.max(axis=-1).squeeze(axis=1)
         text = self.decode(preds_idx, preds_prob, is_remove_duplicate=True)
         return [{"rec_text": t[0], "rec_score": t[1]} for t in text]