|
|
@@ -25,6 +25,7 @@ from PIL import Image
|
|
|
import uuid
|
|
|
import re
|
|
|
from pathlib import Path
|
|
|
+from copy import deepcopy
|
|
|
from typing import Optional, Union, List, Tuple, Dict, Any
|
|
|
from ..ocr.result import OCRResult
|
|
|
from ...models.object_detection.result import DetResult
|
|
|
@@ -252,6 +253,7 @@ def _adjust_span_text(span: List[str], prepend: bool = False, append: bool = Fal
|
|
|
span[1] = "\n" + span[1]
|
|
|
if append:
|
|
|
span[1] = span[1] + "\n"
|
|
|
+ return span
|
|
|
|
|
|
|
|
|
def _format_line(
|
|
|
@@ -277,17 +279,127 @@ def _format_line(
|
|
|
|
|
|
if not is_reference:
|
|
|
if first_span[0][0] - layout_min > 10:
|
|
|
- _adjust_span_text(first_span, prepend=True)
|
|
|
+ first_span = _adjust_span_text(first_span, prepend=True)
|
|
|
if layout_max - end_span[0][2] > 10:
|
|
|
- _adjust_span_text(end_span, append=True)
|
|
|
+ end_span = _adjust_span_text(end_span, append=True)
|
|
|
else:
|
|
|
if first_span[0][0] - layout_min < 5:
|
|
|
- _adjust_span_text(first_span, prepend=True)
|
|
|
+ first_span = _adjust_span_text(first_span, prepend=True)
|
|
|
if layout_max - end_span[0][2] > 20:
|
|
|
- _adjust_span_text(end_span, append=True)
|
|
|
+ end_span = _adjust_span_text(end_span, append=True)
|
|
|
+
|
|
|
+ line[0] = first_span
|
|
|
+ line[-1] = end_span
|
|
|
+
|
|
|
+ return line
|
|
|
+
|
|
|
+
|
|
|
+def split_boxes_if_x_contained(boxes, offset=1e-5):
|
|
|
+ """
|
|
|
+ Check if there is any complete containment in the x-direction
|
|
|
+ between the bounding boxes and split the containing box accordingly.
|
|
|
+
|
|
|
+ Args:
|
|
|
+ boxes (list of lists): Each element is a list containing an ndarray of length 4, a description, and a label.
|
|
|
+ offset (float): A small offset value to ensure that the split boxes are not too close to the original boxes.
|
|
|
+ Returns:
|
|
|
+ A new list of boxes, including split boxes, with the same `rec_text` and `label` attributes.
|
|
|
+ """
|
|
|
+
|
|
|
+ def is_x_contained(box_a, box_b):
|
|
|
+ """Check if box_a completely contains box_b in the x-direction."""
|
|
|
+ return box_a[0][0] <= box_b[0][0] and box_a[0][2] >= box_b[0][2]
|
|
|
+
|
|
|
+ new_boxes = []
|
|
|
+
|
|
|
+ for i in range(len(boxes)):
|
|
|
+ box_a = boxes[i]
|
|
|
+ is_split = False
|
|
|
+ for j in range(len(boxes)):
|
|
|
+ if i == j:
|
|
|
+ continue
|
|
|
+ box_b = boxes[j]
|
|
|
+ if is_x_contained(box_a, box_b):
|
|
|
+ is_split = True
|
|
|
+ # Split box_a based on the x-coordinates of box_b
|
|
|
+ if box_a[0][0] < box_b[0][0]:
|
|
|
+ w = box_b[0][0] - offset - box_a[0][0]
|
|
|
+ if w > 1:
|
|
|
+ new_boxes.append(
|
|
|
+ [
|
|
|
+ np.array(
|
|
|
+ [
|
|
|
+ box_a[0][0],
|
|
|
+ box_a[0][1],
|
|
|
+ box_b[0][0] - offset,
|
|
|
+ box_a[0][3],
|
|
|
+ ]
|
|
|
+ ),
|
|
|
+ box_a[1],
|
|
|
+ box_a[2],
|
|
|
+ ]
|
|
|
+ )
|
|
|
+ if box_a[0][2] > box_b[0][2]:
|
|
|
+ w = box_a[0][2] - box_b[0][2] + offset
|
|
|
+ if w > 1:
|
|
|
+ box_a = [
|
|
|
+ np.array(
|
|
|
+ [
|
|
|
+ box_b[0][2] + offset,
|
|
|
+ box_a[0][1],
|
|
|
+ box_a[0][2],
|
|
|
+ box_a[0][3],
|
|
|
+ ]
|
|
|
+ ),
|
|
|
+ box_a[1],
|
|
|
+ box_a[2],
|
|
|
+ ]
|
|
|
+ if j == len(boxes) - 1 and is_split:
|
|
|
+ new_boxes.append(box_a)
|
|
|
+ if not is_split:
|
|
|
+ new_boxes.append(box_a)
|
|
|
+
|
|
|
+ return new_boxes
|
|
|
+
|
|
|
+
|
|
|
+def _sort_line_by_x_projection(
|
|
|
+ input_img: np.ndarray,
|
|
|
+ general_ocr_pipeline: Any,
|
|
|
+ line: List[List[Union[List[int], str]]],
|
|
|
+) -> None:
|
|
|
+ """
|
|
|
+ Sort a line of text spans based on their vertical position within the layout bounding box.
|
|
|
+
|
|
|
+ Args:
|
|
|
+ input_img (ndarray): The input image used for OCR.
|
|
|
+ general_ocr_pipeline (Any): The general OCR pipeline used for text recognition.
|
|
|
+ line (list): A list of spans, where each span is a list containing a bounding box and text.
|
|
|
+
|
|
|
+ Returns:
|
|
|
+ list: The sorted line of text spans.
|
|
|
+ """
|
|
|
+ splited_boxes = split_boxes_if_x_contained(line)
|
|
|
+ splited_lines = []
|
|
|
+ if len(line) != len(splited_boxes):
|
|
|
+ splited_boxes.sort(key=lambda span: span[0][0])
|
|
|
+ text_rec_model = general_ocr_pipeline.text_rec_model
|
|
|
+ for span in splited_boxes:
|
|
|
+ if span[2] == "text":
|
|
|
+ crop_img = input_img[
|
|
|
+ int(span[0][1]) : int(span[0][3]),
|
|
|
+ int(span[0][0]) : int(span[0][2]),
|
|
|
+ ]
|
|
|
+ span[1] = next(text_rec_model([crop_img]))["rec_text"]
|
|
|
+ splited_lines.append(span)
|
|
|
+ else:
|
|
|
+ splited_lines = line
|
|
|
+
|
|
|
+ return splited_lines
|
|
|
|
|
|
|
|
|
def _sort_ocr_res_by_y_projection(
|
|
|
+ input_img: np.ndarray,
|
|
|
+ general_ocr_pipeline: Any,
|
|
|
label: Any,
|
|
|
block_bbox: Tuple[int, int, int, int],
|
|
|
ocr_res: Dict[str, List[Any]],
|
|
|
@@ -297,6 +409,8 @@ def _sort_ocr_res_by_y_projection(
|
|
|
Sorts OCR results based on their spatial arrangement, grouping them into lines and blocks.
|
|
|
|
|
|
Args:
|
|
|
+ input_img (ndarray): The input image used for OCR.
|
|
|
+ general_ocr_pipeline (Any): The general OCR pipeline used for text recognition.
|
|
|
label (Any): The label associated with the OCR results. It's not used in the function but might be
|
|
|
relevant for other parts of the calling context.
|
|
|
block_bbox (Tuple[int, int, int, int]): A tuple representing the layout bounding box, defined as
|
|
|
@@ -317,12 +431,13 @@ def _sort_ocr_res_by_y_projection(
|
|
|
|
|
|
boxes = ocr_res["boxes"]
|
|
|
rec_texts = ocr_res["rec_texts"]
|
|
|
+ rec_labels = ocr_res["rec_labels"]
|
|
|
|
|
|
x_min, _, x_max, _ = block_bbox
|
|
|
inline_x_min = min([box[0] for box in boxes])
|
|
|
inline_x_max = max([box[2] for box in boxes])
|
|
|
|
|
|
- spans = list(zip(boxes, rec_texts))
|
|
|
+ spans = list(zip(boxes, rec_texts, rec_labels))
|
|
|
|
|
|
spans.sort(key=lambda span: span[0][1])
|
|
|
spans = [list(span) for span in spans]
|
|
|
@@ -349,16 +464,21 @@ def _sort_ocr_res_by_y_projection(
|
|
|
if current_line:
|
|
|
lines.append(current_line)
|
|
|
|
|
|
+ new_lines = []
|
|
|
for line in lines:
|
|
|
line.sort(key=lambda span: span[0][0])
|
|
|
+
|
|
|
+ ocr_labels = [span[2] for span in line]
|
|
|
+ if "formula" in ocr_labels:
|
|
|
+ line = _sort_line_by_x_projection(input_img, general_ocr_pipeline, line)
|
|
|
if label == "reference":
|
|
|
line = _format_line(line, inline_x_min, inline_x_max, is_reference=True)
|
|
|
else:
|
|
|
line = _format_line(line, x_min, x_max)
|
|
|
+ new_lines.append(line)
|
|
|
|
|
|
- # Flatten lines back into a single list for boxes and texts
|
|
|
- ocr_res["boxes"] = [span[0] for line in lines for span in line]
|
|
|
- ocr_res["rec_texts"] = [span[1] + " " for line in lines for span in line]
|
|
|
+ ocr_res["boxes"] = [span[0] for line in new_lines for span in line]
|
|
|
+ ocr_res["rec_texts"] = [span[1] + " " for line in new_lines for span in line]
|
|
|
|
|
|
return ocr_res
|
|
|
|
|
|
@@ -417,6 +537,7 @@ def _process_text(input_text: str) -> str:
|
|
|
|
|
|
|
|
|
def get_single_block_parsing_res(
|
|
|
+ general_ocr_pipeline: Any,
|
|
|
overall_ocr_res: OCRResult,
|
|
|
layout_det_res: DetResult,
|
|
|
table_res_list: list,
|
|
|
@@ -451,10 +572,16 @@ def get_single_block_parsing_res(
|
|
|
input_img = overall_ocr_res["doc_preprocessor_res"]["output_img"]
|
|
|
seal_index = 0
|
|
|
|
|
|
- for box_info in layout_det_res["boxes"]:
|
|
|
+ layout_det_res_list, _ = _remove_overlap_blocks(
|
|
|
+ deepcopy(layout_det_res["boxes"]),
|
|
|
+ threshold=0.5,
|
|
|
+ smaller=True,
|
|
|
+ )
|
|
|
+
|
|
|
+ for box_info in layout_det_res_list:
|
|
|
block_bbox = box_info["coordinate"]
|
|
|
label = box_info["label"]
|
|
|
- rec_res = {"boxes": [], "rec_texts": [], "flag": False}
|
|
|
+ rec_res = {"boxes": [], "rec_texts": [], "rec_labels": [], "flag": False}
|
|
|
seg_start_flag = True
|
|
|
seg_end_flag = True
|
|
|
|
|
|
@@ -503,10 +630,15 @@ def get_single_block_parsing_res(
|
|
|
rec_res["rec_texts"].append(
|
|
|
overall_ocr_res["rec_texts"][box_no],
|
|
|
)
|
|
|
+ rec_res["rec_labels"].append(
|
|
|
+ overall_ocr_res["rec_labels"][box_no],
|
|
|
+ )
|
|
|
rec_res["flag"] = True
|
|
|
|
|
|
if rec_res["flag"]:
|
|
|
- rec_res = _sort_ocr_res_by_y_projection(label, block_bbox, rec_res, 0.7)
|
|
|
+ rec_res = _sort_ocr_res_by_y_projection(
|
|
|
+ input_img, general_ocr_pipeline, label, block_bbox, rec_res, 0.7
|
|
|
+ )
|
|
|
rec_res_first_bbox = rec_res["boxes"][0]
|
|
|
rec_res_end_bbox = rec_res["boxes"][-1]
|
|
|
if rec_res_first_bbox[0] - block_bbox[0] < 10:
|
|
|
@@ -547,6 +679,20 @@ def get_single_block_parsing_res(
|
|
|
},
|
|
|
)
|
|
|
|
|
|
+ if len(layout_det_res_list) == 0:
|
|
|
+ for ocr_rec_box, ocr_rec_text in zip(
|
|
|
+ overall_ocr_res["rec_boxes"], overall_ocr_res["rec_texts"]
|
|
|
+ ):
|
|
|
+ single_block_layout_parsing_res.append(
|
|
|
+ {
|
|
|
+ "block_label": "text",
|
|
|
+ "block_content": ocr_rec_text,
|
|
|
+ "block_bbox": ocr_rec_box,
|
|
|
+ "seg_start_flag": True,
|
|
|
+ "seg_end_flag": True,
|
|
|
+ },
|
|
|
+ )
|
|
|
+
|
|
|
single_block_layout_parsing_res = get_layout_ordering(
|
|
|
single_block_layout_parsing_res,
|
|
|
no_mask_labels=[
|
|
|
@@ -875,8 +1021,8 @@ def _remove_overlap_blocks(
|
|
|
continue
|
|
|
# Check for overlap and determine which block to remove
|
|
|
overlap_box_index = _get_minbox_if_overlap_by_ratio(
|
|
|
- block1["block_bbox"],
|
|
|
- block2["block_bbox"],
|
|
|
+ block1["coordinate"],
|
|
|
+ block2["coordinate"],
|
|
|
threshold,
|
|
|
smaller=smaller,
|
|
|
)
|
|
|
@@ -1384,11 +1530,6 @@ def get_layout_ordering(
|
|
|
vision_labels = ["image", "table", "seal", "chart", "figure"]
|
|
|
vision_title_labels = ["table_title", "chart_title", "figure_title"]
|
|
|
|
|
|
- parsing_res_list, _ = _remove_overlap_blocks(
|
|
|
- parsing_res_list,
|
|
|
- threshold=0.5,
|
|
|
- smaller=True,
|
|
|
- )
|
|
|
parsing_res_list, pre_cuts = _get_sub_category(parsing_res_list, title_text_labels)
|
|
|
|
|
|
parsing_res_by_pre_cuts_list = []
|