Эх сурвалжийг харах

feat: remove model path parameter from vlm_doc_analyze and streamline model loading

myhloli 5 сар өмнө
parent
commit
8737ebb2e2

+ 1 - 2
demo/demo.py

@@ -114,8 +114,7 @@ def do_parse(
             pdf_bytes = convert_pdf_bytes_to_bytes_by_pypdfium2(pdf_bytes, start_page_id, end_page_id)
             local_image_dir, local_md_dir = prepare_env(output_dir, pdf_file_name, parse_method)
             image_writer, md_writer = FileBasedDataWriter(local_image_dir), FileBasedDataWriter(local_md_dir)
-            model_path = auto_download_and_get_model_root_path('/', 'vlm')
-            middle_json, infer_result = vlm_doc_analyze(pdf_bytes, image_writer=image_writer, backend=backend, model_path=model_path, server_url=server_url)
+            middle_json, infer_result = vlm_doc_analyze(pdf_bytes, image_writer=image_writer, backend=backend, server_url=server_url)
 
             pdf_info = middle_json["pdf_info"]
 

+ 8 - 5
mineru/backend/vlm/vlm_analyze.py

@@ -9,6 +9,7 @@ from .base_predictor import BasePredictor
 from .predictor import get_predictor
 from .token_to_middle_json import result_to_middle_json
 from ...utils.enum_class import ModelPath
+from ...utils.models_download_utils import auto_download_and_get_model_root_path
 
 
 class ModelSingleton:
@@ -28,6 +29,8 @@ class ModelSingleton:
     ) -> BasePredictor:
         key = (backend,)
         if key not in self._models:
+            if not model_path:
+                model_path = auto_download_and_get_model_root_path("/","vlm")
             self._models[key] = get_predictor(
                 backend=backend,
                 model_path=model_path,
@@ -41,7 +44,7 @@ def doc_analyze(
     image_writer: DataWriter | None,
     predictor: BasePredictor | None = None,
     backend="transformers",
-    model_path=ModelPath.vlm_root_hf,
+    model_path: str | None = None,
     server_url: str | None = None,
 ):
     if predictor is None:
@@ -53,10 +56,10 @@ def doc_analyze(
     # load_images_time = round(time.time() - load_images_start, 2)
     # logger.info(f"load images cost: {load_images_time}, speed: {round(len(images_base64_list)/load_images_time, 3)} images/s")
 
-    infer_start = time.time()
+    # infer_start = time.time()
     results = predictor.batch_predict(images=images_base64_list)
-    infer_time = round(time.time() - infer_start, 2)
-    logger.info(f"infer finished, cost: {infer_time}, speed: {round(len(results)/infer_time, 3)} page/s")
+    # infer_time = round(time.time() - infer_start, 2)
+    # logger.info(f"infer finished, cost: {infer_time}, speed: {round(len(results)/infer_time, 3)} page/s")
 
     middle_json = result_to_middle_json(results, images_list, pdf_doc, image_writer)
     return middle_json, results
@@ -67,7 +70,7 @@ async def aio_doc_analyze(
     image_writer: DataWriter | None,
     predictor: BasePredictor | None = None,
     backend="transformers",
-    model_path=ModelPath.vlm_root_hf,
+    model_path: str | None = None,
     server_url: str | None = None,
 ):
     if predictor is None:

+ 1 - 3
mineru/cli/common.py

@@ -16,7 +16,6 @@ from mineru.backend.pipeline.pipeline_analyze import doc_analyze as pipeline_doc
 from mineru.data.data_reader_writer import FileBasedDataWriter
 from mineru.utils.draw_bbox import draw_layout_bbox, draw_span_bbox
 from mineru.utils.enum_class import MakeMode
-from mineru.utils.models_download_utils import auto_download_and_get_model_root_path
 from mineru.utils.pdf_image_tools import images_bytes_to_pdf_bytes
 
 pdf_suffixes = [".pdf"]
@@ -173,8 +172,7 @@ def do_parse(
             pdf_bytes = convert_pdf_bytes_to_bytes_by_pypdfium2(pdf_bytes, start_page_id, end_page_id)
             local_image_dir, local_md_dir = prepare_env(output_dir, pdf_file_name, parse_method)
             image_writer, md_writer = FileBasedDataWriter(local_image_dir), FileBasedDataWriter(local_md_dir)
-            model_path = auto_download_and_get_model_root_path('/', 'vlm')
-            middle_json, infer_result = vlm_doc_analyze(pdf_bytes, image_writer=image_writer, backend=backend, model_path=model_path, server_url=server_url)
+            middle_json, infer_result = vlm_doc_analyze(pdf_bytes, image_writer=image_writer, backend=backend, server_url=server_url)
 
             pdf_info = middle_json["pdf_info"]