|
|
@@ -2,12 +2,25 @@ import cv2
|
|
|
import numpy as np
|
|
|
import torch
|
|
|
from loguru import logger
|
|
|
-from rapid_table import RapidTable
|
|
|
+from rapid_table import RapidTable, RapidTableInput
|
|
|
+from rapid_table.main import ModelType
|
|
|
|
|
|
|
|
|
class RapidTableModel(object):
|
|
|
- def __init__(self, ocr_engine):
|
|
|
- self.table_model = RapidTable()
|
|
|
+ def __init__(self, ocr_engine, table_sub_model_name):
|
|
|
+ sub_model_list = [model.value for model in ModelType]
|
|
|
+ if table_sub_model_name is None:
|
|
|
+ input_args = RapidTableInput()
|
|
|
+ elif table_sub_model_name in sub_model_list:
|
|
|
+ if torch.cuda.is_available() and table_sub_model_name == "unitable":
|
|
|
+ input_args = RapidTableInput(model_type=table_sub_model_name, use_cuda=True)
|
|
|
+ else:
|
|
|
+ input_args = RapidTableInput(model_type=table_sub_model_name)
|
|
|
+ else:
|
|
|
+ raise ValueError(f"Invalid table_sub_model_name: {table_sub_model_name}. It must be one of {sub_model_list}")
|
|
|
+
|
|
|
+ self.table_model = RapidTable(input_args)
|
|
|
+
|
|
|
# if ocr_engine is None:
|
|
|
# self.ocr_model_name = "RapidOCR"
|
|
|
# if torch.cuda.is_available():
|
|
|
@@ -45,7 +58,11 @@ class RapidTableModel(object):
|
|
|
ocr_result = None
|
|
|
|
|
|
if ocr_result:
|
|
|
- html_code, table_cell_bboxes, elapse = self.table_model(np.asarray(image), ocr_result)
|
|
|
- return html_code, table_cell_bboxes, elapse
|
|
|
+ table_results = self.table_model(np.asarray(image), ocr_result)
|
|
|
+ html_code = table_results.pred_html
|
|
|
+ table_cell_bboxes = table_results.cell_bboxes
|
|
|
+ logic_points = table_results.logic_points
|
|
|
+ elapse = table_results.elapse
|
|
|
+ return html_code, table_cell_bboxes, logic_points, elapse
|
|
|
else:
|
|
|
- return None, None, None
|
|
|
+ return None, None, None, None
|