|
|
@@ -18,7 +18,21 @@ from magic_pdf.libs.config_reader import get_local_layoutreader_model_dir
|
|
|
from magic_pdf.libs.convert_utils import dict_to_list
|
|
|
from magic_pdf.libs.hash_utils import compute_md5
|
|
|
from magic_pdf.libs.local_math import float_equal
|
|
|
+from magic_pdf.libs.pdf_image_tools import cut_image_to_pil_image
|
|
|
from magic_pdf.model.magic_model import MagicModel
|
|
|
+
|
|
|
+os.environ['NO_ALBUMENTATIONS_UPDATE'] = '1' # 禁止albumentations检查更新
|
|
|
+os.environ['YOLO_VERBOSE'] = 'False' # disable yolo logger
|
|
|
+
|
|
|
+try:
|
|
|
+ import torchtext
|
|
|
+
|
|
|
+ if torchtext.__version__ >= "0.18.0":
|
|
|
+ torchtext.disable_torchtext_deprecation_warning()
|
|
|
+except ImportError:
|
|
|
+ pass
|
|
|
+from magic_pdf.model.sub_modules.model_init import AtomModelSingleton
|
|
|
+
|
|
|
from magic_pdf.para.para_split_v3 import para_split
|
|
|
from magic_pdf.pre_proc.citationmarker_remove import remove_citation_marker
|
|
|
from magic_pdf.pre_proc.construct_page_dict import \
|
|
|
@@ -74,7 +88,150 @@ def __replace_STX_ETX(text_str: str):
|
|
|
return text_str
|
|
|
|
|
|
|
|
|
-def txt_spans_extract(pdf_page, inline_equations, interline_equations):
|
|
|
+def chars_to_content(span):
|
|
|
+ # # 先给chars按char['bbox']的x坐标排序
|
|
|
+ # span['chars'] = sorted(span['chars'], key=lambda x: x['bbox'][0])
|
|
|
+
|
|
|
+ # 先给chars按char['bbox']的中心点的x坐标排序
|
|
|
+ span['chars'] = sorted(span['chars'], key=lambda x: (x['bbox'][0] + x['bbox'][2]) / 2)
|
|
|
+ content = ''
|
|
|
+
|
|
|
+ # 求char的平均宽度
|
|
|
+ if len(span['chars']) == 0:
|
|
|
+ span['content'] = content
|
|
|
+ del span['chars']
|
|
|
+ return
|
|
|
+ else:
|
|
|
+ char_width_sum = sum([char['bbox'][2] - char['bbox'][0] for char in span['chars']])
|
|
|
+ char_avg_width = char_width_sum / len(span['chars'])
|
|
|
+
|
|
|
+ for char in span['chars']:
|
|
|
+ # 如果下一个char的x0和上一个char的x1距离超过一个字符宽度,则需要在中间插入一个空格
|
|
|
+ if char['bbox'][0] - span['chars'][span['chars'].index(char) - 1]['bbox'][2] > char_avg_width:
|
|
|
+ content += ' '
|
|
|
+ content += char['c']
|
|
|
+ span['content'] = __replace_STX_ETX(content)
|
|
|
+ del span['chars']
|
|
|
+
|
|
|
+
|
|
|
+LINE_STOP_FLAG = ('.', '!', '?', '。', '!', '?', ')', ')', '"', '”', ':', ':', ';', ';', ']', '】', '}', '}', '>', '》', '、', ',', ',')
|
|
|
+def fill_char_in_spans(spans, all_chars):
|
|
|
+
|
|
|
+ for char in all_chars:
|
|
|
+ for span in spans:
|
|
|
+ # 判断char是否属于LINE_STOP_FLAG
|
|
|
+ if char['c'] in LINE_STOP_FLAG:
|
|
|
+ char_is_line_stop_flag = True
|
|
|
+ else:
|
|
|
+ char_is_line_stop_flag = False
|
|
|
+ if calculate_char_in_span(char['bbox'], span['bbox'], char_is_line_stop_flag):
|
|
|
+ span['chars'].append(char)
|
|
|
+ break
|
|
|
+
|
|
|
+ for span in spans:
|
|
|
+ chars_to_content(span)
|
|
|
+
|
|
|
+
|
|
|
+# 使用鲁棒性更强的中心点坐标判断
|
|
|
+def calculate_char_in_span(char_bbox, span_bbox, char_is_line_stop_flag):
|
|
|
+ char_center_x = (char_bbox[0] + char_bbox[2]) / 2
|
|
|
+ char_center_y = (char_bbox[1] + char_bbox[3]) / 2
|
|
|
+ span_center_y = (span_bbox[1] + span_bbox[3]) / 2
|
|
|
+ span_height = span_bbox[3] - span_bbox[1]
|
|
|
+
|
|
|
+ if (
|
|
|
+ span_bbox[0] < char_center_x < span_bbox[2]
|
|
|
+ and span_bbox[1] < char_center_y < span_bbox[3]
|
|
|
+ and abs(char_center_y - span_center_y) < span_height / 4 # 字符的中轴和span的中轴高度差不能超过1/4span高度
|
|
|
+ ):
|
|
|
+ return True
|
|
|
+ else:
|
|
|
+ # 如果char是LINE_STOP_FLAG,就不用中心点判定,换一种方案(左边界在span区域内,高度判定和之前逻辑一致)
|
|
|
+ # 主要是给结尾符号一个进入span的机会,这个char还应该离span右边界较近
|
|
|
+ if char_is_line_stop_flag:
|
|
|
+ if (
|
|
|
+ (span_bbox[2] - span_height) < char_bbox[0] < span_bbox[2]
|
|
|
+ and span_bbox[1] < char_center_y < span_bbox[3]
|
|
|
+ and abs(char_center_y - span_center_y) < span_height / 4
|
|
|
+ ):
|
|
|
+ return True
|
|
|
+ else:
|
|
|
+ return False
|
|
|
+
|
|
|
+
|
|
|
+def txt_spans_extract_v2(pdf_page, spans, all_bboxes, all_discarded_blocks, lang):
|
|
|
+
|
|
|
+ useful_spans = []
|
|
|
+ unuseful_spans = []
|
|
|
+ for span in spans:
|
|
|
+ for block in all_bboxes:
|
|
|
+ if block[7] in [BlockType.ImageBody, BlockType.TableBody, BlockType.InterlineEquation]:
|
|
|
+ continue
|
|
|
+ else:
|
|
|
+ if calculate_overlap_area_in_bbox1_area_ratio(span['bbox'], block[0:4]) > 0.5:
|
|
|
+ useful_spans.append(span)
|
|
|
+ break
|
|
|
+ for block in all_discarded_blocks:
|
|
|
+ if calculate_overlap_area_in_bbox1_area_ratio(span['bbox'], block[0:4]) > 0.5:
|
|
|
+ unuseful_spans.append(span)
|
|
|
+ break
|
|
|
+
|
|
|
+ text_blocks = pdf_page.get_text('rawdict', flags=fitz.TEXTFLAGS_TEXT)['blocks']
|
|
|
+
|
|
|
+ # @todo: 拿到char之后把倾斜角度较大的先删一遍
|
|
|
+ all_pymu_chars = []
|
|
|
+ for block in text_blocks:
|
|
|
+ for line in block['lines']:
|
|
|
+ for span in line['spans']:
|
|
|
+ all_pymu_chars.extend(span['chars'])
|
|
|
+
|
|
|
+ new_spans = []
|
|
|
+
|
|
|
+ for span in useful_spans:
|
|
|
+ if span['type'] in [ContentType.Text]:
|
|
|
+ span['chars'] = []
|
|
|
+ new_spans.append(span)
|
|
|
+
|
|
|
+ for span in unuseful_spans:
|
|
|
+ if span['type'] in [ContentType.Text]:
|
|
|
+ span['chars'] = []
|
|
|
+ new_spans.append(span)
|
|
|
+
|
|
|
+ fill_char_in_spans(new_spans, all_pymu_chars)
|
|
|
+
|
|
|
+ empty_spans = []
|
|
|
+ for span in new_spans:
|
|
|
+ if len(span['content']) == 0:
|
|
|
+ empty_spans.append(span)
|
|
|
+ if len(empty_spans) > 0:
|
|
|
+
|
|
|
+ # 初始化ocr模型
|
|
|
+ atom_model_manager = AtomModelSingleton()
|
|
|
+ ocr_model = atom_model_manager.get_atom_model(
|
|
|
+ atom_model_name="ocr",
|
|
|
+ ocr_show_log=False,
|
|
|
+ det_db_box_thresh=0.3,
|
|
|
+ lang=lang
|
|
|
+ )
|
|
|
+
|
|
|
+ for span in empty_spans:
|
|
|
+ spans.remove(span)
|
|
|
+ # 对span的bbox截图
|
|
|
+ span_img = cut_image_to_pil_image(span['bbox'], pdf_page, mode="cv2")
|
|
|
+ ocr_res = ocr_model.ocr(span_img, det=False)
|
|
|
+ # logger.info(f"ocr_res: {ocr_res}")
|
|
|
+ # logger.info(f"empty_span: {span}")
|
|
|
+ if len(ocr_res) > 0:
|
|
|
+ if len(ocr_res[0]) > 0:
|
|
|
+ ocr_text, ocr_score = ocr_res[0][0]
|
|
|
+ if ocr_score > 0.5 and len(ocr_text) > 0:
|
|
|
+ span['content'] = ocr_text
|
|
|
+ spans.append(span)
|
|
|
+
|
|
|
+ return spans
|
|
|
+
|
|
|
+
|
|
|
+def txt_spans_extract_v1(pdf_page, inline_equations, interline_equations):
|
|
|
text_raw_blocks = pdf_page.get_text('dict', flags=fitz.TEXTFLAGS_TEXT)['blocks']
|
|
|
char_level_text_blocks = pdf_page.get_text('rawdict', flags=fitz.TEXTFLAGS_TEXT)[
|
|
|
'blocks'
|
|
|
@@ -464,18 +621,16 @@ def remove_outside_spans(spans, all_bboxes, all_discarded_blocks):
|
|
|
|
|
|
|
|
|
def parse_page_core(
|
|
|
- page_doc: PageableData, magic_model, page_id, pdf_bytes_md5, imageWriter, parse_mode
|
|
|
+ page_doc: PageableData, magic_model, page_id, pdf_bytes_md5, imageWriter, parse_mode, lang
|
|
|
):
|
|
|
need_drop = False
|
|
|
drop_reason = []
|
|
|
|
|
|
"""从magic_model对象中获取后面会用到的区块信息"""
|
|
|
- # img_blocks = magic_model.get_imgs(page_id)
|
|
|
- # table_blocks = magic_model.get_tables(page_id)
|
|
|
-
|
|
|
img_groups = magic_model.get_imgs_v2(page_id)
|
|
|
table_groups = magic_model.get_tables_v2(page_id)
|
|
|
|
|
|
+ """对image和table的区块分组"""
|
|
|
img_body_blocks, img_caption_blocks, img_footnote_blocks = process_groups(
|
|
|
img_groups, 'image_body', 'image_caption_list', 'image_footnote_list'
|
|
|
)
|
|
|
@@ -519,38 +674,20 @@ def parse_page_core(
|
|
|
page_h,
|
|
|
)
|
|
|
|
|
|
+ """获取所有的spans信息"""
|
|
|
spans = magic_model.get_all_spans(page_id)
|
|
|
|
|
|
- """根据parse_mode,构造spans"""
|
|
|
- if parse_mode == SupportedPdfParseMethod.TXT:
|
|
|
- """ocr 中文本类的 span 用 pymu spans 替换!"""
|
|
|
- pymu_spans = txt_spans_extract(page_doc, inline_equations, interline_equations)
|
|
|
- spans = replace_text_span(pymu_spans, spans)
|
|
|
- elif parse_mode == SupportedPdfParseMethod.OCR:
|
|
|
- pass
|
|
|
- else:
|
|
|
- raise Exception('parse_mode must be txt or ocr')
|
|
|
-
|
|
|
"""在删除重复span之前,应该通过image_body和table_body的block过滤一下image和table的span"""
|
|
|
"""顺便删除大水印并保留abandon的span"""
|
|
|
spans = remove_outside_spans(spans, all_bboxes, all_discarded_blocks)
|
|
|
|
|
|
- """删除重叠spans中置信度较低的那些"""
|
|
|
- spans, dropped_spans_by_confidence = remove_overlaps_low_confidence_spans(spans)
|
|
|
- """删除重叠spans中较小的那些"""
|
|
|
- spans, dropped_spans_by_span_overlap = remove_overlaps_min_spans(spans)
|
|
|
- """对image和table截图"""
|
|
|
- spans = ocr_cut_image_and_table(
|
|
|
- spans, page_doc, page_id, pdf_bytes_md5, imageWriter
|
|
|
- )
|
|
|
-
|
|
|
"""先处理不需要排版的discarded_blocks"""
|
|
|
discarded_block_with_spans, spans = fill_spans_in_blocks(
|
|
|
all_discarded_blocks, spans, 0.4
|
|
|
)
|
|
|
fix_discarded_blocks = fix_discarded_block(discarded_block_with_spans)
|
|
|
|
|
|
- """如果当前页面没有bbox则跳过"""
|
|
|
+ """如果当前页面没有有效的bbox则跳过"""
|
|
|
if len(all_bboxes) == 0:
|
|
|
logger.warning(f'skip this page, not found useful bbox, page_id: {page_id}')
|
|
|
return ocr_construct_page_component_v2(
|
|
|
@@ -568,7 +705,32 @@ def parse_page_core(
|
|
|
drop_reason,
|
|
|
)
|
|
|
|
|
|
- """将span填入blocks中"""
|
|
|
+ """删除重叠spans中置信度较低的那些"""
|
|
|
+ spans, dropped_spans_by_confidence = remove_overlaps_low_confidence_spans(spans)
|
|
|
+ """删除重叠spans中较小的那些"""
|
|
|
+ spans, dropped_spans_by_span_overlap = remove_overlaps_min_spans(spans)
|
|
|
+
|
|
|
+ """根据parse_mode,构造spans,主要是文本类的字符填充"""
|
|
|
+ if parse_mode == SupportedPdfParseMethod.TXT:
|
|
|
+
|
|
|
+ """之前的公式替换方案"""
|
|
|
+ # pymu_spans = txt_spans_extract_v1(page_doc, inline_equations, interline_equations)
|
|
|
+ # spans = replace_text_span(pymu_spans, spans)
|
|
|
+
|
|
|
+ """ocr 中文本类的 span 用 pymu spans 替换!"""
|
|
|
+ spans = txt_spans_extract_v2(page_doc, spans, all_bboxes, all_discarded_blocks, lang)
|
|
|
+
|
|
|
+ elif parse_mode == SupportedPdfParseMethod.OCR:
|
|
|
+ pass
|
|
|
+ else:
|
|
|
+ raise Exception('parse_mode must be txt or ocr')
|
|
|
+
|
|
|
+ """对image和table截图"""
|
|
|
+ spans = ocr_cut_image_and_table(
|
|
|
+ spans, page_doc, page_id, pdf_bytes_md5, imageWriter
|
|
|
+ )
|
|
|
+
|
|
|
+ """span填充进block"""
|
|
|
block_with_spans, spans = fill_spans_in_blocks(all_bboxes, spans, 0.5)
|
|
|
|
|
|
"""对block进行fix操作"""
|
|
|
@@ -618,6 +780,7 @@ def pdf_parse_union(
|
|
|
start_page_id=0,
|
|
|
end_page_id=None,
|
|
|
debug_mode=False,
|
|
|
+ lang=None,
|
|
|
):
|
|
|
pdf_bytes_md5 = compute_md5(dataset.data_bits())
|
|
|
|
|
|
@@ -654,7 +817,7 @@ def pdf_parse_union(
|
|
|
"""解析pdf中的每一页"""
|
|
|
if start_page_id <= page_id <= end_page_id:
|
|
|
page_info = parse_page_core(
|
|
|
- page, magic_model, page_id, pdf_bytes_md5, imageWriter, parse_mode
|
|
|
+ page, magic_model, page_id, pdf_bytes_md5, imageWriter, parse_mode, lang
|
|
|
)
|
|
|
else:
|
|
|
page_info = page.get_page_info()
|