Ver Fonte

feat: Add .gitignore, implement grid recovery syntax verification, and enhance HuggingFace model loading with local cache prioritization.

zhch158_admin há 3 dias atrás
pai
commit
76f8e864a8

+ 91 - 3
ocr_tools/universal_doc_parser/models/adapters/docling_layout_adapter.py

@@ -15,6 +15,7 @@
 import cv2
 import numpy as np
 import threading
+import os
 from pathlib import Path
 from typing import Dict, List, Union, Any, Optional
 from PIL import Image
@@ -127,9 +128,96 @@ class DoclingLayoutDetector(BaseLayoutDetector):
                 self._model_path = str(model_path)
                 print(f"📂 Loading model from local path: {self._model_path}")
             else:
-                # 从 HuggingFace 下载
-                print(f"📥 Downloading model from HuggingFace: {model_dir}")
-                self._model_path = snapshot_download(repo_id=model_dir)
+                # HuggingFace 仓库 ID,先检查本地缓存
+                # 获取 HuggingFace 缓存目录
+                hf_home = os.environ.get('HF_HOME', None)
+                if hf_home:
+                    cache_dir = Path(hf_home) / "hub"
+                else:
+                    cache_dir = Path.home() / ".cache" / "huggingface" / "hub"
+                
+                # 将模型 ID 转换为缓存目录格式
+                # 例如: ds4sd/docling-layout-old -> models--ds4sd--docling-layout-old
+                repo_id_escaped = model_dir.replace("/", "--")
+                model_cache_dir = cache_dir / f"models--{repo_id_escaped}"
+                
+                # 先尝试从本地缓存加载(避免不必要的网络请求)
+                local_model_path = None
+                if model_cache_dir.exists() and model_cache_dir.is_dir():
+                    snapshots_dir = model_cache_dir / "snapshots"
+                    if snapshots_dir.exists():
+                        # 获取所有 snapshot 目录,按修改时间排序
+                        snapshots = sorted(
+                            [d for d in snapshots_dir.iterdir() if d.is_dir()],
+                            key=lambda x: x.stat().st_mtime,
+                            reverse=True
+                        )
+                        if snapshots:
+                            # 检查最新的 snapshot 是否完整
+                            latest_snapshot = snapshots[0]
+                            processor_config = latest_snapshot / "preprocessor_config.json"
+                            model_config = latest_snapshot / "config.json"
+                            safetensors_file = latest_snapshot / "model.safetensors"
+                            
+                            if processor_config.exists() and model_config.exists() and safetensors_file.exists():
+                                local_model_path = latest_snapshot
+                
+                if local_model_path:
+                    # 本地缓存存在且完整,直接使用(不进行网络请求)
+                    self._model_path = str(local_model_path)
+                    print(f"📂 Using local cached model: {self._model_path}")
+                    print(f"   (Skipping network check - model already cached)")
+                else:
+                    # 本地缓存不存在或不完整,尝试从 HuggingFace 下载或更新
+                    print(f"📥 Model not found in local cache, downloading from HuggingFace: {model_dir}")
+                    try:
+                        # snapshot_download 会自动检查本地缓存,如果存在且是最新的,不会重新下载
+                        # 只有在需要更新或首次下载时才会下载
+                        self._model_path = snapshot_download(repo_id=model_dir)
+                        print(f"✅ Model downloaded/updated: {self._model_path}")
+                    except Exception as e:
+                        # HuggingFace 访问失败,再次尝试查找本地缓存(可能之前检查时遗漏)
+                        print(f"⚠️ Failed to download from HuggingFace: {e}")
+                        print(f"🔍 Trying to find local cached model again...")
+                        
+                        if model_cache_dir.exists() and model_cache_dir.is_dir():
+                            snapshots_dir = model_cache_dir / "snapshots"
+                            if snapshots_dir.exists():
+                                snapshots = sorted(
+                                    [d for d in snapshots_dir.iterdir() if d.is_dir()],
+                                    key=lambda x: x.stat().st_mtime,
+                                    reverse=True
+                                )
+                                if snapshots:
+                                    local_model_path = snapshots[0]
+                                    processor_config = local_model_path / "preprocessor_config.json"
+                                    model_config = local_model_path / "config.json"
+                                    safetensors_file = local_model_path / "model.safetensors"
+                                    
+                                    if processor_config.exists() and model_config.exists() and safetensors_file.exists():
+                                        self._model_path = str(local_model_path)
+                                        print(f"✅ Found local cached model: {self._model_path}")
+                                    else:
+                                        raise FileNotFoundError(
+                                            f"Local cached model found but missing required files in {local_model_path}. "
+                                            f"Required: preprocessor_config.json, config.json, model.safetensors"
+                                        )
+                                else:
+                                    raise FileNotFoundError(
+                                        f"No snapshots found in {snapshots_dir}. "
+                                        f"Please download the model first or check your network connection."
+                                    )
+                            else:
+                                raise FileNotFoundError(
+                                    f"Cache directory exists but no snapshots found: {model_cache_dir}. "
+                                    f"Please download the model first or check your network connection."
+                                )
+                        else:
+                            raise FileNotFoundError(
+                                f"Model not found in local cache: {model_cache_dir}. "
+                                f"Please download the model first or check your network connection. "
+                                f"Original error: {e}"
+                            )
             
             # 检查必要文件
             processor_config = Path(self._model_path) / "preprocessor_config.json"