Эх сурвалжийг харах

refactor: streamline model path handling and enhance file retrieval logic

myhloli 5 сар өмнө
parent
commit
7bb8f0e971

+ 0 - 22
mineru/backend/pipeline/config_reader.py

@@ -67,28 +67,6 @@ def parse_bucket_key(s3_full_path: str):
     return bucket, key
 
 
-def get_local_models_dir():
-    config = read_config()
-    models_dir = config.get('models-dir')
-    if models_dir is None:
-        logger.warning(f"'models-dir' not found in {CONFIG_FILE_NAME}, use '/tmp/models' as default")
-        return '/tmp/models'
-    else:
-        return models_dir
-
-
-def get_local_layoutreader_model_dir():
-    config = read_config()
-    layoutreader_model_dir = config.get('layoutreader-model-dir')
-    if layoutreader_model_dir is None or not os.path.exists(layoutreader_model_dir):
-        home_dir = os.path.expanduser('~')
-        layoutreader_at_modelscope_dir_path = os.path.join(home_dir, '.cache/modelscope/hub/ppaanngggg/layoutreader')
-        logger.warning(f"'layoutreader-model-dir' not exists, use {layoutreader_at_modelscope_dir_path} as default")
-        return layoutreader_at_modelscope_dir_path
-    else:
-        return layoutreader_model_dir
-
-
 def get_device():
     device_mode = os.getenv('MINERU_DEVICE_MODE', None)
     if device_mode is not None:

+ 5 - 7
mineru/backend/pipeline/model_init.py

@@ -9,10 +9,8 @@ from ...model.mfd.yolo_v8 import YOLOv8MFDModel
 from ...model.mfr.unimernet.Unimernet import UnimernetModel
 from ...model.ocr.paddleocr2pytorch.pytorch_paddle import PytorchPaddleOCR
 from ...model.table.rapid_table import RapidTableModel
-
-doclayout_yolo = "Layout/YOLO/doclayout_yolo_docstructbench_imgsz1280_2501.pt"
-yolo_v8_mfd =  "MFD/YOLO/yolo_v8_ft.pt"
-unimernet_small = "MFR/unimernet_hf_small_2503"
+from ...utils.enum_class import ModelPath
+from ...utils.models_download_utils import get_file_from_repos
 
 
 def table_model_init(lang=None):
@@ -150,14 +148,14 @@ class MineruPipelineModel:
             self.mfd_model = atom_model_manager.get_atom_model(
                 atom_model_name=AtomicModel.MFD,
                 mfd_weights=str(
-                    os.path.join(models_dir, yolo_v8_mfd)
+                    os.path.join(models_dir, get_file_from_repos(ModelPath.yolo_v8_mfd))
                 ),
                 device=self.device,
             )
 
             # 初始化公式解析模型
             mfr_weight_dir = str(
-                os.path.join(models_dir, unimernet_small)
+                os.path.join(models_dir, get_file_from_repos(ModelPath.unimernet_small))
             )
 
             self.mfr_model = atom_model_manager.get_atom_model(
@@ -170,7 +168,7 @@ class MineruPipelineModel:
         self.layout_model = atom_model_manager.get_atom_model(
             atom_model_name=AtomicModel.Layout,
             doclayout_yolo_weights=str(
-                os.path.join(models_dir, doclayout_yolo)
+                os.path.join(models_dir, get_file_from_repos(ModelPath.doclayout_yolo))
             ),
             device=self.device,
         )

+ 8 - 4
mineru/model/ocr/paddleocr2pytorch/pytorch_paddle.py

@@ -9,7 +9,9 @@ import numpy as np
 import yaml
 from loguru import logger
 
-from mineru.backend.pipeline.config_reader import get_device, get_local_models_dir
+from mineru.backend.pipeline.config_reader import get_device
+from mineru.utils.enum_class import ModelPath
+from mineru.utils.models_download_utils import get_file_from_repos
 from ....utils.ocr_utils import check_img, preprocess_image, sorted_boxes, merge_det_boxes, update_det_boxes, get_rotate_crop_image
 from .tools.infer.predict_system import TextSystem
 from .tools.infer import pytorchocr_utility as utility
@@ -74,9 +76,11 @@ class PytorchPaddleOCR(TextSystem):
         with open(models_config_path) as file:
             config = yaml.safe_load(file)
             det, rec, dict_file = get_model_params(self.lang, config)
-        ocr_models_dir = os.path.join(get_local_models_dir(), 'OCR', 'paddleocr_torch')
-        kwargs['det_model_path'] = os.path.join(ocr_models_dir, det)
-        kwargs['rec_model_path'] = os.path.join(ocr_models_dir, rec)
+        ocr_models_dir = ModelPath.pytorch_paddle
+        det_model_path = get_file_from_repos(f"{ocr_models_dir}/{det}")
+        rec_model_path = get_file_from_repos(f"{ocr_models_dir}/{rec}")
+        kwargs['det_model_path'] = det_model_path
+        kwargs['rec_model_path'] = rec_model_path
         kwargs['rec_char_dict_path'] = os.path.join(root_dir, 'pytorchocr', 'utils', 'resources', 'dict', dict_file)
         # kwargs['rec_batch_num'] = 8
 

+ 4 - 3
mineru/utils/block_sort.py

@@ -7,8 +7,9 @@ from typing import List
 import torch
 from loguru import logger
 
-from mineru.backend.pipeline.config_reader import get_device, get_local_layoutreader_model_dir
-from mineru.utils.enum_class import BlockType
+from mineru.backend.pipeline.config_reader import get_device
+from mineru.utils.enum_class import BlockType, ModelPath
+from mineru.utils.models_download_utils import get_file_from_repos
 
 
 def sort_blocks_by_bbox(blocks, page_w, page_h, footnote_blocks):
@@ -187,7 +188,7 @@ def model_init(model_name: str):
     device = torch.device(device_name)
     if model_name == 'layoutreader':
         # 检测modelscope的缓存目录是否存在
-        layoutreader_model_dir = get_local_layoutreader_model_dir()
+        layoutreader_model_dir = get_file_from_repos(ModelPath.layout_reader)
         if os.path.exists(layoutreader_model_dir):
             model = LayoutLMv3ForTokenClassification.from_pretrained(
                 layoutreader_model_dir

+ 13 - 1
mineru/utils/enum_class.py

@@ -42,4 +42,16 @@ class CategoryId:
 class MakeMode:
     MM_MD = 'mm_markdown'
     NLP_MD = 'nlp_markdown'
-    STANDARD_FORMAT = 'standard_format'
+    STANDARD_FORMAT = 'standard_format'
+
+
+class ModelPath:
+    pipeline_root_modelscope = "OpenDataLab/PDF-Extract-Kit-1.0"
+    pipeline_root_hf = "opendatalab/PDF-Extract-Kit-1.0"
+    doclayout_yolo = "models/Layout/YOLO/doclayout_yolo_docstructbench_imgsz1280_2501.pt"
+    yolo_v8_mfd = "models/MFD/YOLO/yolo_v8_ft.pt"
+    unimernet_small = "models/MFR/unimernet_hf_small_2503"
+    pytorch_paddle = "models/OCR/paddleocr_torch"
+    layout_reader = "models/ReadingOrder/layout_reader"
+    vlm_root_hf = "opendatalab/MinerU-VLM-1.0"
+    vlm_root_modelscope = "OpenDataLab/MinerU-VLM-1.0"

+ 153 - 0
mineru/utils/models_download_utils.py

@@ -0,0 +1,153 @@
+import os
+import hashlib
+import requests
+from typing import List, Union
+from huggingface_hub import hf_hub_download, model_info
+from huggingface_hub.constants import HUGGINGFACE_HUB_CACHE
+
+from mineru.utils.enum_class import ModelPath
+
+
+def _sha256sum(path, chunk_size=8192):
+    h = hashlib.sha256()
+    with open(path, "rb") as f:
+        while True:
+            chunk = f.read(chunk_size)
+            if not chunk:
+                break
+            h.update(chunk)
+    return h.hexdigest()
+def get_file_from_repos(relative_path: str, repo_mode='pipeline') -> Union[str, str]:
+    """
+    支持文件或目录的可靠下载。
+    - 如果输入文件: 返回本地文件绝对路径
+    - 如果输入目录: 返回本地缓存下与 relative_path 同结构的相对路径字符串
+    :param repo_mode: 指定仓库模式,'pipeline' 或 'vlm'
+    :param relative_path: 文件或目录相对路径
+    :return: 本地文件绝对路径或相对路径
+    """
+    model_source = os.getenv('MINERU_MODEL_SOURCE', None)
+
+    # 建立仓库模式到路径的映射
+    repo_mapping = {
+        'pipeline': {
+            'huggingface': ModelPath.pipeline_root_hf,
+            'modelscope': ModelPath.pipeline_root_modelscope,
+            'default': ModelPath.pipeline_root_hf
+        },
+        'vlm': {
+            'huggingface': ModelPath.vlm_root_hf,
+            'modelscope': ModelPath.vlm_root_modelscope,
+            'default': ModelPath.vlm_root_hf
+        }
+    }
+
+    if repo_mode not in repo_mapping:
+        raise ValueError(f"Unsupported repo_mode: {repo_mode}, must be 'pipeline' or 'vlm'")
+
+    # 如果没有指定model_source或值不是'modelscope',则使用默认值
+    repo = repo_mapping[repo_mode].get(model_source, repo_mapping[repo_mode]['default'])
+
+    input_clean = relative_path.strip('/')
+    # 获取huggingface云端仓库文件树
+    try:
+        # 获取仓库信息,包含文件元数据
+        info = model_info(repo, files_metadata=True)
+        # 构建文件字典
+        siblings_dict = {f.rfilename: f for f in info.siblings}
+    except Exception as e:
+        siblings_dict = {}
+        print(f"[Warn] 获取 Huggingface 仓库结构失败,错误: {e}")
+    # 1. 文件还是目录拓展
+    if input_clean in siblings_dict and not siblings_dict[input_clean].rfilename.endswith("/"):
+        is_file = True
+        all_paths = [input_clean]
+    else:
+        is_file = False
+        all_paths = [k for k in siblings_dict if k.startswith(input_clean + "/") and not k.endswith("/")]
+    # 若获取不到siblings(如 Huggingface 失败,直接按输入处理)
+    if not all_paths:
+        is_file = os.path.splitext(input_clean)[1] != ""
+        all_paths = [input_clean] if is_file else []
+    cache_home = str(HUGGINGFACE_HUB_CACHE)
+    # 判断主逻辑
+    output_files = []
+    # ---- Huggingface 分支 ----
+    hf_ok = False
+    for relpath in all_paths:
+        ok = False
+        if relpath in siblings_dict:
+            meta = siblings_dict[relpath]
+            sha256 = ""
+            if meta.lfs:
+                sha256 = meta.lfs.sha256
+            try:
+                # 不允许下载线上文件,只寻找本地文件
+                file_path = hf_hub_download(repo_id=repo, filename=relpath, local_files_only=True)
+                if sha256 and os.path.exists(file_path):
+                    if _sha256sum(file_path) == sha256:
+                        ok = True
+                        output_files.append(file_path)
+            except Exception as e:
+                print(f"[Info] Huggingface {relpath} 获取失败: {e}")
+            if not hf_ok:
+                file_path = hf_hub_download(repo_id=repo, filename=relpath, force_download=False)
+                print("file_path = ", file_path)
+                if sha256 and _sha256sum(file_path) != sha256:
+                    raise ValueError(f"Huggingface下载后校验失败: {relpath}")
+                ok = True
+                output_files.append(file_path)
+            hf_ok = hf_ok and ok
+    # ---- ModelScope 分支 ----
+    for relpath in all_paths:
+        if hf_ok:
+            break
+        if "/" in repo:
+            org_name, model_name = repo.split("/", 1)
+        else:
+            org_name, model_name = "modelscope", repo
+        # 目录结构: 缓存/home/modelscope-fallback/org/model/相对路径
+        target_dir = os.path.join(cache_home, "modelscope-fallback", org_name, model_name, os.path.dirname(relpath))
+        os.makedirs(target_dir, exist_ok=True)
+        local_path = os.path.join(target_dir, os.path.basename(relpath))
+        remote_len = 0
+        sha256 = ""
+        try:
+            get_meta_url = f"https://www.modelscope.cn/api/v1/models/{org_name}/{model_name}/repo/raw?Revision=master&FilePath={relpath}&Needmeta=true"
+            resp = requests.get(get_meta_url, timeout=15)
+            if resp.ok:
+                remote_len = resp.json()["Data"]["MetaContent"]["Size"]
+                sha256 = resp.json()["Data"]["MetaContent"]["Sha256"]
+        except Exception as e:
+            print(f"[Info] modelscope {relpath} 获取失败: {e}")
+        ok_local = False
+        if remote_len > 0 and os.path.exists(local_path):
+            if sha256 == _sha256sum(local_path):
+                output_files.append(local_path)
+                ok_local = True
+        if not ok_local:
+            try:
+                modelscope_url = f"https://www.modelscope.cn/api/v1/models/{org_name}/{model_name}/repo?Revision=master&FilePath={relpath}"
+                with requests.get(modelscope_url, stream=True, timeout=30) as resp:
+                    resp.raise_for_status()
+                    with open(local_path, 'wb') as f:
+                        for chunk in resp.iter_content(1024*1024):
+                            if chunk:
+                                f.write(chunk)
+                if remote_len == 0 or os.path.getsize(local_path) == remote_len:
+                    output_files.append(local_path)
+                    ok_local = True
+            except Exception as e:
+                print(f"[Error] ModelScope下载失败: {relpath} {e}")
+    if not output_files:
+        raise FileNotFoundError(f"{relative_path} 在 Huggingface 和 ModelScope 都未能获取")
+    if is_file:
+        return output_files[0]
+    else:
+        # 输入是文件,只返回路径字符串
+        return os.path.dirname(os.path.abspath(output_files[0]))
+if __name__ == '__main__':
+    path1 = get_file_from_repos("models/README.md")
+    print("本地文件绝对路径:", path1)
+    path2 = get_file_from_repos("models/OCR/paddleocr_torch/")
+    print("本地文件绝对路径:", path2)