|
@@ -12,12 +12,11 @@
|
|
|
# See the License for the specific language governing permissions and
|
|
# See the License for the specific language governing permissions and
|
|
|
# limitations under the License.
|
|
# limitations under the License.
|
|
|
|
|
|
|
|
-from typing import Any, Dict, Optional, Union, List
|
|
|
|
|
-import os, sys
|
|
|
|
|
|
|
+from email.mime import image
|
|
|
|
|
+from typing import Any, Dict, Optional, Union, List, Tuple
|
|
|
import numpy as np
|
|
import numpy as np
|
|
|
-import cv2
|
|
|
|
|
from ..base import BasePipeline
|
|
from ..base import BasePipeline
|
|
|
-from .utils import get_sub_regions_ocr_res
|
|
|
|
|
|
|
+from .utils import get_sub_regions_ocr_res, sorted_layout_boxes
|
|
|
from ..components import convert_points_to_boxes
|
|
from ..components import convert_points_to_boxes
|
|
|
from .result import LayoutParsingResult
|
|
from .result import LayoutParsingResult
|
|
|
from ....utils import logging
|
|
from ....utils import logging
|
|
@@ -91,6 +90,22 @@ class LayoutParsingPipeline(BasePipeline):
|
|
|
{"model_config_error": "config error for layout_det_model!"},
|
|
{"model_config_error": "config error for layout_det_model!"},
|
|
|
)
|
|
)
|
|
|
self.layout_det_model = self.create_model(layout_det_config)
|
|
self.layout_det_model = self.create_model(layout_det_config)
|
|
|
|
|
+ layout_kwargs = {}
|
|
|
|
|
+ if (threshold := layout_det_config.get("threshold", None)) is not None:
|
|
|
|
|
+ layout_kwargs["threshold"] = threshold
|
|
|
|
|
+ if (layout_nms := layout_det_config.get("layout_nms", None)) is not None:
|
|
|
|
|
+ layout_kwargs["layout_nms"] = layout_nms
|
|
|
|
|
+ if (
|
|
|
|
|
+ layout_unclip_ratio := layout_det_config.get("layout_unclip_ratio", None)
|
|
|
|
|
+ ) is not None:
|
|
|
|
|
+ layout_kwargs["layout_unclip_ratio"] = layout_unclip_ratio
|
|
|
|
|
+ if (
|
|
|
|
|
+ layout_merge_bboxes_mode := layout_det_config.get(
|
|
|
|
|
+ "layout_merge_bboxes_mode", None
|
|
|
|
|
+ )
|
|
|
|
|
+ ) is not None:
|
|
|
|
|
+ layout_kwargs["layout_merge_bboxes_mode"] = layout_merge_bboxes_mode
|
|
|
|
|
+ self.layout_det_model = self.create_model(layout_det_config, **layout_kwargs)
|
|
|
|
|
|
|
|
if self.use_general_ocr or self.use_table_recognition:
|
|
if self.use_general_ocr or self.use_table_recognition:
|
|
|
general_ocr_config = config.get("SubPipelines", {}).get(
|
|
general_ocr_config = config.get("SubPipelines", {}).get(
|
|
@@ -152,7 +167,127 @@ class LayoutParsingPipeline(BasePipeline):
|
|
|
if box_info["label"].lower() in ["formula", "table", "seal"]:
|
|
if box_info["label"].lower() in ["formula", "table", "seal"]:
|
|
|
object_boxes.append(box_info["coordinate"])
|
|
object_boxes.append(box_info["coordinate"])
|
|
|
object_boxes = np.array(object_boxes)
|
|
object_boxes = np.array(object_boxes)
|
|
|
- return get_sub_regions_ocr_res(overall_ocr_res, object_boxes, flag_within=False)
|
|
|
|
|
|
|
+ sub_regions_ocr_res = get_sub_regions_ocr_res(
|
|
|
|
|
+ overall_ocr_res, object_boxes, flag_within=False
|
|
|
|
|
+ )
|
|
|
|
|
+ return sub_regions_ocr_res
|
|
|
|
|
+
|
|
|
|
|
+ def get_layout_parsing_res(
|
|
|
|
|
+ self,
|
|
|
|
|
+ image: list,
|
|
|
|
|
+ layout_det_res: DetResult,
|
|
|
|
|
+ overall_ocr_res: OCRResult,
|
|
|
|
|
+ table_res_list: list,
|
|
|
|
|
+ seal_res_list: list,
|
|
|
|
|
+ formula_res_list: list,
|
|
|
|
|
+ text_det_limit_side_len: Optional[int] = None,
|
|
|
|
|
+ text_det_limit_type: Optional[str] = None,
|
|
|
|
|
+ text_det_thresh: Optional[float] = None,
|
|
|
|
|
+ text_det_box_thresh: Optional[float] = None,
|
|
|
|
|
+ text_det_unclip_ratio: Optional[float] = None,
|
|
|
|
|
+ text_rec_score_thresh: Optional[float] = None,
|
|
|
|
|
+ ) -> list:
|
|
|
|
|
+ """
|
|
|
|
|
+ Retrieves the layout parsing result based on the layout detection result, OCR result, and other recognition results.
|
|
|
|
|
+ Args:
|
|
|
|
|
+ image (list): The input image.
|
|
|
|
|
+ layout_det_res (DetResult): The detection result containing the layout information of the document.
|
|
|
|
|
+ overall_ocr_res (OCRResult): The overall OCR result containing text information.
|
|
|
|
|
+ table_res_list (list): A list of table recognition results.
|
|
|
|
|
+ seal_res_list (list): A list of seal recognition results.
|
|
|
|
|
+ formula_res_list (list): A list of formula recognition results.
|
|
|
|
|
+ text_det_limit_side_len (Optional[int], optional): The maximum side length of the text detection region. Defaults to None.
|
|
|
|
|
+ text_det_limit_type (Optional[str], optional): The type of limit for the text detection region. Defaults to None.
|
|
|
|
|
+ text_det_thresh (Optional[float], optional): The confidence threshold for text detection. Defaults to None.
|
|
|
|
|
+ text_det_box_thresh (Optional[float], optional): The confidence threshold for text detection bounding boxes. Defaults to None
|
|
|
|
|
+ text_det_unclip_ratio (Optional[float], optional): The unclip ratio for text detection. Defaults to None.
|
|
|
|
|
+ text_rec_score_thresh (Optional[float], optional): The score threshold for text recognition. Defaults to None.
|
|
|
|
|
+ Returns:
|
|
|
|
|
+ list: A list of dictionaries representing the layout parsing result.
|
|
|
|
|
+ """
|
|
|
|
|
+ layout_parsing_res = []
|
|
|
|
|
+ matched_ocr_dict = {}
|
|
|
|
|
+ formula_index = 0
|
|
|
|
|
+ table_index = 0
|
|
|
|
|
+ seal_index = 0
|
|
|
|
|
+ image = np.array(image)
|
|
|
|
|
+ image_labels = ["image", "figure", "img", "fig"]
|
|
|
|
|
+ object_boxes = []
|
|
|
|
|
+ for object_box_idx, box_info in enumerate(layout_det_res["boxes"]):
|
|
|
|
|
+ single_box_res = {}
|
|
|
|
|
+ box = box_info["coordinate"]
|
|
|
|
|
+ label = box_info["label"].lower()
|
|
|
|
|
+ single_box_res["layout_bbox"] = box
|
|
|
|
|
+ object_boxes.append(box)
|
|
|
|
|
+ if label == "formula":
|
|
|
|
|
+ single_box_res["formula"] = formula_res_list[formula_index][
|
|
|
|
|
+ "rec_formula"
|
|
|
|
|
+ ]
|
|
|
|
|
+ formula_index += 1
|
|
|
|
|
+ elif label == "table":
|
|
|
|
|
+ single_box_res["table"] = table_res_list[table_index]["pred_html"]
|
|
|
|
|
+ table_index += 1
|
|
|
|
|
+ elif label == "seal":
|
|
|
|
|
+ single_box_res["seal"] = "".join(seal_res_list[seal_index]["rec_texts"])
|
|
|
|
|
+ seal_index += 1
|
|
|
|
|
+ else:
|
|
|
|
|
+ ocr_res_in_box, matched_idxs = get_sub_regions_ocr_res(
|
|
|
|
|
+ overall_ocr_res, [box], return_match_idx=True
|
|
|
|
|
+ )
|
|
|
|
|
+ for matched_idx in matched_idxs:
|
|
|
|
|
+ if matched_ocr_dict.get(matched_idx, None) is None:
|
|
|
|
|
+ matched_ocr_dict[matched_idx] = [object_box_idx]
|
|
|
|
|
+ else:
|
|
|
|
|
+ matched_ocr_dict[matched_idx].append(object_box_idx)
|
|
|
|
|
+ if label in image_labels:
|
|
|
|
|
+ x1, y1, x2, y2 = [int(i) for i in box]
|
|
|
|
|
+ sub_image = image[y1:y2, x1:x2, :]
|
|
|
|
|
+ single_box_res["image"] = sub_image
|
|
|
|
|
+ single_box_res[f"{label}_text"] = "\n".join(
|
|
|
|
|
+ ocr_res_in_box["rec_texts"]
|
|
|
|
|
+ )
|
|
|
|
|
+ else:
|
|
|
|
|
+ single_box_res["text"] = "\n".join(ocr_res_in_box["rec_texts"])
|
|
|
|
|
+ if single_box_res:
|
|
|
|
|
+ layout_parsing_res.append(single_box_res)
|
|
|
|
|
+ for layout_box_ids in matched_ocr_dict.values():
|
|
|
|
|
+ # one ocr is matched to multiple layout boxes, split the text into multiple lines
|
|
|
|
|
+ if len(layout_box_ids) > 1:
|
|
|
|
|
+ for idx in layout_box_ids:
|
|
|
|
|
+ wht_im = np.ones(image.shape, dtype=image.dtype) * 255
|
|
|
|
|
+ box = layout_parsing_res[idx]["layout_bbox"]
|
|
|
|
|
+ x1, y1, x2, y2 = [int(i) for i in box]
|
|
|
|
|
+ wht_im[y1:y2, x1:x2, :] = image[y1:y2, x1:x2, :]
|
|
|
|
|
+ sub_ocr_res = next(
|
|
|
|
|
+ self.general_ocr_pipeline(
|
|
|
|
|
+ wht_im,
|
|
|
|
|
+ text_det_limit_side_len=text_det_limit_side_len,
|
|
|
|
|
+ text_det_limit_type=text_det_limit_type,
|
|
|
|
|
+ text_det_thresh=text_det_thresh,
|
|
|
|
|
+ text_det_box_thresh=text_det_box_thresh,
|
|
|
|
|
+ text_det_unclip_ratio=text_det_unclip_ratio,
|
|
|
|
|
+ text_rec_score_thresh=text_rec_score_thresh,
|
|
|
|
|
+ )
|
|
|
|
|
+ )
|
|
|
|
|
+ layout_parsing_res[idx]["text"] = "\n".join(
|
|
|
|
|
+ sub_ocr_res["rec_texts"]
|
|
|
|
|
+ )
|
|
|
|
|
+
|
|
|
|
|
+ ocr_without_layout_boxes = get_sub_regions_ocr_res(
|
|
|
|
|
+ overall_ocr_res, object_boxes, flag_within=False
|
|
|
|
|
+ )
|
|
|
|
|
+
|
|
|
|
|
+ for ocr_rec_box, ocr_rec_text in zip(
|
|
|
|
|
+ ocr_without_layout_boxes["rec_boxes"], ocr_without_layout_boxes["rec_texts"]
|
|
|
|
|
+ ):
|
|
|
|
|
+ single_box_res = {}
|
|
|
|
|
+ single_box_res["layout_bbox"] = ocr_rec_box
|
|
|
|
|
+ single_box_res["text_without_layout"] = ocr_rec_text
|
|
|
|
|
+ layout_parsing_res.append(single_box_res)
|
|
|
|
|
+
|
|
|
|
|
+ layout_parsing_res = sorted_layout_boxes(layout_parsing_res, w=image.shape[1])
|
|
|
|
|
+
|
|
|
|
|
+ return layout_parsing_res
|
|
|
|
|
|
|
|
def check_model_settings_valid(self, input_params: Dict) -> bool:
|
|
def check_model_settings_valid(self, input_params: Dict) -> bool:
|
|
|
"""
|
|
"""
|
|
@@ -262,6 +397,10 @@ class LayoutParsingPipeline(BasePipeline):
|
|
|
seal_det_box_thresh: Optional[float] = None,
|
|
seal_det_box_thresh: Optional[float] = None,
|
|
|
seal_det_unclip_ratio: Optional[float] = None,
|
|
seal_det_unclip_ratio: Optional[float] = None,
|
|
|
seal_rec_score_thresh: Optional[float] = None,
|
|
seal_rec_score_thresh: Optional[float] = None,
|
|
|
|
|
+ layout_threshold: Optional[Union[float, dict]] = None,
|
|
|
|
|
+ layout_nms: Optional[bool] = None,
|
|
|
|
|
+ layout_unclip_ratio: Optional[Union[float, Tuple[float, float]]] = None,
|
|
|
|
|
+ layout_merge_bboxes_mode: Optional[str] = None,
|
|
|
**kwargs,
|
|
**kwargs,
|
|
|
) -> LayoutParsingResult:
|
|
) -> LayoutParsingResult:
|
|
|
"""
|
|
"""
|
|
@@ -308,7 +447,15 @@ class LayoutParsingPipeline(BasePipeline):
|
|
|
|
|
|
|
|
doc_preprocessor_image = doc_preprocessor_res["output_img"]
|
|
doc_preprocessor_image = doc_preprocessor_res["output_img"]
|
|
|
|
|
|
|
|
- layout_det_res = next(self.layout_det_model(doc_preprocessor_image))
|
|
|
|
|
|
|
+ layout_det_res = next(
|
|
|
|
|
+ self.layout_det_model(
|
|
|
|
|
+ doc_preprocessor_image,
|
|
|
|
|
+ threshold=layout_threshold,
|
|
|
|
|
+ layout_nms=layout_nms,
|
|
|
|
|
+ layout_unclip_ratio=layout_unclip_ratio,
|
|
|
|
|
+ layout_merge_bboxes_mode=layout_merge_bboxes_mode,
|
|
|
|
|
+ )
|
|
|
|
|
+ )
|
|
|
|
|
|
|
|
if (
|
|
if (
|
|
|
model_settings["use_general_ocr"]
|
|
model_settings["use_general_ocr"]
|
|
@@ -382,10 +529,24 @@ class LayoutParsingPipeline(BasePipeline):
|
|
|
)
|
|
)
|
|
|
)
|
|
)
|
|
|
formula_res_list = formula_res_all["formula_res_list"]
|
|
formula_res_list = formula_res_all["formula_res_list"]
|
|
|
- print(formula_res_list)
|
|
|
|
|
else:
|
|
else:
|
|
|
formula_res_list = []
|
|
formula_res_list = []
|
|
|
|
|
|
|
|
|
|
+ parsing_res_list = self.get_layout_parsing_res(
|
|
|
|
|
+ doc_preprocessor_image,
|
|
|
|
|
+ layout_det_res=layout_det_res,
|
|
|
|
|
+ overall_ocr_res=overall_ocr_res,
|
|
|
|
|
+ table_res_list=table_res_list,
|
|
|
|
|
+ seal_res_list=seal_res_list,
|
|
|
|
|
+ formula_res_list=formula_res_list,
|
|
|
|
|
+ text_det_limit_side_len=text_det_limit_side_len,
|
|
|
|
|
+ text_det_limit_type=text_det_limit_type,
|
|
|
|
|
+ text_det_thresh=text_det_thresh,
|
|
|
|
|
+ text_det_box_thresh=text_det_box_thresh,
|
|
|
|
|
+ text_det_unclip_ratio=text_det_unclip_ratio,
|
|
|
|
|
+ text_rec_score_thresh=text_rec_score_thresh,
|
|
|
|
|
+ )
|
|
|
|
|
+
|
|
|
single_img_res = {
|
|
single_img_res = {
|
|
|
"input_path": batch_data.input_paths[0],
|
|
"input_path": batch_data.input_paths[0],
|
|
|
"page_index": batch_data.page_indexes[0],
|
|
"page_index": batch_data.page_indexes[0],
|
|
@@ -396,6 +557,7 @@ class LayoutParsingPipeline(BasePipeline):
|
|
|
"table_res_list": table_res_list,
|
|
"table_res_list": table_res_list,
|
|
|
"seal_res_list": seal_res_list,
|
|
"seal_res_list": seal_res_list,
|
|
|
"formula_res_list": formula_res_list,
|
|
"formula_res_list": formula_res_list,
|
|
|
|
|
+ "parsing_res_list": parsing_res_list,
|
|
|
"model_settings": model_settings,
|
|
"model_settings": model_settings,
|
|
|
}
|
|
}
|
|
|
yield LayoutParsingResult(single_img_res)
|
|
yield LayoutParsingResult(single_img_res)
|