ソースを参照

fix(pdf_parse): handle blocks without lines and enable bf16 on compatible devices

Blocks without lines are now correctly indexed even when they contain textual content rendered
as images. The sorting logic has been updated to accommodate this scenario. Additionally, the
LayoutLMv3 model initialization has been enhanced to utilize bfloat16 precision on devices that
support it, offering potential performance benefits on supported hardware.
myhloli 1 年間 前
コミット
2145a8b6d2
1 ファイル変更48 行追加19 行削除
  1. 48 19
      magic_pdf/pdf_parse_union_core_v2.py

+ 48 - 19
magic_pdf/pdf_parse_union_core_v2.py

@@ -94,16 +94,27 @@ def replace_text_span(pymu_spans, ocr_spans):
     return list(filter(lambda x: x["type"] != ContentType.Text, ocr_spans)) + pymu_spans
 
 
-def model_init(model_name: str):
+def model_init(model_name: str, local_path=None):
     from transformers import LayoutLMv3ForTokenClassification
-    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
+    if torch.cuda.is_available():
+        device = torch.device("cuda")
+        if torch.cuda.is_bf16_supported():
+            supports_bfloat16 = True
+        else:
+            supports_bfloat16 = False
+    else:
+        device = torch.device("cpu")
+        supports_bfloat16 = False
+
     if model_name == "layoutreader":
-        model = (
-            LayoutLMv3ForTokenClassification.from_pretrained("hantian/layoutreader")
-            # .bfloat16()
-            .to(device)
-            .eval()
-        )
+        if local_path:
+            model = LayoutLMv3ForTokenClassification.from_pretrained(local_path)
+        else:
+            model = LayoutLMv3ForTokenClassification.from_pretrained("hantian/layoutreader")
+        # 检查设备是否支持 bfloat16
+        if supports_bfloat16:
+            model.bfloat16()
+        model.to(device).eval()
     else:
         logger.error("model name not allow")
         exit(1)
@@ -119,9 +130,12 @@ class ModelSingleton:
             cls._instance = super().__new__(cls)
         return cls._instance
 
-    def get_model(self, model_name: str):
+    def get_model(self, model_name: str, local_path=None):
         if model_name not in self._models:
-            self._models[model_name] = model_init(model_name=model_name)
+            if local_path:
+                self._models[model_name] = model_init(model_name=model_name, local_path=local_path)
+            else:
+                self._models[model_name] = model_init(model_name=model_name)
         return self._models[model_name]
 
 
@@ -134,13 +148,11 @@ def do_predict(boxes: List[List[int]], model) -> List[int]:
 
 
 def cal_block_index(fix_blocks, sorted_bboxes):
-    block_without_lines = []
     for block in fix_blocks:
         if block['type'] in ['text', 'title', 'interline_equation']:
             line_index_list = []
             if len(block['lines']) == 0:
-                block_without_lines.append(block)
-                continue
+                block['index'] = sorted_bboxes.index(block['bbox'])
             else:
                 for line in block['lines']:
                     line['index'] = sorted_bboxes.index(line['bbox'])
@@ -151,10 +163,6 @@ def cal_block_index(fix_blocks, sorted_bboxes):
         elif block['type'] in ['table', 'image']:
             block['index'] = sorted_bboxes.index(block['bbox'])
 
-    '''移除没有line的block'''
-    for block in block_without_lines:
-        fix_blocks.remove(block)
-
     return fix_blocks
 
 
@@ -162,9 +170,13 @@ def sort_lines_by_model(fix_blocks, page_w, page_h):
     page_line_list = []
     for block in fix_blocks:
         if block['type'] in ['text', 'title', 'interline_equation']:
-            for line in block['lines']:
-                bbox = line['bbox']
+            if len(block['lines']) == 0:  # 没有line的block(一般是图片形式的文本块),就直接用block的bbox来排序
+                bbox = block['bbox']
                 page_line_list.append(bbox)
+            else:
+                for line in block['lines']:
+                    bbox = line['bbox']
+                    page_line_list.append(bbox)
         elif block['type'] in ['table', 'image']:  # 简单的把表和图都当成一个line处理
             bbox = block['bbox']
             page_line_list.append(bbox)
@@ -175,6 +187,23 @@ def sort_lines_by_model(fix_blocks, page_w, page_h):
     boxes = []
     # logger.info(f"Scale: {x_scale}, {y_scale}, Boxes len: {len(page_line_list)}")
     for left, top, right, bottom in page_line_list:
+        if left < 0:
+            logger.warning(
+                f"left < 0, left: {left}, right: {right}, top: {top}, bottom: {bottom}, page_w: {page_w}, page_h: {page_h}")
+            left = 0
+        if right > page_w:
+            logger.warning(
+                f"right > page_w, left: {left}, right: {right}, top: {top}, bottom: {bottom}, page_w: {page_w}, page_h: {page_h}")
+            right = page_w
+        if top < 0:
+            logger.warning(
+                f"top < 0, left: {left}, right: {right}, top: {top}, bottom: {bottom}, page_w: {page_w}, page_h: {page_h}")
+            top = 0
+        if bottom > page_h:
+            logger.warning(
+                f"bottom > page_h, left: {left}, right: {right}, top: {top}, bottom: {bottom}, page_w: {page_w}, page_h: {page_h}")
+            bottom = page_h
+
         left = round(left * x_scale)
         top = round(top * y_scale)
         right = round(right * x_scale)