|
@@ -15,6 +15,7 @@
|
|
|
import cv2
|
|
import cv2
|
|
|
import numpy as np
|
|
import numpy as np
|
|
|
import threading
|
|
import threading
|
|
|
|
|
+import os
|
|
|
from pathlib import Path
|
|
from pathlib import Path
|
|
|
from typing import Dict, List, Union, Any, Optional
|
|
from typing import Dict, List, Union, Any, Optional
|
|
|
from PIL import Image
|
|
from PIL import Image
|
|
@@ -127,9 +128,96 @@ class DoclingLayoutDetector(BaseLayoutDetector):
|
|
|
self._model_path = str(model_path)
|
|
self._model_path = str(model_path)
|
|
|
print(f"📂 Loading model from local path: {self._model_path}")
|
|
print(f"📂 Loading model from local path: {self._model_path}")
|
|
|
else:
|
|
else:
|
|
|
- # 从 HuggingFace 下载
|
|
|
|
|
- print(f"📥 Downloading model from HuggingFace: {model_dir}")
|
|
|
|
|
- self._model_path = snapshot_download(repo_id=model_dir)
|
|
|
|
|
|
|
+ # HuggingFace 仓库 ID,先检查本地缓存
|
|
|
|
|
+ # 获取 HuggingFace 缓存目录
|
|
|
|
|
+ hf_home = os.environ.get('HF_HOME', None)
|
|
|
|
|
+ if hf_home:
|
|
|
|
|
+ cache_dir = Path(hf_home) / "hub"
|
|
|
|
|
+ else:
|
|
|
|
|
+ cache_dir = Path.home() / ".cache" / "huggingface" / "hub"
|
|
|
|
|
+
|
|
|
|
|
+ # 将模型 ID 转换为缓存目录格式
|
|
|
|
|
+ # 例如: ds4sd/docling-layout-old -> models--ds4sd--docling-layout-old
|
|
|
|
|
+ repo_id_escaped = model_dir.replace("/", "--")
|
|
|
|
|
+ model_cache_dir = cache_dir / f"models--{repo_id_escaped}"
|
|
|
|
|
+
|
|
|
|
|
+ # 先尝试从本地缓存加载(避免不必要的网络请求)
|
|
|
|
|
+ local_model_path = None
|
|
|
|
|
+ if model_cache_dir.exists() and model_cache_dir.is_dir():
|
|
|
|
|
+ snapshots_dir = model_cache_dir / "snapshots"
|
|
|
|
|
+ if snapshots_dir.exists():
|
|
|
|
|
+ # 获取所有 snapshot 目录,按修改时间排序
|
|
|
|
|
+ snapshots = sorted(
|
|
|
|
|
+ [d for d in snapshots_dir.iterdir() if d.is_dir()],
|
|
|
|
|
+ key=lambda x: x.stat().st_mtime,
|
|
|
|
|
+ reverse=True
|
|
|
|
|
+ )
|
|
|
|
|
+ if snapshots:
|
|
|
|
|
+ # 检查最新的 snapshot 是否完整
|
|
|
|
|
+ latest_snapshot = snapshots[0]
|
|
|
|
|
+ processor_config = latest_snapshot / "preprocessor_config.json"
|
|
|
|
|
+ model_config = latest_snapshot / "config.json"
|
|
|
|
|
+ safetensors_file = latest_snapshot / "model.safetensors"
|
|
|
|
|
+
|
|
|
|
|
+ if processor_config.exists() and model_config.exists() and safetensors_file.exists():
|
|
|
|
|
+ local_model_path = latest_snapshot
|
|
|
|
|
+
|
|
|
|
|
+ if local_model_path:
|
|
|
|
|
+ # 本地缓存存在且完整,直接使用(不进行网络请求)
|
|
|
|
|
+ self._model_path = str(local_model_path)
|
|
|
|
|
+ print(f"📂 Using local cached model: {self._model_path}")
|
|
|
|
|
+ print(f" (Skipping network check - model already cached)")
|
|
|
|
|
+ else:
|
|
|
|
|
+ # 本地缓存不存在或不完整,尝试从 HuggingFace 下载或更新
|
|
|
|
|
+ print(f"📥 Model not found in local cache, downloading from HuggingFace: {model_dir}")
|
|
|
|
|
+ try:
|
|
|
|
|
+ # snapshot_download 会自动检查本地缓存,如果存在且是最新的,不会重新下载
|
|
|
|
|
+ # 只有在需要更新或首次下载时才会下载
|
|
|
|
|
+ self._model_path = snapshot_download(repo_id=model_dir)
|
|
|
|
|
+ print(f"✅ Model downloaded/updated: {self._model_path}")
|
|
|
|
|
+ except Exception as e:
|
|
|
|
|
+ # HuggingFace 访问失败,再次尝试查找本地缓存(可能之前检查时遗漏)
|
|
|
|
|
+ print(f"⚠️ Failed to download from HuggingFace: {e}")
|
|
|
|
|
+ print(f"🔍 Trying to find local cached model again...")
|
|
|
|
|
+
|
|
|
|
|
+ if model_cache_dir.exists() and model_cache_dir.is_dir():
|
|
|
|
|
+ snapshots_dir = model_cache_dir / "snapshots"
|
|
|
|
|
+ if snapshots_dir.exists():
|
|
|
|
|
+ snapshots = sorted(
|
|
|
|
|
+ [d for d in snapshots_dir.iterdir() if d.is_dir()],
|
|
|
|
|
+ key=lambda x: x.stat().st_mtime,
|
|
|
|
|
+ reverse=True
|
|
|
|
|
+ )
|
|
|
|
|
+ if snapshots:
|
|
|
|
|
+ local_model_path = snapshots[0]
|
|
|
|
|
+ processor_config = local_model_path / "preprocessor_config.json"
|
|
|
|
|
+ model_config = local_model_path / "config.json"
|
|
|
|
|
+ safetensors_file = local_model_path / "model.safetensors"
|
|
|
|
|
+
|
|
|
|
|
+ if processor_config.exists() and model_config.exists() and safetensors_file.exists():
|
|
|
|
|
+ self._model_path = str(local_model_path)
|
|
|
|
|
+ print(f"✅ Found local cached model: {self._model_path}")
|
|
|
|
|
+ else:
|
|
|
|
|
+ raise FileNotFoundError(
|
|
|
|
|
+ f"Local cached model found but missing required files in {local_model_path}. "
|
|
|
|
|
+ f"Required: preprocessor_config.json, config.json, model.safetensors"
|
|
|
|
|
+ )
|
|
|
|
|
+ else:
|
|
|
|
|
+ raise FileNotFoundError(
|
|
|
|
|
+ f"No snapshots found in {snapshots_dir}. "
|
|
|
|
|
+ f"Please download the model first or check your network connection."
|
|
|
|
|
+ )
|
|
|
|
|
+ else:
|
|
|
|
|
+ raise FileNotFoundError(
|
|
|
|
|
+ f"Cache directory exists but no snapshots found: {model_cache_dir}. "
|
|
|
|
|
+ f"Please download the model first or check your network connection."
|
|
|
|
|
+ )
|
|
|
|
|
+ else:
|
|
|
|
|
+ raise FileNotFoundError(
|
|
|
|
|
+ f"Model not found in local cache: {model_cache_dir}. "
|
|
|
|
|
+ f"Please download the model first or check your network connection. "
|
|
|
|
|
+ f"Original error: {e}"
|
|
|
|
|
+ )
|
|
|
|
|
|
|
|
# 检查必要文件
|
|
# 检查必要文件
|
|
|
processor_config = Path(self._model_path) / "preprocessor_config.json"
|
|
processor_config = Path(self._model_path) / "preprocessor_config.json"
|