|
|
@@ -36,7 +36,7 @@ from ...results import DetResult
|
|
|
class LayoutParsingPipeline(BasePipeline):
|
|
|
"""Layout Parsing Pipeline"""
|
|
|
|
|
|
- entities = "layout_parsing"
|
|
|
+ entities = ["layout_parsing", "seal_recognition", "table_recognition"]
|
|
|
|
|
|
def __init__(
|
|
|
self,
|
|
|
@@ -66,6 +66,42 @@ class LayoutParsingPipeline(BasePipeline):
|
|
|
|
|
|
self._crop_by_boxes = CropByBoxes()
|
|
|
|
|
|
+ def set_used_models_flag(self, config: Dict) -> None:
|
|
|
+ """
|
|
|
+ Set the flags for which models to use based on the configuration.
|
|
|
+
|
|
|
+ Args:
|
|
|
+ config (Dict): A dictionary containing configuration settings.
|
|
|
+
|
|
|
+ Returns:
|
|
|
+ None
|
|
|
+ """
|
|
|
+ pipeline_name = config["pipeline_name"]
|
|
|
+
|
|
|
+ self.pipeline_name = pipeline_name
|
|
|
+
|
|
|
+ self.use_doc_preprocessor = False
|
|
|
+ self.use_general_ocr = False
|
|
|
+ self.use_seal_recognition = False
|
|
|
+ self.use_table_recognition = False
|
|
|
+
|
|
|
+ if "use_doc_preprocessor" in config:
|
|
|
+ self.use_doc_preprocessor = config["use_doc_preprocessor"]
|
|
|
+
|
|
|
+ if pipeline_name == "layout_parsing":
|
|
|
+ if "use_general_ocr" in config:
|
|
|
+ self.use_general_ocr = config["use_general_ocr"]
|
|
|
+ if "use_seal_recognition" in config:
|
|
|
+ self.use_seal_recognition = config["use_seal_recognition"]
|
|
|
+ if "use_table_recognition" in config:
|
|
|
+ self.use_table_recognition = config["use_table_recognition"]
|
|
|
+
|
|
|
+ elif pipeline_name == "seal_recognition":
|
|
|
+ self.use_seal_recognition = True
|
|
|
+
|
|
|
+ elif pipeline_name == "table_recognition":
|
|
|
+ self.use_table_recognition = True
|
|
|
+
|
|
|
def inintial_predictor(self, config: Dict) -> None:
|
|
|
"""Initializes the predictor based on the provided configuration.
|
|
|
|
|
|
@@ -76,36 +112,25 @@ class LayoutParsingPipeline(BasePipeline):
|
|
|
None
|
|
|
"""
|
|
|
|
|
|
+ self.set_used_models_flag(config)
|
|
|
+
|
|
|
layout_det_config = config["SubModules"]["LayoutDetection"]
|
|
|
self.layout_det_model = self.create_model(layout_det_config)
|
|
|
|
|
|
- self.use_doc_preprocessor = False
|
|
|
- if "use_doc_preprocessor" in config:
|
|
|
- self.use_doc_preprocessor = config["use_doc_preprocessor"]
|
|
|
-
|
|
|
if self.use_doc_preprocessor:
|
|
|
doc_preprocessor_config = config["SubPipelines"]["DocPreprocessor"]
|
|
|
self.doc_preprocessor_pipeline = self.create_pipeline(
|
|
|
doc_preprocessor_config
|
|
|
)
|
|
|
|
|
|
- self.use_general_ocr = False
|
|
|
- if "use_general_ocr" in config:
|
|
|
- self.use_general_ocr = config["use_general_ocr"]
|
|
|
if self.use_general_ocr:
|
|
|
general_ocr_config = config["SubPipelines"]["GeneralOCR"]
|
|
|
self.general_ocr_pipeline = self.create_pipeline(general_ocr_config)
|
|
|
|
|
|
- self.use_seal_recognition = False
|
|
|
- if "use_seal_recognition" in config:
|
|
|
- self.use_seal_recognition = config["use_seal_recognition"]
|
|
|
if self.use_seal_recognition:
|
|
|
seal_ocr_config = config["SubPipelines"]["SealOCR"]
|
|
|
self.seal_ocr_pipeline = self.create_pipeline(seal_ocr_config)
|
|
|
|
|
|
- self.use_table_recognition = False
|
|
|
- if "use_table_recognition" in config:
|
|
|
- self.use_table_recognition = config["use_table_recognition"]
|
|
|
if self.use_table_recognition:
|
|
|
table_structure_config = config["SubModules"]["TableStructureRecognition"]
|
|
|
self.table_structure_model = self.create_model(table_structure_config)
|
|
|
@@ -171,6 +196,24 @@ class LayoutParsingPipeline(BasePipeline):
|
|
|
|
|
|
return True
|
|
|
|
|
|
+ def convert_input_params(self, input_params: Dict) -> None:
|
|
|
+ """
|
|
|
+ Convert input parameters based on the pipeline name.
|
|
|
+
|
|
|
+ Args:
|
|
|
+ input_params (Dict): The input parameters dictionary.
|
|
|
+
|
|
|
+ Returns:
|
|
|
+ None
|
|
|
+ """
|
|
|
+ if self.pipeline_name == "seal_recognition":
|
|
|
+ input_params["use_general_ocr"] = False
|
|
|
+ input_params["use_table_recognition"] = False
|
|
|
+ elif self.pipeline_name == "table_recognition":
|
|
|
+ input_params["use_general_ocr"] = False
|
|
|
+ input_params["use_seal_recognition"] = False
|
|
|
+ return
|
|
|
+
|
|
|
def predict(
|
|
|
self,
|
|
|
input: str | list[str] | np.ndarray | list[np.ndarray],
|
|
|
@@ -211,6 +254,8 @@ class LayoutParsingPipeline(BasePipeline):
|
|
|
"use_table_recognition": use_table_recognition,
|
|
|
}
|
|
|
|
|
|
+ self.convert_input_params(input_params)
|
|
|
+
|
|
|
if use_doc_orientation_classify or use_doc_unwarping:
|
|
|
input_params["use_doc_preprocessor"] = True
|
|
|
else:
|