|
|
@@ -8,17 +8,25 @@ from rapid_table import RapidTable
|
|
|
class RapidTableModel(object):
|
|
|
def __init__(self, ocr_engine):
|
|
|
self.table_model = RapidTable()
|
|
|
- if ocr_engine is None:
|
|
|
- self.ocr_model_name = "RapidOCR"
|
|
|
- if torch.cuda.is_available():
|
|
|
- from rapidocr_paddle import RapidOCR
|
|
|
- self.ocr_engine = RapidOCR(det_use_cuda=True, cls_use_cuda=True, rec_use_cuda=True)
|
|
|
- else:
|
|
|
- from rapidocr_onnxruntime import RapidOCR
|
|
|
- self.ocr_engine = RapidOCR()
|
|
|
+ # if ocr_engine is None:
|
|
|
+ # self.ocr_model_name = "RapidOCR"
|
|
|
+ # if torch.cuda.is_available():
|
|
|
+ # from rapidocr_paddle import RapidOCR
|
|
|
+ # self.ocr_engine = RapidOCR(det_use_cuda=True, cls_use_cuda=True, rec_use_cuda=True)
|
|
|
+ # else:
|
|
|
+ # from rapidocr_onnxruntime import RapidOCR
|
|
|
+ # self.ocr_engine = RapidOCR()
|
|
|
+ # else:
|
|
|
+ # self.ocr_model_name = "PaddleOCR"
|
|
|
+ # self.ocr_engine = ocr_engine
|
|
|
+
|
|
|
+ self.ocr_model_name = "RapidOCR"
|
|
|
+ if torch.cuda.is_available():
|
|
|
+ from rapidocr_paddle import RapidOCR
|
|
|
+ self.ocr_engine = RapidOCR(det_use_cuda=True, cls_use_cuda=True, rec_use_cuda=True)
|
|
|
else:
|
|
|
- self.ocr_model_name = "PaddleOCR"
|
|
|
- self.ocr_engine = ocr_engine
|
|
|
+ from rapidocr_onnxruntime import RapidOCR
|
|
|
+ self.ocr_engine = RapidOCR()
|
|
|
|
|
|
def predict(self, image):
|
|
|
|