# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
import copy
import math
import re
from pathlib import Path
from typing import List
import numpy as np
from PIL import Image, ImageDraw
from ...common.result import (
BaseCVResult,
HtmlMixin,
JsonMixin,
MarkdownMixin,
XlsxMixin,
)
class LayoutParsingResultV2(BaseCVResult, HtmlMixin, XlsxMixin, MarkdownMixin):
"""Layout Parsing Result V2"""
def __init__(self, data) -> None:
"""Initializes a new instance of the class with the specified data."""
super().__init__(data)
HtmlMixin.__init__(self)
XlsxMixin.__init__(self)
MarkdownMixin.__init__(self)
JsonMixin.__init__(self)
self.title_pattern = self._build_title_pattern()
def _build_title_pattern(self):
# Precompiled regex pattern for matching numbering at the beginning of the title
numbering_pattern = (
r"(?:"
+ r"[1-9][0-9]*(?:\.[1-9][0-9]*)*[\.、]?|"
+ r"[\(\(](?:[1-9][0-9]*|["
r"一二三四五六七八九十百千万亿零壹贰叁肆伍陆柒捌玖拾]+)[\)\)]|" + r"["
r"一二三四五六七八九十百千万亿零壹贰叁肆伍陆柒捌玖拾]+"
r"[、\.]?|" + r"(?:I|II|III|IV|V|VI|VII|VIII|IX|X)\.?" + r")"
)
return re.compile(r"^\s*(" + numbering_pattern + r")(\s*)(.*)$")
def _get_input_fn(self):
fn = super()._get_input_fn()
if (page_idx := self["page_index"]) is not None:
fp = Path(fn)
stem, suffix = fp.stem, fp.suffix
return f"{stem}_{page_idx}{suffix}"
else:
return fn
def _to_img(self) -> dict[str, np.ndarray]:
from .utils import get_show_color
res_img_dict = {}
model_settings = self["model_settings"]
if model_settings["use_doc_preprocessor"]:
for key, value in self["doc_preprocessor_res"].img.items():
res_img_dict[key] = value
res_img_dict["layout_det_res"] = self["layout_det_res"].img["res"]
if model_settings["use_region_detection"]:
res_img_dict["region_det_res"] = self["region_det_res"].img["res"]
if model_settings["use_general_ocr"] or model_settings["use_table_recognition"]:
res_img_dict["overall_ocr_res"] = self["overall_ocr_res"].img["ocr_res_img"]
if model_settings["use_table_recognition"] and len(self["table_res_list"]) > 0:
table_cell_img = Image.fromarray(
copy.deepcopy(self["doc_preprocessor_res"]["output_img"])
)
table_draw = ImageDraw.Draw(table_cell_img)
rectangle_color = (255, 0, 0)
for sno in range(len(self["table_res_list"])):
table_res = self["table_res_list"][sno]
cell_box_list = table_res["cell_box_list"]
for box in cell_box_list:
x1, y1, x2, y2 = [int(pos) for pos in box]
table_draw.rectangle(
[x1, y1, x2, y2], outline=rectangle_color, width=2
)
res_img_dict["table_cell_img"] = table_cell_img
if model_settings["use_seal_recognition"] and len(self["seal_res_list"]) > 0:
for sno in range(len(self["seal_res_list"])):
seal_res = self["seal_res_list"][sno]
seal_region_id = seal_res["seal_region_id"]
sub_seal_res_dict = seal_res.img
key = f"seal_res_region{seal_region_id}"
res_img_dict[key] = sub_seal_res_dict["ocr_res_img"]
# for layout ordering image
image = Image.fromarray(self["doc_preprocessor_res"]["output_img"][:, :, ::-1])
draw = ImageDraw.Draw(image, "RGBA")
parsing_result: List[LayoutParsingBlock] = self["parsing_res_list"]
for block in parsing_result:
bbox = block.bbox
index = block.index
label = block.label
fill_color = get_show_color(label)
draw.rectangle(bbox, fill=fill_color)
if index is not None:
text_position = (bbox[2] + 2, bbox[1] - 10)
draw.text(text_position, str(index), fill="red")
res_img_dict["layout_order_res"] = image
return res_img_dict
def _to_str(self, *args, **kwargs) -> dict[str, str]:
"""Converts the instance's attributes to a dictionary and then to a string.
Args:
*args: Additional positional arguments passed to the base class method.
**kwargs: Additional keyword arguments passed to the base class method.
Returns:
Dict[str, str]: A dictionary with the instance's attributes converted to strings.
"""
data = {}
data["input_path"] = self["input_path"]
data["page_index"] = self["page_index"]
model_settings = self["model_settings"]
data["model_settings"] = model_settings
if self["model_settings"]["use_doc_preprocessor"]:
data["doc_preprocessor_res"] = self["doc_preprocessor_res"].str["res"]
data["layout_det_res"] = self["layout_det_res"].str["res"]
if model_settings["use_general_ocr"] or model_settings["use_table_recognition"]:
data["overall_ocr_res"] = self["overall_ocr_res"].str["res"]
if model_settings["use_table_recognition"] and len(self["table_res_list"]) > 0:
data["table_res_list"] = []
for sno in range(len(self["table_res_list"])):
table_res = self["table_res_list"][sno]
data["table_res_list"].append(table_res.str["res"])
if model_settings["use_seal_recognition"] and len(self["seal_res_list"]) > 0:
data["seal_res_list"] = []
for sno in range(len(self["seal_res_list"])):
seal_res = self["seal_res_list"][sno]
data["seal_res_list"].append(seal_res.str["res"])
if (
model_settings["use_formula_recognition"]
and len(self["formula_res_list"]) > 0
):
data["formula_res_list"] = []
for sno in range(len(self["formula_res_list"])):
formula_res = self["formula_res_list"][sno]
data["formula_res_list"].append(formula_res.str["res"])
return JsonMixin._to_str(data, *args, **kwargs)
def _to_json(self, *args, **kwargs) -> dict[str, str]:
"""
Converts the object's data to a JSON dictionary.
Args:
*args: Positional arguments passed to the JsonMixin._to_json method.
**kwargs: Keyword arguments passed to the JsonMixin._to_json method.
Returns:
Dict[str, str]: A dictionary containing the object's data in JSON format.
"""
data = {}
data["input_path"] = self["input_path"]
data["page_index"] = self["page_index"]
model_settings = self["model_settings"]
data["model_settings"] = model_settings
parsing_res_list = self["parsing_res_list"]
parsing_res_list = [
{
"block_label": parsing_res.label,
"block_content": parsing_res.content,
"block_bbox": parsing_res.bbox,
}
for parsing_res in parsing_res_list
]
data["parsing_res_list"] = parsing_res_list
if self["model_settings"]["use_doc_preprocessor"]:
data["doc_preprocessor_res"] = self["doc_preprocessor_res"].json["res"]
data["layout_det_res"] = self["layout_det_res"].json["res"]
if model_settings["use_general_ocr"] or model_settings["use_table_recognition"]:
data["overall_ocr_res"] = self["overall_ocr_res"].json["res"]
if model_settings["use_table_recognition"] and len(self["table_res_list"]) > 0:
data["table_res_list"] = []
for sno in range(len(self["table_res_list"])):
table_res = self["table_res_list"][sno]
data["table_res_list"].append(table_res.json["res"])
if model_settings["use_seal_recognition"] and len(self["seal_res_list"]) > 0:
data["seal_res_list"] = []
for sno in range(len(self["seal_res_list"])):
seal_res = self["seal_res_list"][sno]
data["seal_res_list"].append(seal_res.json["res"])
if (
model_settings["use_formula_recognition"]
and len(self["formula_res_list"]) > 0
):
data["formula_res_list"] = []
for sno in range(len(self["formula_res_list"])):
formula_res = self["formula_res_list"][sno]
data["formula_res_list"].append(formula_res.json["res"])
return JsonMixin._to_json(data, *args, **kwargs)
def _to_html(self) -> dict[str, str]:
"""Converts the prediction to its corresponding HTML representation.
Returns:
Dict[str, str]: The str type HTML representation result.
"""
model_settings = self["model_settings"]
res_html_dict = {}
if model_settings["use_table_recognition"] and len(self["table_res_list"]) > 0:
for sno in range(len(self["table_res_list"])):
table_res = self["table_res_list"][sno]
table_region_id = table_res["table_region_id"]
key = f"table_{table_region_id}"
res_html_dict[key] = table_res.html["pred"]
return res_html_dict
def _to_xlsx(self) -> dict[str, str]:
"""Converts the prediction HTML to an XLSX file path.
Returns:
Dict[str, str]: The str type XLSX representation result.
"""
model_settings = self["model_settings"]
res_xlsx_dict = {}
if model_settings["use_table_recognition"] and len(self["table_res_list"]) > 0:
for sno in range(len(self["table_res_list"])):
table_res = self["table_res_list"][sno]
table_region_id = table_res["table_region_id"]
key = f"table_{table_region_id}"
res_xlsx_dict[key] = table_res.xlsx["pred"]
return res_xlsx_dict
def _to_markdown(self) -> dict:
"""
Save the parsing result to a Markdown file.
Returns:
Dict
"""
def _format_data(obj):
def format_title(title):
"""
Normalize chapter title.
Add the '#' to indicate the level of the title.
If numbering exists, ensure there's exactly one space between it and the title content.
If numbering does not exist, return the original title unchanged.
:param title: Original chapter title string.
:return: Normalized chapter title string.
"""
match = self.title_pattern.match(title)
if match:
numbering = match.group(1).strip()
title_content = match.group(3).lstrip()
# Return numbering and title content separated by one space
title = numbering + " " + title_content
title = title.rstrip(".")
level = (
title.count(
".",
)
+ 1
if "." in title
else 1
)
return f"#{'#' * level} {title}".replace("-\n", "").replace(
"\n",
" ",
)
# def format_centered_text():
# return (
# f'
{block.content}
'.replace(
# "-\n",
# "",
# ).replace("\n", " ")
# + "\n"
# )
def format_centered_text():
return block.content
# def format_image():
# img_tags = []
# image_path = "".join(block.image.keys())
# img_tags.append(
# ''.format(
# image_path.replace("-\n", "").replace("\n", " "),
# ),
# )
# return "\n".join(img_tags)
def format_image():
img_tags = []
image_path = "".join(block.image.keys())
img_tags.append(
"".format(image_path.replace("-\n", "").replace("\n", " "))
)
return "\n".join(img_tags)
def format_first_line(templates, format_func, spliter):
lines = block.content.split(spliter)
for idx in range(len(lines)):
line = lines[idx]
if line.strip() == "":
continue
if line.lower() in templates:
lines[idx] = format_func(line)
break
return spliter.join(lines)
def format_table():
return "\n" + block.content
def get_seg_flag(block: LayoutParsingBlock, prev_block: LayoutParsingBlock):
seg_start_flag = True
seg_end_flag = True
block_box = block.bbox
context_left_coordinate = block_box[0]
context_right_coordinate = block_box[2]
seg_start_coordinate = block.seg_start_coordinate
seg_end_coordinate = block.seg_end_coordinate
if prev_block is not None:
prev_block_bbox = prev_block.bbox
num_of_prev_lines = prev_block.num_of_lines
pre_block_seg_end_coordinate = prev_block.seg_end_coordinate
prev_end_space_small = (
abs(prev_block_bbox[2] - pre_block_seg_end_coordinate) < 10
)
prev_lines_more_than_one = num_of_prev_lines > 1
overlap_blocks = context_left_coordinate < prev_block_bbox[2]
# update context_left_coordinate and context_right_coordinate
if overlap_blocks:
context_left_coordinate = min(
prev_block_bbox[0], context_left_coordinate
)
context_right_coordinate = max(
prev_block_bbox[2], context_right_coordinate
)
prev_end_space_small = (
abs(context_right_coordinate - pre_block_seg_end_coordinate)
< 10
)
edge_distance = 0
else:
edge_distance = abs(block_box[0] - prev_block_bbox[2])
current_start_space_small = (
seg_start_coordinate - context_left_coordinate < 10
)
if (
prev_end_space_small
and current_start_space_small
and prev_lines_more_than_one
and edge_distance < max(prev_block.width, block.width)
):
seg_start_flag = False
else:
if seg_start_coordinate - context_left_coordinate < 10:
seg_start_flag = False
if context_right_coordinate - seg_end_coordinate < 10:
seg_end_flag = False
return seg_start_flag, seg_end_flag
handlers = {
"paragraph_title": lambda: format_title(block.content),
"abstract_title": lambda: format_title(block.content),
"reference_title": lambda: format_title(block.content),
"content_title": lambda: format_title(block.content),
"doc_title": lambda: f"# {block.content}".replace(
"-\n",
"",
).replace("\n", " "),
"table_title": lambda: format_centered_text(),
"figure_title": lambda: format_centered_text(),
"chart_title": lambda: format_centered_text(),
"text": lambda: block.content.replace("\n\n", "\n").replace(
"\n", "\n\n"
),
"abstract": lambda: format_first_line(
["摘要", "abstract"], lambda l: f"## {l}\n", " "
),
"content": lambda: block.content.replace("-\n", " \n").replace(
"\n", " \n"
),
"image": lambda: format_image(),
"chart": lambda: format_image(),
"formula": lambda: f"$${block.content}$$",
"table": format_table,
"reference": lambda: format_first_line(
["参考文献", "references"], lambda l: f"## {l}", "\n"
),
"algorithm": lambda: block.content.strip("\n"),
"seal": lambda: f"Words of Seals:\n{block.content}",
}
parsing_res_list = obj["parsing_res_list"]
markdown_content = ""
last_label = None
seg_start_flag = None
seg_end_flag = None
prev_block = None
page_first_element_seg_start_flag = None
page_last_element_seg_end_flag = None
for block in parsing_res_list:
seg_start_flag, seg_end_flag = get_seg_flag(block, prev_block)
label = block.label
page_first_element_seg_start_flag = (
seg_start_flag
if (page_first_element_seg_start_flag is None)
else page_first_element_seg_start_flag
)
handler = handlers.get(label)
if handler:
prev_block = block
if label == last_label == "text" and seg_start_flag == False:
markdown_content += handler()
else:
markdown_content += (
"\n\n" + handler() if markdown_content else handler()
)
last_label = label
page_last_element_seg_end_flag = seg_end_flag
return markdown_content, (
page_first_element_seg_start_flag,
page_last_element_seg_end_flag,
)
markdown_info = dict()
markdown_info["markdown_texts"], (
page_first_element_seg_start_flag,
page_last_element_seg_end_flag,
) = _format_data(self)
markdown_info["page_continuation_flags"] = (
page_first_element_seg_start_flag,
page_last_element_seg_end_flag,
)
markdown_info["markdown_images"] = {}
for img in self["imgs_in_doc"]:
markdown_info["markdown_images"][img["path"]] = img["img"]
return markdown_info
class LayoutParsingBlock:
def __init__(self, label, bbox, content="") -> None:
self.label = label
self.order_label = "other"
self.bbox = [int(item) for item in bbox]
self.content = content
self.seg_start_coordinate = float("inf")
self.seg_end_coordinate = float("-inf")
self.width = bbox[2] - bbox[0]
self.height = bbox[3] - bbox[1]
self.area = self.width * self.height
self.num_of_lines = 1
self.image = None
self.index = None
self.visual_index = None
self.orientation = self.get_bbox_orientation()
self.child_blocks = []
self.update_orientation_info()
def __str__(self) -> str:
return f"{self.__dict__}"
def __repr__(self) -> str:
_str = f"\n\n#################\nlabel:\t{self.label}\nregion_label:\t{self.order_label}\nbbox:\t{self.bbox}\ncontent:\t{self.content}\n#################"
return _str
def to_dict(self) -> dict:
return self.__dict__
def update_orientation_info(self) -> None:
if self.order_label == "vision":
self.orientation = "horizontal"
if self.orientation == "horizontal":
self.secondary_orientation = "vertical"
self.short_side_length = self.height
self.long_side_length = self.width
self.start_coordinate = self.bbox[0]
self.end_coordinate = self.bbox[2]
self.secondary_orientation_start_coordinate = self.bbox[1]
self.secondary_orientation_end_coordinate = self.bbox[3]
else:
self.secondary_orientation = "horizontal"
self.short_side_length = self.width
self.long_side_length = self.height
self.start_coordinate = self.bbox[1]
self.end_coordinate = self.bbox[3]
self.secondary_orientation_start_coordinate = self.bbox[0]
self.secondary_orientation_end_coordinate = self.bbox[2]
def append_child_block(self, child_block: LayoutParsingBlock) -> None:
if not self.child_blocks:
self.ori_bbox = self.bbox.copy()
x1, y1, x2, y2 = self.bbox
x1_child, y1_child, x2_child, y2_child = child_block.bbox
union_bbox = (
min(x1, x1_child),
min(y1, y1_child),
max(x2, x2_child),
max(y2, y2_child),
)
self.bbox = union_bbox
self.update_orientation_info()
child_blocks = [child_block]
if child_block.child_blocks:
child_blocks.extend(child_block.get_child_blocks())
self.child_blocks.extend(child_blocks)
def get_child_blocks(self) -> list:
self.bbox = self.ori_bbox
child_blocks = self.child_blocks.copy()
self.child_blocks = []
return child_blocks
def get_centroid(self) -> tuple:
x1, y1, x2, y2 = self.bbox
centroid = ((x1 + x2) / 2, (y1 + y2) / 2)
return centroid
def get_bbox_orientation(self, orientation_ratio: float = 1.0) -> bool:
"""
Determine if a bounding box is horizontal or vertical.
Args:
bbox (List[float]): Bounding box [x_min, y_min, x_max, y_max].
orientation_ratio (float): Ratio for determining orientation. Default is 1.0.
Returns:
str: "horizontal" or "vertical".
"""
return (
"horizontal"
if self.width * orientation_ratio >= self.height
else "vertical"
)
class LayoutParsingRegion:
def __init__(
self, region_bbox, blocks: List[LayoutParsingBlock] = [], block_label_mapping={}
) -> None:
self.region_bbox = region_bbox
self.blocks = blocks
self.block_map = {}
self.update_config(block_label_mapping)
self.orientation = None
self.calculate_bbox_metrics()
def update_config(self, block_label_mapping):
self.block_map = {}
self.config = copy.deepcopy(block_label_mapping)
self.config["region_bbox"] = self.region_bbox
horizontal_text_block_num = 0
for idx, block in enumerate(self.blocks):
label = block.label
if (
block.order_label not in ["vision", "vision_title"]
and block.orientation == "horizontal"
):
horizontal_text_block_num += 1
self.block_map[idx] = block
self.update_layout_order_config_block_index(label, idx)
text_block_num = (
len(self.blocks)
- len(self.config.get("vision_block_idxes", []))
- len(self.config.get("vision_title_block_idxes", []))
)
self.orientation = (
"horizontal"
if horizontal_text_block_num >= text_block_num * 0.5
else "vertical"
)
self.config["region_orientation"] = self.orientation
def calculate_bbox_metrics(self):
x1, y1, x2, y2 = self.region_bbox
x_center, y_center = (x1 + x2) / 2, (y1 + y2) / 2
self.euclidean_distance = math.sqrt(((x1) ** 2 + (y1) ** 2))
self.center_euclidean_distance = math.sqrt(((x_center) ** 2 + (y_center) ** 2))
self.angle_rad = math.atan2(y_center, x_center)
def sort(self):
from .xycut_enhanced import xycut_enhanced
return xycut_enhanced(self.blocks, self.config)
def update_layout_order_config_block_index(
self, block_label: str, block_idx: int
) -> None:
doc_title_labels = self.config["doc_title_labels"]
paragraph_title_labels = self.config["paragraph_title_labels"]
vision_labels = self.config["vision_labels"]
vision_title_labels = self.config["vision_title_labels"]
header_labels = self.config["header_labels"]
unordered_labels = self.config["unordered_labels"]
footer_labels = self.config["footer_labels"]
text_labels = self.config["text_labels"]
self.config.setdefault("doc_title_block_idxes", [])
self.config.setdefault("paragraph_title_block_idxes", [])
self.config.setdefault("vision_block_idxes", [])
self.config.setdefault("vision_title_block_idxes", [])
self.config.setdefault("unordered_block_idxes", [])
self.config.setdefault("text_block_idxes", [])
self.config.setdefault("header_block_idxes", [])
self.config.setdefault("footer_block_idxes", [])
if block_label in doc_title_labels:
self.config["doc_title_block_idxes"].append(block_idx)
if block_label in paragraph_title_labels:
self.config["paragraph_title_block_idxes"].append(block_idx)
if block_label in vision_labels:
self.config["vision_block_idxes"].append(block_idx)
if block_label in vision_title_labels:
self.config["vision_title_block_idxes"].append(block_idx)
if block_label in unordered_labels:
self.config["unordered_block_idxes"].append(block_idx)
if block_label in text_labels:
self.config["text_block_idxes"].append(block_idx)
if block_label in header_labels:
self.config["header_block_idxes"].append(block_idx)
if block_label in footer_labels:
self.config["footer_block_idxes"].append(block_idx)