|
@@ -3,6 +3,7 @@ import os
|
|
|
import cv2
|
|
import cv2
|
|
|
import numpy as np
|
|
import numpy as np
|
|
|
import onnxruntime
|
|
import onnxruntime
|
|
|
|
|
+from loguru import logger
|
|
|
|
|
|
|
|
from mineru.backend.pipeline.model_list import AtomicModel
|
|
from mineru.backend.pipeline.model_list import AtomicModel
|
|
|
from mineru.utils.enum_class import ModelPath
|
|
from mineru.utils.enum_class import ModelPath
|
|
@@ -63,6 +64,10 @@ class PaddleTableClsModel:
|
|
|
|
|
|
|
|
def predict(self, img):
|
|
def predict(self, img):
|
|
|
x = self.preprocess(img)
|
|
x = self.preprocess(img)
|
|
|
- (result,) = self.sess.run(None, {"x": x})
|
|
|
|
|
- label = self.labels[np.argmax(result)]
|
|
|
|
|
- return label
|
|
|
|
|
|
|
+ result = self.sess.run(None, {"x": x})
|
|
|
|
|
+ idx = np.argmax(result)
|
|
|
|
|
+ conf = float(np.max(result))
|
|
|
|
|
+ # logger.debug(f"Table classification result: {self.labels[idx]} with confidence {conf:.4f}")
|
|
|
|
|
+ if idx == 0 and conf < 0.8:
|
|
|
|
|
+ idx = 1
|
|
|
|
|
+ return self.labels[idx]
|