|
|
@@ -9,6 +9,7 @@ from .base_predictor import BasePredictor
|
|
|
from .predictor import get_predictor
|
|
|
from .token_to_middle_json import result_to_middle_json
|
|
|
from ...utils.enum_class import ModelPath
|
|
|
+from ...utils.models_download_utils import auto_download_and_get_model_root_path
|
|
|
|
|
|
|
|
|
class ModelSingleton:
|
|
|
@@ -28,6 +29,8 @@ class ModelSingleton:
|
|
|
) -> BasePredictor:
|
|
|
key = (backend,)
|
|
|
if key not in self._models:
|
|
|
+ if not model_path:
|
|
|
+ model_path = auto_download_and_get_model_root_path("/","vlm")
|
|
|
self._models[key] = get_predictor(
|
|
|
backend=backend,
|
|
|
model_path=model_path,
|
|
|
@@ -41,7 +44,7 @@ def doc_analyze(
|
|
|
image_writer: DataWriter | None,
|
|
|
predictor: BasePredictor | None = None,
|
|
|
backend="transformers",
|
|
|
- model_path=ModelPath.vlm_root_hf,
|
|
|
+ model_path: str | None = None,
|
|
|
server_url: str | None = None,
|
|
|
):
|
|
|
if predictor is None:
|
|
|
@@ -53,10 +56,10 @@ def doc_analyze(
|
|
|
# load_images_time = round(time.time() - load_images_start, 2)
|
|
|
# logger.info(f"load images cost: {load_images_time}, speed: {round(len(images_base64_list)/load_images_time, 3)} images/s")
|
|
|
|
|
|
- infer_start = time.time()
|
|
|
+ # infer_start = time.time()
|
|
|
results = predictor.batch_predict(images=images_base64_list)
|
|
|
- infer_time = round(time.time() - infer_start, 2)
|
|
|
- logger.info(f"infer finished, cost: {infer_time}, speed: {round(len(results)/infer_time, 3)} page/s")
|
|
|
+ # infer_time = round(time.time() - infer_start, 2)
|
|
|
+ # logger.info(f"infer finished, cost: {infer_time}, speed: {round(len(results)/infer_time, 3)} page/s")
|
|
|
|
|
|
middle_json = result_to_middle_json(results, images_list, pdf_doc, image_writer)
|
|
|
return middle_json, results
|
|
|
@@ -67,7 +70,7 @@ async def aio_doc_analyze(
|
|
|
image_writer: DataWriter | None,
|
|
|
predictor: BasePredictor | None = None,
|
|
|
backend="transformers",
|
|
|
- model_path=ModelPath.vlm_root_hf,
|
|
|
+ model_path: str | None = None,
|
|
|
server_url: str | None = None,
|
|
|
):
|
|
|
if predictor is None:
|