Bläddra i källkod

feat(PPDocLayoutV3): 添加PP-DocLayoutV3布局适配器,支持基于HuggingFace的模型加载与处理

zhch158_admin 1 vecka sedan
förälder
incheckning
d6a8cc7e15

+ 5 - 0
ocr_tools/universal_doc_parser/models/adapters/__init__.py

@@ -15,6 +15,7 @@ from .paddle_layout_detector import PaddleLayoutDetector
 from .paddle_vl_adapter import PaddleVLRecognizer
 
 from .docling_layout_adapter import DoclingLayoutDetector
+from .pp_doclayout_v3_layout_adapter import PPDocLayoutV3Detector
 
 # 可选导入 DiT 适配器
 try:
@@ -50,6 +51,8 @@ __all__ = [
     
     # Docling 适配器
     'DoclingLayoutDetector',
+    # PP-DocLayoutV3 适配器
+    'PPDocLayoutV3Detector',
 ]
 
 # 如果 DiT 可用,添加到导出列表
@@ -88,6 +91,8 @@ def get_layout_detector(config: dict):
         return MinerULayoutDetector(config)
     elif module == 'docling':
         return DoclingLayoutDetector(config)
+    elif module == 'pp_doclayout_v3':
+        return PPDocLayoutV3Detector(config)
     elif module == 'dit':
         if not DIT_AVAILABLE:
             raise ImportError("DiT adapter not available. Please ensure detectron2 and ditod are installed.")

+ 472 - 0
ocr_tools/universal_doc_parser/models/adapters/pp_doclayout_v3_layout_adapter.py

@@ -0,0 +1,472 @@
+"""PP-DocLayoutV3 Layout 适配器 (符合 BaseLayoutDetector 规范)
+
+基于 HuggingFace transformers 加载 PaddlePaddle PP-DocLayoutV3 布局模型,与 Docling 适配器一致:
+使用 AutoProcessor + AutoModelForObjectDetection,由 transformers 的 image_processor 做预处理与
+post_process_object_detection 做后处理,不自行实现推理管线。
+
+- Hugging Face: https://huggingface.co/PaddlePaddle/PP-DocLayoutV3_safetensors
+- 本地缓存: HF_HOME/hub/models--PaddlePaddle--PP-DocLayoutV3_safetensors/snapshots/<hash>
+
+依赖: transformers 需包含 pp_doclayout_v3(主分支已支持)。若报 KeyError/ValueError,
+请执行: pip install -U transformers 或 pip install git+https://github.com/huggingface/transformers.git
+"""
+
+import cv2
+import numpy as np
+import threading
+import os
+from pathlib import Path
+from typing import Dict, List, Union, Any, Optional
+from PIL import Image
+
+try:
+    from .base import BaseLayoutDetector
+except ImportError:
+    from base import BaseLayoutDetector
+
+_model_init_lock = threading.Lock()
+
+
+class PPDocLayoutV3Detector(BaseLayoutDetector):
+    """PP-DocLayoutV3 布局检测适配器(PaddleOCR-VL-1.5 版式分析模块)"""
+
+    # PP-DocLayoutV3 官方 id2label(来自 config.json,25 类,无 offset)
+    PPDOCLAYOUT_V3_LABELS = {
+        0: "abstract",
+        1: "algorithm",
+        2: "aside_text",
+        3: "chart",
+        4: "content",
+        5: "formula",
+        6: "doc_title",
+        7: "figure_title",
+        8: "footer",
+        9: "footer",
+        10: "footnote",
+        11: "formula_number",
+        12: "header",
+        13: "header",
+        14: "image",
+        15: "formula",
+        16: "number",
+        17: "paragraph_title",
+        18: "reference",
+        19: "reference_content",
+        20: "seal",
+        21: "table",
+        22: "text",
+        23: "text",
+        24: "vision_footnote",
+    }
+
+    CATEGORY_MAP = {
+        "abstract": "text",
+        "algorithm": "text",
+        "aside_text": "text",
+        "chart": "image_body",
+        "content": "text",
+        "formula": "interline_equation",
+        "doc_title": "title",
+        "figure_title": "image_caption",
+        "footer": "footer",
+        "footnote": "page_footnote",
+        "formula_number": "interline_equation",
+        "header": "header",
+        "image": "image_body",
+        "number": "text",
+        "paragraph_title": "title",
+        "reference": "text",
+        "reference_content": "text",
+        "seal": "image_body",
+        "table": "table_body",
+        "text": "text",
+        "vision_footnote": "page_footnote",
+    }
+
+    def __init__(self, config: Dict[str, Any]):
+        super().__init__(config)
+        self.model = None
+        self.image_processor = None
+        self._device = None
+        self._threshold = 0.3
+        self._num_threads = 4
+        self._model_path = None
+
+    def initialize(self):
+        """初始化模型(与 Docling 一致:AutoProcessor + AutoModelForObjectDetection)"""
+        try:
+            import torch
+            from transformers import AutoProcessor, AutoModelForObjectDetection
+            from huggingface_hub import snapshot_download
+
+            model_dir = self.config.get("model_dir", "PaddlePaddle/PP-DocLayoutV3_safetensors")
+            device = self.config.get("device", "cpu")
+            self._threshold = self.config.get("conf", 0.3)
+            self._num_threads = self.config.get("num_threads", 4)
+
+            self._device = torch.device(device)
+            if device == "cpu":
+                torch.set_num_threads(self._num_threads)
+
+            model_path = Path(model_dir)
+            if model_path.exists() and model_path.is_dir():
+                self._model_path = str(model_path)
+                print(f"📂 Loading PP-DocLayoutV3 from local path: {self._model_path}")
+            else:
+                hf_home = os.environ.get("HF_HOME", None)
+                if hf_home:
+                    cache_dir = Path(hf_home) / "hub"
+                else:
+                    cache_dir = Path.home() / ".cache" / "huggingface" / "hub"
+
+                repo_id_escaped = model_dir.replace("/", "--")
+                model_cache_dir = cache_dir / f"models--{repo_id_escaped}"
+                local_model_path = None
+
+                if model_cache_dir.exists() and model_cache_dir.is_dir():
+                    snapshots_dir = model_cache_dir / "snapshots"
+                    if snapshots_dir.exists():
+                        snapshots = sorted(
+                            [d for d in snapshots_dir.iterdir() if d.is_dir()],
+                            key=lambda x: x.stat().st_mtime,
+                            reverse=True,
+                        )
+                        if snapshots:
+                            latest = snapshots[0]
+                            if (latest / "preprocessor_config.json").exists() and (
+                                latest / "config.json"
+                            ).exists() and (latest / "model.safetensors").exists():
+                                local_model_path = latest
+
+                if local_model_path:
+                    self._model_path = str(local_model_path)
+                    print(f"📂 Using local cached PP-DocLayoutV3: {self._model_path}")
+                else:
+                    print(f"📥 Downloading PP-DocLayoutV3 from HuggingFace: {model_dir}")
+                    self._model_path = snapshot_download(repo_id=model_dir)
+                    print(f"✅ PP-DocLayoutV3 downloaded/updated: {self._model_path}")
+
+            for name in ("preprocessor_config.json", "config.json", "model.safetensors"):
+                if not (Path(self._model_path) / name).exists():
+                    raise FileNotFoundError(f"Missing {name} in {self._model_path}")
+
+            with _model_init_lock:
+                self.image_processor = AutoProcessor.from_pretrained(
+                    self._model_path,
+                    trust_remote_code=True,
+                )
+                self.model = AutoModelForObjectDetection.from_pretrained(
+                    self._model_path,
+                    trust_remote_code=True,
+                    device_map=self._device,
+                )
+                self.model.eval()
+
+            print(f"✅ PP-DocLayoutV3 Detector initialized")
+            print(f"   - Model: {type(self.model).__name__}")
+            print(f"   - Device: {self._device}")
+            print(f"   - Threshold: {self._threshold}")
+
+        except ValueError as e:
+            err_str = str(e)
+            if "pp_doclayout_v3" in err_str or "does not recognize" in err_str or "Unrecognized processing" in err_str:
+                raise RuntimeError(
+                    "当前 transformers 版本不支持 pp_doclayout_v3 或 PPDocLayoutV3ImageProcessor。"
+                    "PP-DocLayoutV3 已在 Hugging Face transformers 主分支中支持,请安装包含该模型的版本:\n"
+                    "  pip install -U transformers\n"
+                    "或从源码安装最新版:\n"
+                    "  pip install git+https://github.com/huggingface/transformers.git"
+                ) from e
+            raise
+        except ImportError as e:
+            print(f"❌ 依赖缺失: {e}")
+            raise
+
+    def cleanup(self):
+        self.model = None
+        self.image_processor = None
+        self._model_path = None
+
+    def _detect_raw(
+        self,
+        image: Union[np.ndarray, Image.Image],
+        ocr_spans: Optional[List[Dict[str, Any]]] = None,
+    ) -> List[Dict[str, Any]]:
+        """原始布局检测,完全由 transformers image_processor + model + post_process_object_detection 完成。"""
+        import torch
+
+        if self.model is None or self.image_processor is None:
+            raise RuntimeError("Model not initialized. Call initialize() first.")
+        assert self.image_processor is not None
+
+        if isinstance(image, np.ndarray):
+            if len(image.shape) == 3 and image.shape[2] == 3:
+                image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
+            else:
+                image_rgb = image
+            pil_image = Image.fromarray(image_rgb).convert("RGB")
+            orig_h, orig_w = image.shape[:2]
+        else:
+            pil_image = image.convert("RGB")
+            orig_w, orig_h = image.size
+
+        with torch.inference_mode():
+            target_sizes = torch.tensor([pil_image.size[::-1]])
+            inputs = self.image_processor(images=[pil_image], return_tensors="pt").to(self._device)
+            outputs = self.model(**inputs)
+            results = self.image_processor.post_process_object_detection(
+                outputs,
+                target_sizes=target_sizes,
+                threshold=self._threshold,
+            )
+
+        w, h = pil_image.size
+        result = results[0]
+        formatted_results = []
+
+        for score, label_id, box in zip(result["scores"], result["labels"], result["boxes"]):
+            score = float(score.item())
+            label_id_int = int(label_id.item())
+            original_label = self.PPDOCLAYOUT_V3_LABELS.get(
+                label_id_int, f"unknown_{label_id_int}"
+            )
+            mineru_category = self.CATEGORY_MAP.get(original_label, "text")
+
+            bbox_float = [float(b.item()) for b in box]
+            x1 = min(w, max(0, bbox_float[0]))
+            y1 = min(h, max(0, bbox_float[1]))
+            x2 = min(w, max(0, bbox_float[2]))
+            y2 = min(h, max(0, bbox_float[3]))
+            bbox = [int(x1), int(y1), int(x2), int(y2)]
+            width = bbox[2] - bbox[0]
+            height = bbox[3] - bbox[1]
+
+            if width < 10 or height < 10:
+                continue
+            if width * height > orig_w * orig_h * 0.95:
+                continue
+
+            poly = [bbox[0], bbox[1], bbox[2], bbox[1], bbox[2], bbox[3], bbox[0], bbox[3]]
+            formatted_results.append({
+                "category": mineru_category,
+                "bbox": bbox,
+                "confidence": score,
+                "raw": {
+                    "original_label": original_label,
+                    "original_label_id": label_id_int,
+                    "poly": poly,
+                    "width": width,
+                    "height": height,
+                },
+            })
+
+        return formatted_results
+
+    def detect_batch(
+        self,
+        images: List[Union[np.ndarray, Image.Image]],
+    ) -> List[List[Dict[str, Any]]]:
+        """批量检测(与 Docling 一致:image_processor + model + post_process_object_detection)。"""
+        import torch
+
+        if self.model is None or self.image_processor is None:
+            raise RuntimeError("Model not initialized. Call initialize() first.")
+        assert self.image_processor is not None
+        if not images:
+            return []
+
+        pil_images = []
+        orig_sizes = []
+        for image in images:
+            if isinstance(image, np.ndarray):
+                if len(image.shape) == 3 and image.shape[2] == 3:
+                    image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
+                else:
+                    image_rgb = image
+                pil_images.append(Image.fromarray(image_rgb).convert("RGB"))
+                orig_sizes.append((image.shape[1], image.shape[0]))
+            else:
+                pil_images.append(image.convert("RGB"))
+                orig_sizes.append(image.size)
+
+        with torch.inference_mode():
+            target_sizes = torch.tensor([img.size[::-1] for img in pil_images])
+            inputs = self.image_processor(images=pil_images, return_tensors="pt").to(self._device)
+            outputs = self.model(**inputs)
+            results_list = self.image_processor.post_process_object_detection(
+                outputs,
+                target_sizes=target_sizes,
+                threshold=self._threshold,
+            )
+
+        all_formatted = []
+        for pil_img, results, (orig_w, orig_h) in zip(pil_images, results_list, orig_sizes):
+            w, h = pil_img.size
+            formatted = []
+            for score, label_id, box in zip(
+                results["scores"], results["labels"], results["boxes"]
+            ):
+                score = float(score.item())
+                label_id_int = int(label_id.item())
+                original_label = self.PPDOCLAYOUT_V3_LABELS.get(
+                    label_id_int, f"unknown_{label_id_int}"
+                )
+                mineru_category = self.CATEGORY_MAP.get(original_label, "text")
+                bbox_float = [float(b.item()) for b in box]
+                x1 = min(w, max(0, bbox_float[0]))
+                y1 = min(h, max(0, bbox_float[1]))
+                x2 = min(w, max(0, bbox_float[2]))
+                y2 = min(h, max(0, bbox_float[3]))
+                bbox = [int(x1), int(y1), int(x2), int(y2)]
+                width = bbox[2] - bbox[0]
+                height = bbox[3] - bbox[1]
+                if width < 10 or height < 10:
+                    continue
+                if width * height > orig_w * orig_h * 0.95:
+                    continue
+                poly = [
+                    bbox[0], bbox[1], bbox[2], bbox[1],
+                    bbox[2], bbox[3], bbox[0], bbox[3],
+                ]
+                formatted.append({
+                    "category": mineru_category,
+                    "bbox": bbox,
+                    "confidence": score,
+                    "raw": {
+                        "original_label": original_label,
+                        "original_label_id": label_id_int,
+                        "poly": poly,
+                        "width": width,
+                        "height": height,
+                    },
+                })
+            all_formatted.append(formatted)
+
+        return all_formatted
+
+    def visualize(
+        self,
+        img: np.ndarray,
+        results: List[Dict],
+        output_path: Optional[str] = None,
+        show_confidence: bool = True,
+        min_confidence: float = 0.0,
+    ) -> np.ndarray:
+        """可视化检测结果(与 DoclingLayoutDetector 一致)。"""
+        import random
+
+        vis_img = img.copy()
+        predefined_colors = {
+            "text": (153, 0, 76),
+            "title": (102, 102, 255),
+            "header": (128, 128, 128),
+            "footer": (128, 128, 128),
+            "page_footnote": (200, 200, 200),
+            "table_body": (204, 204, 0),
+            "table_caption": (255, 255, 102),
+            "image_body": (153, 255, 51),
+            "image_caption": (102, 178, 255),
+            "interline_equation": (0, 255, 0),
+            "code": (102, 0, 204),
+            "abandon": (100, 100, 100),
+        }
+        filtered = [r for r in results if r["confidence"] >= min_confidence]
+        if not filtered:
+            return vis_img
+        category_colors = {}
+        for res in filtered:
+            cat = res["category"]
+            if cat not in category_colors:
+                category_colors[cat] = (
+                    predefined_colors.get(cat)
+                    or (random.randint(50, 255), random.randint(50, 255), random.randint(50, 255))
+                )
+        for res in filtered:
+            x1, y1, x2, y2 = res["bbox"]
+            cat = res["category"]
+            color = category_colors[cat]
+            orig = res.get("raw", {}).get("original_label", cat)
+            cv2.rectangle(vis_img, (x1, y1), (x2, y2), color, 2)
+            label = f"{orig}->{cat} {res['confidence']:.2f}" if show_confidence else f"{orig}->{cat}"
+            ls, _ = cv2.getTextSize(label, cv2.FONT_HERSHEY_SIMPLEX, 0.4, 1)
+            cv2.rectangle(vis_img, (x1, y1 - ls[1] - 4), (x1 + ls[0], y1), color, -1)
+            cv2.putText(vis_img, label, (x1, y1 - 2), cv2.FONT_HERSHEY_SIMPLEX, 0.4, (255, 255, 255), 1, cv2.LINE_AA)
+        if output_path:
+            Path(output_path).parent.mkdir(parents=True, exist_ok=True)
+            cv2.imwrite(output_path, vis_img)
+        return vis_img
+
+
+# 测试代码
+if __name__ == "__main__":
+    import sys
+
+    # 测试配置(可使用 HF 仓库 ID 或本地 HF_HOME 缓存路径)
+    config = {
+        "model_dir": "PaddlePaddle/PP-DocLayoutV3_safetensors",
+        "device": "cpu",
+        "conf": 0.3,
+        "num_threads": 4,
+    }
+
+    # 初始化检测器
+    print("🔧 Initializing PP-DocLayoutV3 Detector...")
+    detector = PPDocLayoutV3Detector(config)
+    detector.initialize()
+
+    # 读取测试图像
+    img_path = "/Users/zhch158/workspace/data/流水分析/施博深/bank_statement_yusys_v3/施博深/施博深_page_007.png"
+
+    print(f"\n📖 Loading image: {img_path}")
+    img = cv2.imread(img_path)
+
+    if img is None:
+        print(f"❌ Failed to load image: {img_path}")
+        sys.exit(1)
+
+    print(f"   Image shape: {img.shape}")
+
+    # 执行检测
+    print("\n🔍 Detecting layout...")
+    results = detector.detect(img)
+
+    print(f"\n✅ 检测到 {len(results)} 个区域:")
+    for i, res in enumerate(results, 1):
+        print(
+            f"  [{i}] {res['category']}: "
+            f"score={res['confidence']:.3f}, "
+            f"bbox={res['bbox']}, "
+            f"original={res['raw']['original_label']}"
+        )
+
+    # 统计各类别
+    category_counts = {}
+    for res in results:
+        cat = res["category"]
+        category_counts[cat] = category_counts.get(cat, 0) + 1
+
+    print(f"\n📊 类别统计 (MinerU格式):")
+    for cat, count in sorted(category_counts.items()):
+        print(f"  - {cat}: {count}")
+
+    # 可视化
+    if len(results) > 0:
+        print("\n🎨 Generating visualization...")
+
+        output_dir = Path(__file__).resolve().parent.parent.parent / "tests" / "output"
+        output_dir.mkdir(parents=True, exist_ok=True)
+        output_path = output_dir / f"{Path(img_path).stem}_PP_DocLayoutV3_layout_vis.jpg"
+
+        vis_img = detector.visualize(
+            img,
+            results,
+            output_path=str(output_path),
+            show_confidence=True,
+            min_confidence=0.0,
+        )
+
+        print(f"💾 Visualization saved to: {output_path}")
+
+    # 清理
+    detector.cleanup()
+    print("\n✅ 测试完成!")