浏览代码

refactor: Optimize fill_char_in_spans using a spatial grid

Jmper 4 月之前
父节点
当前提交
1ee1550460
共有 1 个文件被更改,包括 17 次插入4 次删除
  1. 17 4
      mineru/utils/span_pre_proc.py

+ 17 - 4
mineru/utils/span_pre_proc.py

@@ -1,4 +1,5 @@
 # Copyright (c) Opendatalab. All rights reserved.
+import collections
 import re
 import statistics
 
@@ -187,7 +188,7 @@ def txt_spans_extract(pdf_page, spans, pil_img, scale, all_bboxes, all_discarded
             span['chars'] = []
             new_spans.append(span)
 
-    need_ocr_spans = fill_char_in_spans(new_spans, page_all_chars)
+    need_ocr_spans = fill_char_in_spans(new_spans, page_all_chars, median_span_height)
 
     """对未填充的span进行ocr"""
     if len(need_ocr_spans) > 0:
@@ -208,14 +209,26 @@ def txt_spans_extract(pdf_page, spans, pil_img, scale, all_bboxes, all_discarded
     return spans
 
 
-def fill_char_in_spans(spans, all_chars):
-
+def fill_char_in_spans(spans, all_chars, median_span_height):
     # 简单从上到下排一下序
     spans = sorted(spans, key=lambda x: x['bbox'][1])
 
+    grid_size = median_span_height
+    grid = collections.defaultdict(list)
+    for i, span in enumerate(spans):
+        start_cell = int(span['bbox'][1] / grid_size)
+        end_cell = int(span['bbox'][3] / grid_size)
+        for cell_idx in range(start_cell, end_cell + 1):
+            grid[cell_idx].append(i)
+
     for char in all_chars:
+        char_center_y = (char['bbox'][1] + char['bbox'][3]) / 2
+        cell_idx = int(char_center_y / grid_size)
+
+        candidate_span_indices = grid.get(cell_idx, [])
 
-        for span in spans:
+        for span_idx in candidate_span_indices:
+            span = spans[span_idx]
             if calculate_char_in_span(char['bbox'], span['bbox'], char['char']):
                 span['chars'].append(char)
                 break