Эх сурвалжийг харах

refactor(ocr): improve image and table block handling

- Split image and table blocks into separate categories
- Add group_id to image and table blocks- Update block processing logic to handle new categories
- Modify layout splitting and span filling to accommodate new block types
- Adjust block indexing and sorting to consider new structures
myhloli 1 жил өмнө
parent
commit
c34c9d21ef

+ 10 - 10
magic_pdf/dict2md/ocr_mkcontent.py

@@ -70,17 +70,17 @@ def ocr_mk_markdown_with_para_core_v2(paras_of_layout,
                                     para_text += f"\n![]({join_path(img_buket_path, span['image_path'])})  \n"
                 for block in para_block['blocks']:  # 2nd.拼image_caption
                     if block['type'] == BlockType.ImageCaption:
-                        para_text += merge_para_with_text(block)
+                        para_text += merge_para_with_text(block) + '  \n'
                 for block in para_block['blocks']:  # 2nd.拼image_caption
                     if block['type'] == BlockType.ImageFootnote:
-                        para_text += merge_para_with_text(block)
+                        para_text += merge_para_with_text(block) + '  \n'
         elif para_type == BlockType.Table:
             if mode == 'nlp':
                 continue
             elif mode == 'mm':
                 for block in para_block['blocks']:  # 1st.拼table_caption
                     if block['type'] == BlockType.TableCaption:
-                        para_text += merge_para_with_text(block)
+                        para_text += merge_para_with_text(block) + '  \n'
                 for block in para_block['blocks']:  # 2nd.拼table_body
                     if block['type'] == BlockType.TableBody:
                         for line in block['lines']:
@@ -95,7 +95,7 @@ def ocr_mk_markdown_with_para_core_v2(paras_of_layout,
                                         para_text += f"\n![]({join_path(img_buket_path, span['image_path'])})  \n"
                 for block in para_block['blocks']:  # 3rd.拼table_footnote
                     if block['type'] == BlockType.TableFootnote:
-                        para_text += merge_para_with_text(block)
+                        para_text += merge_para_with_text(block) + '  \n'
 
         if para_text.strip() == '':
             continue
@@ -180,18 +180,18 @@ def para_to_standard_format_v2(para_block, img_buket_path, page_idx, drop_reason
             'text_format': 'latex',
         }
     elif para_type == BlockType.Image:
-        para_content = {'type': 'image'}
+        para_content = {'type': 'image', 'img_caption': [], 'img_footnote': []}
         for block in para_block['blocks']:
             if block['type'] == BlockType.ImageBody:
                 para_content['img_path'] = join_path(
                     img_buket_path,
                     block['lines'][0]['spans'][0]['image_path'])
             if block['type'] == BlockType.ImageCaption:
-                para_content['img_caption'] = merge_para_with_text(block)
+                para_content['img_caption'].append(merge_para_with_text(block))
             if block['type'] == BlockType.ImageFootnote:
-                para_content['img_footnote'] = merge_para_with_text(block)
+                para_content['img_footnote'].append(merge_para_with_text(block))
     elif para_type == BlockType.Table:
-        para_content = {'type': 'table'}
+        para_content = {'type': 'table', 'table_caption': [], 'table_footnote': []}
         for block in para_block['blocks']:
             if block['type'] == BlockType.TableBody:
                 if block["lines"][0]["spans"][0].get('latex', ''):
@@ -200,9 +200,9 @@ def para_to_standard_format_v2(para_block, img_buket_path, page_idx, drop_reason
                     para_content['table_body'] = f"\n\n{block['lines'][0]['spans'][0]['html']}\n\n"
                 para_content['img_path'] = join_path(img_buket_path, block["lines"][0]["spans"][0]['image_path'])
             if block['type'] == BlockType.TableCaption:
-                para_content['table_caption'] = merge_para_with_text(block)
+                para_content['table_caption'].append(merge_para_with_text(block))
             if block['type'] == BlockType.TableFootnote:
-                para_content['table_footnote'] = merge_para_with_text(block)
+                para_content['table_footnote'].append(merge_para_with_text(block))
 
     para_content['page_idx'] = page_idx
 

+ 99 - 28
magic_pdf/pdf_parse_union_core_v2.py

@@ -1,3 +1,4 @@
+import copy
 import os
 import statistics
 import time
@@ -15,7 +16,7 @@ from magic_pdf.libs.convert_utils import dict_to_list
 from magic_pdf.libs.drop_reason import DropReason
 from magic_pdf.libs.hash_utils import compute_md5
 from magic_pdf.libs.local_math import float_equal
-from magic_pdf.libs.ocr_content_type import ContentType
+from magic_pdf.libs.ocr_content_type import ContentType, BlockType
 from magic_pdf.model.magic_model import MagicModel
 from magic_pdf.para.para_split_v3 import para_split
 from magic_pdf.pre_proc.citationmarker_remove import remove_citation_marker
@@ -29,7 +30,7 @@ from magic_pdf.pre_proc.ocr_detect_all_bboxes import \
     ocr_prepare_bboxes_for_layout_split_v2
 from magic_pdf.pre_proc.ocr_dict_merge import (fill_spans_in_blocks,
                                                fix_block_spans,
-                                               fix_discarded_block)
+                                               fix_discarded_block, fix_block_spans_v2)
 from magic_pdf.pre_proc.ocr_span_list_modify import (
     get_qa_need_list_v2, remove_overlaps_low_confidence_spans,
     remove_overlaps_min_spans)
@@ -173,19 +174,6 @@ def do_predict(boxes: List[List[int]], model) -> List[int]:
 
 def cal_block_index(fix_blocks, sorted_bboxes):
     for block in fix_blocks:
-        # if block['type'] in ['text', 'title', 'interline_equation']:
-        #     line_index_list = []
-        #     if len(block['lines']) == 0:
-        #         block['index'] = sorted_bboxes.index(block['bbox'])
-        #     else:
-        #         for line in block['lines']:
-        #             line['index'] = sorted_bboxes.index(line['bbox'])
-        #             line_index_list.append(line['index'])
-        #         median_value = statistics.median(line_index_list)
-        #         block['index'] = median_value
-        #
-        # elif block['type'] in ['table', 'image']:
-        #     block['index'] = sorted_bboxes.index(block['bbox'])
 
         line_index_list = []
         if len(block['lines']) == 0:
@@ -197,9 +185,11 @@ def cal_block_index(fix_blocks, sorted_bboxes):
             median_value = statistics.median(line_index_list)
             block['index'] = median_value
 
-        # 删除图表block中的虚拟line信息
-        if block['type'] in ['table', 'image']:
-            del block['lines']
+        # 删除图表body block中的虚拟line信息, 并用real_lines信息回填
+        if block['type'] in [BlockType.ImageBody, BlockType.TableBody]:
+            block['virtual_lines'] = copy.deepcopy(block['lines'])
+            block['lines'] = copy.deepcopy(block['real_lines'])
+            del block['real_lines']
 
     return fix_blocks
 
@@ -250,7 +240,11 @@ def insert_lines_into_block(block_bbox, line_height, page_w, page_h):
 def sort_lines_by_model(fix_blocks, page_w, page_h, line_height):
     page_line_list = []
     for block in fix_blocks:
-        if block['type'] in ['text', 'title', 'interline_equation']:
+        if block['type'] in [
+            BlockType.Text, BlockType.Title, BlockType.InterlineEquation,
+            BlockType.ImageCaption, BlockType.ImageFootnote,
+            BlockType.TableCaption, BlockType.TableFootnote
+        ]:
             if len(block['lines']) == 0:
                 bbox = block['bbox']
                 lines = insert_lines_into_block(bbox, line_height, page_w, page_h)
@@ -261,8 +255,9 @@ def sort_lines_by_model(fix_blocks, page_w, page_h, line_height):
                 for line in block['lines']:
                     bbox = line['bbox']
                     page_line_list.append(bbox)
-        elif block['type'] in ['table', 'image']:
+        elif block['type'] in [BlockType.ImageBody, BlockType.TableBody]:
             bbox = block['bbox']
+            block["real_lines"] = copy.deepcopy(block['lines'])
             lines = insert_lines_into_block(bbox, line_height, page_w, page_h)
             block['lines'] = []
             for line in lines:
@@ -316,7 +311,11 @@ def sort_lines_by_model(fix_blocks, page_w, page_h, line_height):
 def get_line_height(blocks):
     page_line_height_list = []
     for block in blocks:
-        if block['type'] in ['text', 'title', 'interline_equation']:
+        if block['type'] in [
+            BlockType.Text, BlockType.Title,
+            BlockType.ImageCaption, BlockType.ImageFootnote,
+            BlockType.TableCaption, BlockType.TableFootnote
+        ]:
             for line in block['lines']:
                 bbox = line['bbox']
                 page_line_height_list.append(int(bbox[3] - bbox[1]))
@@ -326,6 +325,63 @@ def get_line_height(blocks):
         return 10
 
 
+def process_groups(groups, body_key, caption_key, footnote_key):
+    body_blocks = []
+    caption_blocks = []
+    footnote_blocks = []
+    for i, group in enumerate(groups):
+        group[body_key]['group_id'] = i
+        body_blocks.append(group[body_key])
+        for caption_block in group[caption_key]:
+            caption_block['group_id'] = i
+            caption_blocks.append(caption_block)
+        for footnote_block in group[footnote_key]:
+            footnote_block['group_id'] = i
+            footnote_blocks.append(footnote_block)
+    return body_blocks, caption_blocks, footnote_blocks
+
+
+def process_block_list(blocks, body_type, block_type):
+    indices = [block['index'] for block in blocks]
+    median_index = statistics.median(indices)
+
+    body_bbox = next((block['bbox'] for block in blocks if block.get('type') == body_type), [])
+
+    return {
+        'type': block_type,
+        'bbox': body_bbox,
+        'blocks': blocks,
+        'index': median_index,
+    }
+
+
+def revert_group_blocks(blocks):
+    image_groups = {}
+    table_groups = {}
+    new_blocks = []
+    for block in blocks:
+        if block['type'] in [BlockType.ImageBody, BlockType.ImageCaption, BlockType.ImageFootnote]:
+            group_id = block['group_id']
+            if group_id not in image_groups:
+                image_groups[group_id] = []
+            image_groups[group_id].append(block)
+        elif block['type'] in [BlockType.TableBody, BlockType.TableCaption, BlockType.TableFootnote]:
+            group_id = block['group_id']
+            if group_id not in table_groups:
+                table_groups[group_id] = []
+            table_groups[group_id].append(block)
+        else:
+            new_blocks.append(block)
+
+    for group_id, blocks in image_groups.items():
+        new_blocks.append(process_block_list(blocks, BlockType.ImageBody, BlockType.Image))
+
+    for group_id, blocks in table_groups.items():
+        new_blocks.append(process_block_list(blocks, BlockType.TableBody, BlockType.Table))
+
+    return new_blocks
+
+
 def parse_page_core(
     page_doc: PageableData, magic_model, page_id, pdf_bytes_md5, imageWriter, parse_mode
 ):
@@ -333,8 +389,20 @@ def parse_page_core(
     drop_reason = []
 
     """从magic_model对象中获取后面会用到的区块信息"""
-    img_blocks = magic_model.get_imgs(page_id)
-    table_blocks = magic_model.get_tables(page_id)
+    # img_blocks = magic_model.get_imgs(page_id)
+    # table_blocks = magic_model.get_tables(page_id)
+
+    img_groups = magic_model.get_imgs_v2(page_id)
+    table_groups = magic_model.get_tables_v2(page_id)
+
+    img_body_blocks, img_caption_blocks, img_footnote_blocks = process_groups(
+        img_groups, 'image_body', 'image_caption_list', 'image_footnote_list'
+    )
+
+    table_body_blocks, table_caption_blocks, table_footnote_blocks = process_groups(
+        table_groups, 'table_body', 'table_caption_list', 'table_footnote_list'
+    )
+
     discarded_blocks = magic_model.get_discarded(page_id)
     text_blocks = magic_model.get_text_blocks(page_id)
     title_blocks = magic_model.get_title_blocks(page_id)
@@ -370,8 +438,8 @@ def parse_page_core(
     interline_equation_blocks = []
     if len(interline_equation_blocks) > 0:
         all_bboxes, all_discarded_blocks = ocr_prepare_bboxes_for_layout_split_v2(
-            img_blocks,
-            table_blocks,
+            img_body_blocks, img_caption_blocks, img_footnote_blocks,
+            table_body_blocks, table_caption_blocks, table_footnote_blocks,
             discarded_blocks,
             text_blocks,
             title_blocks,
@@ -381,8 +449,8 @@ def parse_page_core(
         )
     else:
         all_bboxes, all_discarded_blocks = ocr_prepare_bboxes_for_layout_split_v2(
-            img_blocks,
-            table_blocks,
+            img_body_blocks, img_caption_blocks, img_footnote_blocks,
+            table_body_blocks, table_caption_blocks, table_footnote_blocks,
             discarded_blocks,
             text_blocks,
             title_blocks,
@@ -419,7 +487,7 @@ def parse_page_core(
     block_with_spans, spans = fill_spans_in_blocks(all_bboxes, spans, 0.5)
 
     """对block进行fix操作"""
-    fix_blocks = fix_block_spans(block_with_spans, img_blocks, table_blocks)
+    fix_blocks = fix_block_spans_v2(block_with_spans)
 
     """获取所有line并计算正文line的高度"""
     line_height = get_line_height(fix_blocks)
@@ -430,6 +498,9 @@ def parse_page_core(
     """根据line的中位数算block的序列关系"""
     fix_blocks = cal_block_index(fix_blocks, sorted_bboxes)
 
+    """将image和table的block还原回group形式参与后续流程"""
+    fix_blocks = revert_group_blocks(fix_blocks)
+
     """重排block"""
     sorted_blocks = sorted(fix_blocks, key=lambda b: b['index'])
 

+ 31 - 24
magic_pdf/pre_proc/ocr_detect_all_bboxes.py

@@ -60,29 +60,34 @@ def ocr_prepare_bboxes_for_layout_split(img_blocks, table_blocks, discarded_bloc
     return all_bboxes, all_discarded_blocks, drop_reasons
 
 
-def ocr_prepare_bboxes_for_layout_split_v2(img_blocks, table_blocks, discarded_blocks, text_blocks,
-                                        title_blocks, interline_equation_blocks, page_w, page_h):
+def add_bboxes(blocks, block_type, bboxes):
+    for block in blocks:
+        x0, y0, x1, y1 = block['bbox']
+        if block_type in [
+            BlockType.ImageBody, BlockType.ImageCaption, BlockType.ImageFootnote,
+            BlockType.TableBody, BlockType.TableCaption, BlockType.TableFootnote
+        ]:
+            bboxes.append([x0, y0, x1, y1, None, None, None, block_type, None, None, None, None, block["score"], block["group_id"]])
+        else:
+            bboxes.append([x0, y0, x1, y1, None, None, None, block_type, None, None, None, None, block["score"]])
+
+
+def ocr_prepare_bboxes_for_layout_split_v2(
+        img_body_blocks, img_caption_blocks, img_footnote_blocks,
+        table_body_blocks, table_caption_blocks, table_footnote_blocks,
+        discarded_blocks, text_blocks, title_blocks, interline_equation_blocks, page_w, page_h
+):
     all_bboxes = []
-    all_discarded_blocks = []
-    for image in img_blocks:
-        x0, y0, x1, y1 = image['bbox']
-        all_bboxes.append([x0, y0, x1, y1, None, None, None, BlockType.Image, None, None, None, None, image["score"]])
 
-    for table in table_blocks:
-        x0, y0, x1, y1 = table['bbox']
-        all_bboxes.append([x0, y0, x1, y1, None, None, None, BlockType.Table, None, None, None, None, table["score"]])
-
-    for text in text_blocks:
-        x0, y0, x1, y1 = text['bbox']
-        all_bboxes.append([x0, y0, x1, y1, None, None, None, BlockType.Text, None, None, None, None, text["score"]])
-
-    for title in title_blocks:
-        x0, y0, x1, y1 = title['bbox']
-        all_bboxes.append([x0, y0, x1, y1, None, None, None, BlockType.Title, None, None, None, None, title["score"]])
-
-    for interline_equation in interline_equation_blocks:
-        x0, y0, x1, y1 = interline_equation['bbox']
-        all_bboxes.append([x0, y0, x1, y1, None, None, None, BlockType.InterlineEquation, None, None, None, None, interline_equation["score"]])
+    add_bboxes(img_body_blocks, BlockType.ImageBody, all_bboxes)
+    add_bboxes(img_caption_blocks, BlockType.ImageCaption, all_bboxes)
+    add_bboxes(img_footnote_blocks, BlockType.ImageFootnote, all_bboxes)
+    add_bboxes(table_body_blocks, BlockType.TableBody, all_bboxes)
+    add_bboxes(table_caption_blocks, BlockType.TableCaption, all_bboxes)
+    add_bboxes(table_footnote_blocks, BlockType.TableFootnote, all_bboxes)
+    add_bboxes(text_blocks, BlockType.Text, all_bboxes)
+    add_bboxes(title_blocks, BlockType.Title, all_bboxes)
+    add_bboxes(interline_equation_blocks, BlockType.InterlineEquation, all_bboxes)
 
     '''block嵌套问题解决'''
     '''文本框与标题框重叠,优先信任文本框'''
@@ -96,12 +101,14 @@ def ocr_prepare_bboxes_for_layout_split_v2(img_blocks, table_blocks, discarded_b
     '''interline_equation框被包含在文本类型框内,且interline_equation比文本区块小很多时信任文本框,这时需要舍弃公式框'''
     # 通过后续大框套小框逻辑删除
 
-    '''discarded_blocks中只保留宽度超过1/3页面宽度的,高度超过10的,处于页面下半50%区域的(限定footnote)'''
+    '''discarded_blocks'''
+    all_discarded_blocks = []
+    add_bboxes(discarded_blocks, BlockType.Discarded, all_discarded_blocks)
+
+    '''footnote识别:宽度超过1/3页面宽度的,高度超过10的,处于页面下半50%区域的'''
     footnote_blocks = []
     for discarded in discarded_blocks:
         x0, y0, x1, y1 = discarded['bbox']
-        all_discarded_blocks.append([x0, y0, x1, y1, None, None, None, BlockType.Discarded, None, None, None, None, discarded["score"]])
-        # 将footnote加入到all_bboxes中,用来计算layout
         if (x1 - x0) > (page_w / 3) and (y1 - y0) > 10 and y0 > (page_h / 2):
             footnote_blocks.append([x0, y0, x1, y1])
 

+ 26 - 0
magic_pdf/pre_proc/ocr_dict_merge.py

@@ -153,6 +153,11 @@ def fill_spans_in_blocks(blocks, spans, radio):
             'type': block_type,
             'bbox': block_bbox,
         }
+        if block_type in [
+            BlockType.ImageBody, BlockType.ImageCaption, BlockType.ImageFootnote,
+            BlockType.TableBody, BlockType.TableCaption, BlockType.TableFootnote
+        ]:
+            block_dict["group_id"] = block[-1]
         block_spans = []
         for span in spans:
             span_bbox = span['bbox']
@@ -201,6 +206,27 @@ def fix_block_spans(block_with_spans, img_blocks, table_blocks):
     return fix_blocks
 
 
+def fix_block_spans_v2(block_with_spans):
+    """1、img_block和table_block因为包含caption和footnote的关系,存在block的嵌套关系
+    需要将caption和footnote的text_span放入相应img_block和table_block内的
+    caption_block和footnote_block中 2、同时需要删除block中的spans字段."""
+    fix_blocks = []
+    for block in block_with_spans:
+        block_type = block['type']
+
+        if block_type in [BlockType.Text, BlockType.Title,
+                          BlockType.ImageCaption, BlockType.ImageFootnote,
+                          BlockType.TableCaption, BlockType.TableFootnote
+                          ]:
+            block = fix_text_block(block)
+        elif block_type in [BlockType.InterlineEquation, BlockType.ImageBody, BlockType.TableBody]:
+            block = fix_interline_block(block)
+        else:
+            continue
+        fix_blocks.append(block)
+    return fix_blocks
+
+
 def fix_discarded_block(discarded_block_with_spans):
     fix_discarded_blocks = []
     for block in discarded_block_with_spans: