|
|
@@ -13,7 +13,7 @@
|
|
|
# limitations under the License.
|
|
|
|
|
|
import os, sys
|
|
|
-from typing import Any, Dict, Optional
|
|
|
+from typing import Any, Dict, Optional, Union, List, Tuple
|
|
|
import numpy as np
|
|
|
import cv2
|
|
|
from ..base import BasePipeline
|
|
|
@@ -96,19 +96,27 @@ class TableRecognitionPipelineV2(BasePipeline):
|
|
|
{"model_config_error": "config error for wireless_table_structure_model!"},
|
|
|
)
|
|
|
self.wireless_table_rec_model = self.create_model(wireless_table_rec_config)
|
|
|
-
|
|
|
+
|
|
|
wired_table_cells_det_config = config.get("SubModules", {}).get(
|
|
|
"WiredTableCellsDetection",
|
|
|
- {"model_config_error": "config error for wired_table_cells_detection_model!"},
|
|
|
+ {
|
|
|
+ "model_config_error": "config error for wired_table_cells_detection_model!"
|
|
|
+ },
|
|
|
+ )
|
|
|
+ self.wired_table_cells_detection_model = self.create_model(
|
|
|
+ wired_table_cells_det_config
|
|
|
)
|
|
|
- self.wired_table_cells_detection_model = self.create_model(wired_table_cells_det_config)
|
|
|
|
|
|
wireless_table_cells_det_config = config.get("SubModules", {}).get(
|
|
|
"WirelessTableCellsDetection",
|
|
|
- {"model_config_error": "config error for wireless_table_cells_detection_model!"},
|
|
|
+ {
|
|
|
+ "model_config_error": "config error for wireless_table_cells_detection_model!"
|
|
|
+ },
|
|
|
+ )
|
|
|
+ self.wireless_table_cells_detection_model = self.create_model(
|
|
|
+ wireless_table_cells_det_config
|
|
|
)
|
|
|
- self.wireless_table_cells_detection_model = self.create_model(wireless_table_cells_det_config)
|
|
|
-
|
|
|
+
|
|
|
self.use_ocr_model = config.get("use_ocr_model", True)
|
|
|
if self.use_ocr_model:
|
|
|
general_ocr_config = config.get("SubPipelines", {}).get(
|
|
|
@@ -218,7 +226,7 @@ class TableRecognitionPipelineV2(BasePipeline):
|
|
|
|
|
|
def predict_doc_preprocessor_res(
|
|
|
self, image_array: np.ndarray, input_params: dict
|
|
|
- ) -> tuple[DocPreprocessorResult, np.ndarray]:
|
|
|
+ ) -> Tuple[DocPreprocessorResult, np.ndarray]:
|
|
|
"""
|
|
|
Preprocess the document image based on input parameters.
|
|
|
|
|
|
@@ -248,15 +256,15 @@ class TableRecognitionPipelineV2(BasePipeline):
|
|
|
|
|
|
def extract_results(self, pred, task):
|
|
|
if task == "cls":
|
|
|
- return pred['label_names'][np.argmax(pred['scores'])]
|
|
|
+ return pred["label_names"][np.argmax(pred["scores"])]
|
|
|
elif task == "det":
|
|
|
threshold = 0.0
|
|
|
result = []
|
|
|
- if 'boxes' in pred and isinstance(pred['boxes'], list):
|
|
|
- for box in pred['boxes']:
|
|
|
- if isinstance(box, dict) and 'score' in box and 'coordinate' in box:
|
|
|
- score = box['score']
|
|
|
- coordinate = box['coordinate']
|
|
|
+ if "boxes" in pred and isinstance(pred["boxes"], list):
|
|
|
+ for box in pred["boxes"]:
|
|
|
+ if isinstance(box, dict) and "score" in box and "coordinate" in box:
|
|
|
+ score = box["score"]
|
|
|
+ coordinate = box["coordinate"]
|
|
|
if isinstance(score, float) and score > threshold:
|
|
|
result.append(coordinate)
|
|
|
return result
|
|
|
@@ -291,8 +299,12 @@ class TableRecognitionPipelineV2(BasePipeline):
|
|
|
table_cells_pred = next(self.wired_table_cells_detection_model(image_array))
|
|
|
elif table_cls_result == "wireless_table":
|
|
|
table_structure_pred = next(self.wireless_table_rec_model(image_array))
|
|
|
- table_cells_pred = next(self.wireless_table_cells_detection_model(image_array))
|
|
|
- table_structure_result = self.extract_results(table_structure_pred, "table_stru")
|
|
|
+ table_cells_pred = next(
|
|
|
+ self.wireless_table_cells_detection_model(image_array)
|
|
|
+ )
|
|
|
+ table_structure_result = self.extract_results(
|
|
|
+ table_structure_pred, "table_stru"
|
|
|
+ )
|
|
|
table_cells_result = self.extract_results(table_cells_pred, "det")
|
|
|
single_table_recognition_res = get_table_recognition_res(
|
|
|
table_box, table_structure_result, table_cells_result, overall_ocr_res
|
|
|
@@ -310,7 +322,7 @@ class TableRecognitionPipelineV2(BasePipeline):
|
|
|
|
|
|
def predict(
|
|
|
self,
|
|
|
- input: str | list[str] | np.ndarray | list[np.ndarray],
|
|
|
+ input: Union[str, List[str], np.ndarray, List[np.ndarray]],
|
|
|
use_doc_orientation_classify: Optional[bool] = None,
|
|
|
use_doc_unwarping: Optional[bool] = None,
|
|
|
use_layout_detection: Optional[bool] = None,
|
|
|
@@ -329,7 +341,7 @@ class TableRecognitionPipelineV2(BasePipeline):
|
|
|
This function predicts the layout parsing result for the given input.
|
|
|
|
|
|
Args:
|
|
|
- input (str | list[str] | np.ndarray | list[np.ndarray]): The input image(s) of pdf(s) to be processed.
|
|
|
+ input (Union[str, list[str], np.ndarray, list[np.ndarray]]): The input image(s) of pdf(s) to be processed.
|
|
|
use_layout_detection (bool): Whether to use layout detection.
|
|
|
use_doc_orientation_classify (bool): Whether to use document orientation classification.
|
|
|
use_doc_unwarping (bool): Whether to use document unwarping.
|