Bläddra i källkod

refactor: replace get_file_from_repos with auto_download_and_get_model_root_path in multiple files

myhloli 5 månader sedan
förälder
incheckning
284cec041a

+ 4 - 6
mineru/backend/pipeline/model_init.py

@@ -10,7 +10,7 @@ from ...model.mfr.unimernet.Unimernet import UnimernetModel
 from ...model.ocr.paddleocr2pytorch.pytorch_paddle import PytorchPaddleOCR
 from ...model.table.rapid_table import RapidTableModel
 from ...utils.enum_class import ModelPath
-from ...utils.models_download_utils import get_file_from_repos
+from ...utils.models_download_utils import auto_download_and_get_model_root_path
 
 
 def table_model_init(lang=None):
@@ -144,15 +144,13 @@ class MineruPipelineModel:
             self.mfd_model = atom_model_manager.get_atom_model(
                 atom_model_name=AtomicModel.MFD,
                 mfd_weights=str(
-                    get_file_from_repos(ModelPath.yolo_v8_mfd)
+                    os.path.join(auto_download_and_get_model_root_path(ModelPath.yolo_v8_mfd), ModelPath.yolo_v8_mfd)
                 ),
                 device=self.device,
             )
 
             # 初始化公式解析模型
-            mfr_weight_dir = str(
-                get_file_from_repos(ModelPath.unimernet_small)
-            )
+            mfr_weight_dir = os.path.join(auto_download_and_get_model_root_path(ModelPath.unimernet_small), ModelPath.unimernet_small)
 
             self.mfr_model = atom_model_manager.get_atom_model(
                 atom_model_name=AtomicModel.MFR,
@@ -164,7 +162,7 @@ class MineruPipelineModel:
         self.layout_model = atom_model_manager.get_atom_model(
             atom_model_name=AtomicModel.Layout,
             doclayout_yolo_weights=str(
-                get_file_from_repos(ModelPath.doclayout_yolo)
+                os.path.join(auto_download_and_get_model_root_path(ModelPath.doclayout_yolo), ModelPath.doclayout_yolo)
             ),
             device=self.device,
         )

+ 9 - 5
mineru/cli/common.py

@@ -16,7 +16,7 @@ from mineru.backend.pipeline.pipeline_analyze import doc_analyze as pipeline_doc
 from mineru.data.data_reader_writer import FileBasedDataWriter
 from mineru.utils.draw_bbox import draw_layout_bbox, draw_span_bbox
 from mineru.utils.enum_class import MakeMode
-from mineru.utils.models_download_utils import get_file_from_repos
+from mineru.utils.models_download_utils import auto_download_and_get_model_root_path
 from mineru.utils.pdf_image_tools import images_bytes_to_pdf_bytes
 
 pdf_suffixes = [".pdf"]
@@ -168,7 +168,7 @@ def do_parse(
             pdf_bytes = convert_pdf_bytes_to_bytes_by_pypdfium2(pdf_bytes, start_page_id, end_page_id)
             local_image_dir, local_md_dir = prepare_env(output_dir, pdf_file_name, parse_method)
             image_writer, md_writer = FileBasedDataWriter(local_image_dir), FileBasedDataWriter(local_md_dir)
-            model_path = get_file_from_repos('/','vlm')
+            model_path = auto_download_and_get_model_root_path('/', 'vlm')
             middle_json, infer_result = vlm_doc_analyze(pdf_bytes, image_writer=image_writer, backend=backend, model_path=model_path, server_url=server_url)
 
             pdf_info = middle_json["pdf_info"]
@@ -219,10 +219,14 @@ def do_parse(
 
 
 if __name__ == "__main__":
-    pdf_path = "../../demo/pdfs/demo2.pdf"
-    # pdf_path = "C:/Users/zhaoxiaomeng/Downloads/input_img_0.jpg"
+    # pdf_path = "../../demo/pdfs/demo3.pdf"
+    pdf_path = "C:/Users/zhaoxiaomeng/Downloads/4546d0e2-ba60-40a5-a17e-b68555cec741.pdf"
 
     try:
-       do_parse("./output", [Path(pdf_path).stem], [read_fn(Path(pdf_path))],["ch"], end_page_id=1, backend='vlm-huggingface')
+       do_parse("./output", [Path(pdf_path).stem], [read_fn(Path(pdf_path))],["ch"],
+                end_page_id=10,
+                backend='vlm-huggingface'
+                # backend = 'pipeline'
+                )
     except Exception as e:
         logger.exception(e)

+ 7 - 4
mineru/model/ocr/paddleocr2pytorch/pytorch_paddle.py

@@ -1,6 +1,6 @@
 # Copyright (c) Opendatalab. All rights reserved.
 import copy
-import os.path
+import os
 import warnings
 from pathlib import Path
 
@@ -11,7 +11,7 @@ from loguru import logger
 
 from mineru.utils.config_reader import get_device
 from mineru.utils.enum_class import ModelPath
-from mineru.utils.models_download_utils import get_file_from_repos
+from mineru.utils.models_download_utils import auto_download_and_get_model_root_path
 from ....utils.ocr_utils import check_img, preprocess_image, sorted_boxes, merge_det_boxes, update_det_boxes, get_rotate_crop_image
 from .tools.infer.predict_system import TextSystem
 from .tools.infer import pytorchocr_utility as utility
@@ -77,8 +77,11 @@ class PytorchPaddleOCR(TextSystem):
             config = yaml.safe_load(file)
             det, rec, dict_file = get_model_params(self.lang, config)
         ocr_models_dir = ModelPath.pytorch_paddle
-        det_model_path = get_file_from_repos(f"{ocr_models_dir}/{det}")
-        rec_model_path = get_file_from_repos(f"{ocr_models_dir}/{rec}")
+
+        det_model_path = f"{ocr_models_dir}/{det}"
+        det_model_path = os.path.join(auto_download_and_get_model_root_path(det_model_path), det_model_path)
+        rec_model_path = f"{ocr_models_dir}/{rec}"
+        rec_model_path = os.path.join(auto_download_and_get_model_root_path(rec_model_path), rec_model_path)
         kwargs['det_model_path'] = det_model_path
         kwargs['rec_model_path'] = rec_model_path
         kwargs['rec_char_dict_path'] = os.path.join(root_dir, 'pytorchocr', 'utils', 'resources', 'dict', dict_file)

+ 3 - 2
mineru/model/table/rapid_table.py

@@ -1,15 +1,16 @@
+import os
 import cv2
 import numpy as np
 from loguru import logger
 from rapid_table import RapidTable, RapidTableInput
 
 from mineru.utils.enum_class import ModelPath
-from mineru.utils.models_download_utils import get_file_from_repos
+from mineru.utils.models_download_utils import auto_download_and_get_model_root_path
 
 
 class RapidTableModel(object):
     def __init__(self, ocr_engine):
-        slanet_plus_model_path = get_file_from_repos(ModelPath.slanet_plus)
+        slanet_plus_model_path = os.path.join(auto_download_and_get_model_root_path(ModelPath.slanet_plus), ModelPath.slanet_plus)
         input_args = RapidTableInput(model_type='slanet_plus', model_path=slanet_plus_model_path)
         self.table_model = RapidTable(input_args)
         self.ocr_engine = ocr_engine

+ 2 - 2
mineru/utils/block_sort.py

@@ -9,7 +9,7 @@ from loguru import logger
 
 from mineru.utils.config_reader import get_device
 from mineru.utils.enum_class import BlockType, ModelPath
-from mineru.utils.models_download_utils import get_file_from_repos
+from mineru.utils.models_download_utils import auto_download_and_get_model_root_path
 
 
 def sort_blocks_by_bbox(blocks, page_w, page_h, footnote_blocks):
@@ -188,7 +188,7 @@ def model_init(model_name: str):
     device = torch.device(device_name)
     if model_name == 'layoutreader':
         # 检测modelscope的缓存目录是否存在
-        layoutreader_model_dir = get_file_from_repos(ModelPath.layout_reader)
+        layoutreader_model_dir = os.path.join(auto_download_and_get_model_root_path(ModelPath.layout_reader), ModelPath.layout_reader)
         if os.path.exists(layoutreader_model_dir):
             model = LayoutLMv3ForTokenClassification.from_pretrained(
                 layoutreader_model_dir

+ 7 - 7
mineru/utils/models_download_utils.py

@@ -5,7 +5,7 @@ from modelscope import snapshot_download as ms_snapshot_download
 from mineru.utils.config_reader import get_local_models_dir
 from mineru.utils.enum_class import ModelPath
 
-def get_file_from_repos(relative_path: str, repo_mode='pipeline') -> str:
+def auto_download_and_get_model_root_path(relative_path: str, repo_mode='pipeline') -> str:
     """
     支持文件或目录的可靠下载。
     - 如果输入文件: 返回本地文件绝对路径
@@ -14,7 +14,7 @@ def get_file_from_repos(relative_path: str, repo_mode='pipeline') -> str:
     :param relative_path: 文件或目录相对路径
     :return: 本地文件绝对路径或相对路径
     """
-    model_source = os.getenv('MINERU_MODEL_SOURCE', None)
+    model_source = os.getenv('MINERU_MODEL_SOURCE', "huggingface")
 
     if model_source == 'local':
         local_models_config = get_local_models_dir()
@@ -54,10 +54,10 @@ def get_file_from_repos(relative_path: str, repo_mode='pipeline') -> str:
     relative_path = relative_path.strip('/')
     cache_dir = snapshot_download(repo, allow_patterns=[relative_path, relative_path+"/*"])
 
-    return cache_dir + "/" + relative_path
+    return cache_dir
+
 
 if __name__ == '__main__':
-    path1 = get_file_from_repos("models/README.md")
-    print("本地文件绝对路径:", path1)
-    path2 = get_file_from_repos("models/OCR/paddleocr_torch/")
-    print("本地文件绝对路径:", path2)
+    path1 = "models/README.md"
+    root = auto_download_and_get_model_root_path(path1)
+    print("本地文件绝对路径:", os.path.join(root, path1))