Просмотр исходного кода

feat: Add .gitignore, implement grid recovery syntax verification, and enhance HuggingFace model loading with local cache prioritization.

zhch158_admin 3 дней назад
Родитель
Сommit
76f8e864a8
1 измененных файлов с 91 добавлено и 3 удалено
  1. 91 3
      ocr_tools/universal_doc_parser/models/adapters/docling_layout_adapter.py

+ 91 - 3
ocr_tools/universal_doc_parser/models/adapters/docling_layout_adapter.py

@@ -15,6 +15,7 @@
 import cv2
 import cv2
 import numpy as np
 import numpy as np
 import threading
 import threading
+import os
 from pathlib import Path
 from pathlib import Path
 from typing import Dict, List, Union, Any, Optional
 from typing import Dict, List, Union, Any, Optional
 from PIL import Image
 from PIL import Image
@@ -127,9 +128,96 @@ class DoclingLayoutDetector(BaseLayoutDetector):
                 self._model_path = str(model_path)
                 self._model_path = str(model_path)
                 print(f"📂 Loading model from local path: {self._model_path}")
                 print(f"📂 Loading model from local path: {self._model_path}")
             else:
             else:
-                # 从 HuggingFace 下载
-                print(f"📥 Downloading model from HuggingFace: {model_dir}")
-                self._model_path = snapshot_download(repo_id=model_dir)
+                # HuggingFace 仓库 ID,先检查本地缓存
+                # 获取 HuggingFace 缓存目录
+                hf_home = os.environ.get('HF_HOME', None)
+                if hf_home:
+                    cache_dir = Path(hf_home) / "hub"
+                else:
+                    cache_dir = Path.home() / ".cache" / "huggingface" / "hub"
+                
+                # 将模型 ID 转换为缓存目录格式
+                # 例如: ds4sd/docling-layout-old -> models--ds4sd--docling-layout-old
+                repo_id_escaped = model_dir.replace("/", "--")
+                model_cache_dir = cache_dir / f"models--{repo_id_escaped}"
+                
+                # 先尝试从本地缓存加载(避免不必要的网络请求)
+                local_model_path = None
+                if model_cache_dir.exists() and model_cache_dir.is_dir():
+                    snapshots_dir = model_cache_dir / "snapshots"
+                    if snapshots_dir.exists():
+                        # 获取所有 snapshot 目录,按修改时间排序
+                        snapshots = sorted(
+                            [d for d in snapshots_dir.iterdir() if d.is_dir()],
+                            key=lambda x: x.stat().st_mtime,
+                            reverse=True
+                        )
+                        if snapshots:
+                            # 检查最新的 snapshot 是否完整
+                            latest_snapshot = snapshots[0]
+                            processor_config = latest_snapshot / "preprocessor_config.json"
+                            model_config = latest_snapshot / "config.json"
+                            safetensors_file = latest_snapshot / "model.safetensors"
+                            
+                            if processor_config.exists() and model_config.exists() and safetensors_file.exists():
+                                local_model_path = latest_snapshot
+                
+                if local_model_path:
+                    # 本地缓存存在且完整,直接使用(不进行网络请求)
+                    self._model_path = str(local_model_path)
+                    print(f"📂 Using local cached model: {self._model_path}")
+                    print(f"   (Skipping network check - model already cached)")
+                else:
+                    # 本地缓存不存在或不完整,尝试从 HuggingFace 下载或更新
+                    print(f"📥 Model not found in local cache, downloading from HuggingFace: {model_dir}")
+                    try:
+                        # snapshot_download 会自动检查本地缓存,如果存在且是最新的,不会重新下载
+                        # 只有在需要更新或首次下载时才会下载
+                        self._model_path = snapshot_download(repo_id=model_dir)
+                        print(f"✅ Model downloaded/updated: {self._model_path}")
+                    except Exception as e:
+                        # HuggingFace 访问失败,再次尝试查找本地缓存(可能之前检查时遗漏)
+                        print(f"⚠️ Failed to download from HuggingFace: {e}")
+                        print(f"🔍 Trying to find local cached model again...")
+                        
+                        if model_cache_dir.exists() and model_cache_dir.is_dir():
+                            snapshots_dir = model_cache_dir / "snapshots"
+                            if snapshots_dir.exists():
+                                snapshots = sorted(
+                                    [d for d in snapshots_dir.iterdir() if d.is_dir()],
+                                    key=lambda x: x.stat().st_mtime,
+                                    reverse=True
+                                )
+                                if snapshots:
+                                    local_model_path = snapshots[0]
+                                    processor_config = local_model_path / "preprocessor_config.json"
+                                    model_config = local_model_path / "config.json"
+                                    safetensors_file = local_model_path / "model.safetensors"
+                                    
+                                    if processor_config.exists() and model_config.exists() and safetensors_file.exists():
+                                        self._model_path = str(local_model_path)
+                                        print(f"✅ Found local cached model: {self._model_path}")
+                                    else:
+                                        raise FileNotFoundError(
+                                            f"Local cached model found but missing required files in {local_model_path}. "
+                                            f"Required: preprocessor_config.json, config.json, model.safetensors"
+                                        )
+                                else:
+                                    raise FileNotFoundError(
+                                        f"No snapshots found in {snapshots_dir}. "
+                                        f"Please download the model first or check your network connection."
+                                    )
+                            else:
+                                raise FileNotFoundError(
+                                    f"Cache directory exists but no snapshots found: {model_cache_dir}. "
+                                    f"Please download the model first or check your network connection."
+                                )
+                        else:
+                            raise FileNotFoundError(
+                                f"Model not found in local cache: {model_cache_dir}. "
+                                f"Please download the model first or check your network connection. "
+                                f"Original error: {e}"
+                            )
             
             
             # 检查必要文件
             # 检查必要文件
             processor_config = Path(self._model_path) / "preprocessor_config.json"
             processor_config = Path(self._model_path) / "preprocessor_config.json"