Przeglądaj źródła

Merge pull request #2754 from opendatalab/release-2.0.6

Release 2.0.6
Xiaomeng Zhao 5 miesięcy temu
rodzic
commit
48085d02c4

+ 2 - 5
.github/ISSUE_TEMPLATE/bug_report.yml

@@ -109,14 +109,11 @@ body:
   - type: dropdown
     id: software_version
     attributes:
-      label: Software version | 软件版本 (magic-pdf --version)
+      label: Software version | 软件版本 (mineru --version)
       #multiple: false
       options:
         -
-        - "1.0.x"
-        - "1.1.x"
-        - "1.2.x"
-        - "1.3.x"
+        - "2.0.x"
     validations:
       required: true
 

Plik diff jest za duży
+ 0 - 2
README.md


Plik diff jest za duży
+ 0 - 2
README_zh-CN.md


+ 26 - 19
mineru/backend/vlm/vlm_magic_model.py

@@ -1,6 +1,8 @@
 import re
 from typing import Literal
 
+from loguru import logger
+
 from mineru.utils.boxbase import bbox_distance, is_in
 from mineru.utils.enum_class import ContentType, BlockType, SplitFlag
 from mineru.backend.vlm.vlm_middle_json_mkcontent import merge_para_with_text
@@ -22,25 +24,30 @@ class MagicModel:
         # 解析每个块
         for index, block_info in enumerate(block_infos):
             block_bbox = block_info[0].strip()
-            x1, y1, x2, y2 = map(int, block_bbox.split())
-            x_1, y_1, x_2, y_2 = (
-                int(x1 * width / 1000),
-                int(y1 * height / 1000),
-                int(x2 * width / 1000),
-                int(y2 * height / 1000),
-            )
-            if x_2 < x_1:
-                x_1, x_2 = x_2, x_1
-            if y_2 < y_1:
-                y_1, y_2 = y_2, y_1
-            block_bbox = (x_1, y_1, x_2, y_2)
-            block_type = block_info[1].strip()
-            block_content = block_info[2].strip()
-
-            # print(f"坐标: {block_bbox}")
-            # print(f"类型: {block_type}")
-            # print(f"内容: {block_content}")
-            # print("-" * 50)
+            try:
+                x1, y1, x2, y2 = map(int, block_bbox.split())
+                x_1, y_1, x_2, y_2 = (
+                    int(x1 * width / 1000),
+                    int(y1 * height / 1000),
+                    int(x2 * width / 1000),
+                    int(y2 * height / 1000),
+                )
+                if x_2 < x_1:
+                    x_1, x_2 = x_2, x_1
+                if y_2 < y_1:
+                    y_1, y_2 = y_2, y_1
+                block_bbox = (x_1, y_1, x_2, y_2)
+                block_type = block_info[1].strip()
+                block_content = block_info[2].strip()
+
+                # print(f"坐标: {block_bbox}")
+                # print(f"类型: {block_type}")
+                # print(f"内容: {block_content}")
+                # print("-" * 50)
+            except Exception as e:
+                # 如果解析失败,可能是因为格式不正确,跳过这个块
+                logger.warning(f"Invalid block format: {block_info}, error: {e}")
+                continue
 
             span_type = "unknown"
             if block_type in [

+ 1 - 1
mineru/model/ocr/paddleocr2pytorch/pytorch_paddle.py

@@ -58,7 +58,7 @@ class PytorchPaddleOCR(TextSystem):
 
         device = get_device()
         if device == 'cpu' and self.lang in ['ch', 'ch_server', 'japan', 'chinese_cht']:
-            logger.warning("The current device in use is CPU. To ensure the speed of parsing, the language is automatically switched to ch_lite.")
+            # logger.warning("The current device in use is CPU. To ensure the speed of parsing, the language is automatically switched to ch_lite.")
             self.lang = 'ch_lite'
 
         if self.lang in latin_lang:

+ 1 - 1
mineru/model/vlm_sglang_model/model.py

@@ -62,7 +62,7 @@ class Mineru2QwenForCausalLM(nn.Module):
 
         # load vision tower
         mm_vision_tower = self.config.mm_vision_tower
-        model_root_path = auto_download_and_get_model_root_path("/", "vlm")
+        model_root_path = auto_download_and_get_model_root_path(mm_vision_tower, "vlm")
         mm_vision_tower = f"{model_root_path}/{mm_vision_tower}"
 
         if "clip" in mm_vision_tower:

+ 31 - 1
mineru/utils/format_utils.py

@@ -132,6 +132,35 @@ def otsl_parse_texts(texts, tokens):
     r_idx = 0
     c_idx = 0
 
+    # Check and complete the matrix
+    if split_row_tokens:
+        max_cols = max(len(row) for row in split_row_tokens)
+
+        # Insert additional <ecel> to tags
+        for row_idx, row in enumerate(split_row_tokens):
+            while len(row) < max_cols:
+                row.append(OTSL_ECEL)
+
+        # Insert additional <ecel> to texts
+        new_texts = []
+        text_idx = 0
+
+        for row_idx, row in enumerate(split_row_tokens):
+            for col_idx, token in enumerate(row):
+                new_texts.append(token)
+                if text_idx < len(texts) and texts[text_idx] == token:
+                    text_idx += 1
+                    if (text_idx < len(texts) and 
+                        texts[text_idx] not in [OTSL_NL, OTSL_FCEL, OTSL_ECEL, OTSL_LCEL, OTSL_UCEL, OTSL_XCEL]):
+                        new_texts.append(texts[text_idx])
+                        text_idx += 1
+
+            new_texts.append(OTSL_NL)
+            if text_idx < len(texts) and texts[text_idx] == OTSL_NL:
+                text_idx += 1
+
+        texts = new_texts
+
     def count_right(tokens, c_idx, r_idx, which_tokens):
         span = 0
         c_idx_iter = c_idx
@@ -235,10 +264,11 @@ def export_to_html(table_data: TableData):
 
     body = ""
 
+    grid = table_data.grid
     for i in range(nrows):
         body += "<tr>"
         for j in range(ncols):
-            cell: TableCell = table_data.grid[i][j]
+            cell: TableCell = grid[i][j]
 
             rowspan, rowstart = (
                 cell.row_span,

+ 6 - 2
mineru/utils/models_download_utils.py

@@ -57,8 +57,12 @@ def auto_download_and_get_model_root_path(relative_path: str, repo_mode='pipelin
         relative_path = relative_path.strip('/')
         cache_dir = snapshot_download(repo, allow_patterns=[relative_path, relative_path+"/*"])
     elif repo_mode == 'vlm':
-        # VLM 模式下,直接下载整个模型目录
-        cache_dir = snapshot_download(repo)
+        # VLM 模式下,根据 relative_path 的不同处理方式
+        if relative_path == "/":
+            cache_dir = snapshot_download(repo)
+        else:
+            relative_path = relative_path.strip('/')
+            cache_dir = snapshot_download(repo, allow_patterns=[relative_path, relative_path+"/*"])
 
     if not cache_dir:
         raise FileNotFoundError(f"Failed to download model: {relative_path} from {repo}")

Niektóre pliki nie zostały wyświetlone z powodu dużej ilości zmienionych plików