|
|
@@ -0,0 +1,472 @@
|
|
|
+"""PP-DocLayoutV3 Layout 适配器 (符合 BaseLayoutDetector 规范)
|
|
|
+
|
|
|
+基于 HuggingFace transformers 加载 PaddlePaddle PP-DocLayoutV3 布局模型,与 Docling 适配器一致:
|
|
|
+使用 AutoProcessor + AutoModelForObjectDetection,由 transformers 的 image_processor 做预处理与
|
|
|
+post_process_object_detection 做后处理,不自行实现推理管线。
|
|
|
+
|
|
|
+- Hugging Face: https://huggingface.co/PaddlePaddle/PP-DocLayoutV3_safetensors
|
|
|
+- 本地缓存: HF_HOME/hub/models--PaddlePaddle--PP-DocLayoutV3_safetensors/snapshots/<hash>
|
|
|
+
|
|
|
+依赖: transformers 需包含 pp_doclayout_v3(主分支已支持)。若报 KeyError/ValueError,
|
|
|
+请执行: pip install -U transformers 或 pip install git+https://github.com/huggingface/transformers.git
|
|
|
+"""
|
|
|
+
|
|
|
+import cv2
|
|
|
+import numpy as np
|
|
|
+import threading
|
|
|
+import os
|
|
|
+from pathlib import Path
|
|
|
+from typing import Dict, List, Union, Any, Optional
|
|
|
+from PIL import Image
|
|
|
+
|
|
|
+try:
|
|
|
+ from .base import BaseLayoutDetector
|
|
|
+except ImportError:
|
|
|
+ from base import BaseLayoutDetector
|
|
|
+
|
|
|
+_model_init_lock = threading.Lock()
|
|
|
+
|
|
|
+
|
|
|
+class PPDocLayoutV3Detector(BaseLayoutDetector):
|
|
|
+ """PP-DocLayoutV3 布局检测适配器(PaddleOCR-VL-1.5 版式分析模块)"""
|
|
|
+
|
|
|
+ # PP-DocLayoutV3 官方 id2label(来自 config.json,25 类,无 offset)
|
|
|
+ PPDOCLAYOUT_V3_LABELS = {
|
|
|
+ 0: "abstract",
|
|
|
+ 1: "algorithm",
|
|
|
+ 2: "aside_text",
|
|
|
+ 3: "chart",
|
|
|
+ 4: "content",
|
|
|
+ 5: "formula",
|
|
|
+ 6: "doc_title",
|
|
|
+ 7: "figure_title",
|
|
|
+ 8: "footer",
|
|
|
+ 9: "footer",
|
|
|
+ 10: "footnote",
|
|
|
+ 11: "formula_number",
|
|
|
+ 12: "header",
|
|
|
+ 13: "header",
|
|
|
+ 14: "image",
|
|
|
+ 15: "formula",
|
|
|
+ 16: "number",
|
|
|
+ 17: "paragraph_title",
|
|
|
+ 18: "reference",
|
|
|
+ 19: "reference_content",
|
|
|
+ 20: "seal",
|
|
|
+ 21: "table",
|
|
|
+ 22: "text",
|
|
|
+ 23: "text",
|
|
|
+ 24: "vision_footnote",
|
|
|
+ }
|
|
|
+
|
|
|
+ CATEGORY_MAP = {
|
|
|
+ "abstract": "text",
|
|
|
+ "algorithm": "text",
|
|
|
+ "aside_text": "text",
|
|
|
+ "chart": "image_body",
|
|
|
+ "content": "text",
|
|
|
+ "formula": "interline_equation",
|
|
|
+ "doc_title": "title",
|
|
|
+ "figure_title": "image_caption",
|
|
|
+ "footer": "footer",
|
|
|
+ "footnote": "page_footnote",
|
|
|
+ "formula_number": "interline_equation",
|
|
|
+ "header": "header",
|
|
|
+ "image": "image_body",
|
|
|
+ "number": "text",
|
|
|
+ "paragraph_title": "title",
|
|
|
+ "reference": "text",
|
|
|
+ "reference_content": "text",
|
|
|
+ "seal": "image_body",
|
|
|
+ "table": "table_body",
|
|
|
+ "text": "text",
|
|
|
+ "vision_footnote": "page_footnote",
|
|
|
+ }
|
|
|
+
|
|
|
+ def __init__(self, config: Dict[str, Any]):
|
|
|
+ super().__init__(config)
|
|
|
+ self.model = None
|
|
|
+ self.image_processor = None
|
|
|
+ self._device = None
|
|
|
+ self._threshold = 0.3
|
|
|
+ self._num_threads = 4
|
|
|
+ self._model_path = None
|
|
|
+
|
|
|
+ def initialize(self):
|
|
|
+ """初始化模型(与 Docling 一致:AutoProcessor + AutoModelForObjectDetection)"""
|
|
|
+ try:
|
|
|
+ import torch
|
|
|
+ from transformers import AutoProcessor, AutoModelForObjectDetection
|
|
|
+ from huggingface_hub import snapshot_download
|
|
|
+
|
|
|
+ model_dir = self.config.get("model_dir", "PaddlePaddle/PP-DocLayoutV3_safetensors")
|
|
|
+ device = self.config.get("device", "cpu")
|
|
|
+ self._threshold = self.config.get("conf", 0.3)
|
|
|
+ self._num_threads = self.config.get("num_threads", 4)
|
|
|
+
|
|
|
+ self._device = torch.device(device)
|
|
|
+ if device == "cpu":
|
|
|
+ torch.set_num_threads(self._num_threads)
|
|
|
+
|
|
|
+ model_path = Path(model_dir)
|
|
|
+ if model_path.exists() and model_path.is_dir():
|
|
|
+ self._model_path = str(model_path)
|
|
|
+ print(f"📂 Loading PP-DocLayoutV3 from local path: {self._model_path}")
|
|
|
+ else:
|
|
|
+ hf_home = os.environ.get("HF_HOME", None)
|
|
|
+ if hf_home:
|
|
|
+ cache_dir = Path(hf_home) / "hub"
|
|
|
+ else:
|
|
|
+ cache_dir = Path.home() / ".cache" / "huggingface" / "hub"
|
|
|
+
|
|
|
+ repo_id_escaped = model_dir.replace("/", "--")
|
|
|
+ model_cache_dir = cache_dir / f"models--{repo_id_escaped}"
|
|
|
+ local_model_path = None
|
|
|
+
|
|
|
+ if model_cache_dir.exists() and model_cache_dir.is_dir():
|
|
|
+ snapshots_dir = model_cache_dir / "snapshots"
|
|
|
+ if snapshots_dir.exists():
|
|
|
+ snapshots = sorted(
|
|
|
+ [d for d in snapshots_dir.iterdir() if d.is_dir()],
|
|
|
+ key=lambda x: x.stat().st_mtime,
|
|
|
+ reverse=True,
|
|
|
+ )
|
|
|
+ if snapshots:
|
|
|
+ latest = snapshots[0]
|
|
|
+ if (latest / "preprocessor_config.json").exists() and (
|
|
|
+ latest / "config.json"
|
|
|
+ ).exists() and (latest / "model.safetensors").exists():
|
|
|
+ local_model_path = latest
|
|
|
+
|
|
|
+ if local_model_path:
|
|
|
+ self._model_path = str(local_model_path)
|
|
|
+ print(f"📂 Using local cached PP-DocLayoutV3: {self._model_path}")
|
|
|
+ else:
|
|
|
+ print(f"📥 Downloading PP-DocLayoutV3 from HuggingFace: {model_dir}")
|
|
|
+ self._model_path = snapshot_download(repo_id=model_dir)
|
|
|
+ print(f"✅ PP-DocLayoutV3 downloaded/updated: {self._model_path}")
|
|
|
+
|
|
|
+ for name in ("preprocessor_config.json", "config.json", "model.safetensors"):
|
|
|
+ if not (Path(self._model_path) / name).exists():
|
|
|
+ raise FileNotFoundError(f"Missing {name} in {self._model_path}")
|
|
|
+
|
|
|
+ with _model_init_lock:
|
|
|
+ self.image_processor = AutoProcessor.from_pretrained(
|
|
|
+ self._model_path,
|
|
|
+ trust_remote_code=True,
|
|
|
+ )
|
|
|
+ self.model = AutoModelForObjectDetection.from_pretrained(
|
|
|
+ self._model_path,
|
|
|
+ trust_remote_code=True,
|
|
|
+ device_map=self._device,
|
|
|
+ )
|
|
|
+ self.model.eval()
|
|
|
+
|
|
|
+ print(f"✅ PP-DocLayoutV3 Detector initialized")
|
|
|
+ print(f" - Model: {type(self.model).__name__}")
|
|
|
+ print(f" - Device: {self._device}")
|
|
|
+ print(f" - Threshold: {self._threshold}")
|
|
|
+
|
|
|
+ except ValueError as e:
|
|
|
+ err_str = str(e)
|
|
|
+ if "pp_doclayout_v3" in err_str or "does not recognize" in err_str or "Unrecognized processing" in err_str:
|
|
|
+ raise RuntimeError(
|
|
|
+ "当前 transformers 版本不支持 pp_doclayout_v3 或 PPDocLayoutV3ImageProcessor。"
|
|
|
+ "PP-DocLayoutV3 已在 Hugging Face transformers 主分支中支持,请安装包含该模型的版本:\n"
|
|
|
+ " pip install -U transformers\n"
|
|
|
+ "或从源码安装最新版:\n"
|
|
|
+ " pip install git+https://github.com/huggingface/transformers.git"
|
|
|
+ ) from e
|
|
|
+ raise
|
|
|
+ except ImportError as e:
|
|
|
+ print(f"❌ 依赖缺失: {e}")
|
|
|
+ raise
|
|
|
+
|
|
|
+ def cleanup(self):
|
|
|
+ self.model = None
|
|
|
+ self.image_processor = None
|
|
|
+ self._model_path = None
|
|
|
+
|
|
|
+ def _detect_raw(
|
|
|
+ self,
|
|
|
+ image: Union[np.ndarray, Image.Image],
|
|
|
+ ocr_spans: Optional[List[Dict[str, Any]]] = None,
|
|
|
+ ) -> List[Dict[str, Any]]:
|
|
|
+ """原始布局检测,完全由 transformers image_processor + model + post_process_object_detection 完成。"""
|
|
|
+ import torch
|
|
|
+
|
|
|
+ if self.model is None or self.image_processor is None:
|
|
|
+ raise RuntimeError("Model not initialized. Call initialize() first.")
|
|
|
+ assert self.image_processor is not None
|
|
|
+
|
|
|
+ if isinstance(image, np.ndarray):
|
|
|
+ if len(image.shape) == 3 and image.shape[2] == 3:
|
|
|
+ image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
|
|
|
+ else:
|
|
|
+ image_rgb = image
|
|
|
+ pil_image = Image.fromarray(image_rgb).convert("RGB")
|
|
|
+ orig_h, orig_w = image.shape[:2]
|
|
|
+ else:
|
|
|
+ pil_image = image.convert("RGB")
|
|
|
+ orig_w, orig_h = image.size
|
|
|
+
|
|
|
+ with torch.inference_mode():
|
|
|
+ target_sizes = torch.tensor([pil_image.size[::-1]])
|
|
|
+ inputs = self.image_processor(images=[pil_image], return_tensors="pt").to(self._device)
|
|
|
+ outputs = self.model(**inputs)
|
|
|
+ results = self.image_processor.post_process_object_detection(
|
|
|
+ outputs,
|
|
|
+ target_sizes=target_sizes,
|
|
|
+ threshold=self._threshold,
|
|
|
+ )
|
|
|
+
|
|
|
+ w, h = pil_image.size
|
|
|
+ result = results[0]
|
|
|
+ formatted_results = []
|
|
|
+
|
|
|
+ for score, label_id, box in zip(result["scores"], result["labels"], result["boxes"]):
|
|
|
+ score = float(score.item())
|
|
|
+ label_id_int = int(label_id.item())
|
|
|
+ original_label = self.PPDOCLAYOUT_V3_LABELS.get(
|
|
|
+ label_id_int, f"unknown_{label_id_int}"
|
|
|
+ )
|
|
|
+ mineru_category = self.CATEGORY_MAP.get(original_label, "text")
|
|
|
+
|
|
|
+ bbox_float = [float(b.item()) for b in box]
|
|
|
+ x1 = min(w, max(0, bbox_float[0]))
|
|
|
+ y1 = min(h, max(0, bbox_float[1]))
|
|
|
+ x2 = min(w, max(0, bbox_float[2]))
|
|
|
+ y2 = min(h, max(0, bbox_float[3]))
|
|
|
+ bbox = [int(x1), int(y1), int(x2), int(y2)]
|
|
|
+ width = bbox[2] - bbox[0]
|
|
|
+ height = bbox[3] - bbox[1]
|
|
|
+
|
|
|
+ if width < 10 or height < 10:
|
|
|
+ continue
|
|
|
+ if width * height > orig_w * orig_h * 0.95:
|
|
|
+ continue
|
|
|
+
|
|
|
+ poly = [bbox[0], bbox[1], bbox[2], bbox[1], bbox[2], bbox[3], bbox[0], bbox[3]]
|
|
|
+ formatted_results.append({
|
|
|
+ "category": mineru_category,
|
|
|
+ "bbox": bbox,
|
|
|
+ "confidence": score,
|
|
|
+ "raw": {
|
|
|
+ "original_label": original_label,
|
|
|
+ "original_label_id": label_id_int,
|
|
|
+ "poly": poly,
|
|
|
+ "width": width,
|
|
|
+ "height": height,
|
|
|
+ },
|
|
|
+ })
|
|
|
+
|
|
|
+ return formatted_results
|
|
|
+
|
|
|
+ def detect_batch(
|
|
|
+ self,
|
|
|
+ images: List[Union[np.ndarray, Image.Image]],
|
|
|
+ ) -> List[List[Dict[str, Any]]]:
|
|
|
+ """批量检测(与 Docling 一致:image_processor + model + post_process_object_detection)。"""
|
|
|
+ import torch
|
|
|
+
|
|
|
+ if self.model is None or self.image_processor is None:
|
|
|
+ raise RuntimeError("Model not initialized. Call initialize() first.")
|
|
|
+ assert self.image_processor is not None
|
|
|
+ if not images:
|
|
|
+ return []
|
|
|
+
|
|
|
+ pil_images = []
|
|
|
+ orig_sizes = []
|
|
|
+ for image in images:
|
|
|
+ if isinstance(image, np.ndarray):
|
|
|
+ if len(image.shape) == 3 and image.shape[2] == 3:
|
|
|
+ image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
|
|
|
+ else:
|
|
|
+ image_rgb = image
|
|
|
+ pil_images.append(Image.fromarray(image_rgb).convert("RGB"))
|
|
|
+ orig_sizes.append((image.shape[1], image.shape[0]))
|
|
|
+ else:
|
|
|
+ pil_images.append(image.convert("RGB"))
|
|
|
+ orig_sizes.append(image.size)
|
|
|
+
|
|
|
+ with torch.inference_mode():
|
|
|
+ target_sizes = torch.tensor([img.size[::-1] for img in pil_images])
|
|
|
+ inputs = self.image_processor(images=pil_images, return_tensors="pt").to(self._device)
|
|
|
+ outputs = self.model(**inputs)
|
|
|
+ results_list = self.image_processor.post_process_object_detection(
|
|
|
+ outputs,
|
|
|
+ target_sizes=target_sizes,
|
|
|
+ threshold=self._threshold,
|
|
|
+ )
|
|
|
+
|
|
|
+ all_formatted = []
|
|
|
+ for pil_img, results, (orig_w, orig_h) in zip(pil_images, results_list, orig_sizes):
|
|
|
+ w, h = pil_img.size
|
|
|
+ formatted = []
|
|
|
+ for score, label_id, box in zip(
|
|
|
+ results["scores"], results["labels"], results["boxes"]
|
|
|
+ ):
|
|
|
+ score = float(score.item())
|
|
|
+ label_id_int = int(label_id.item())
|
|
|
+ original_label = self.PPDOCLAYOUT_V3_LABELS.get(
|
|
|
+ label_id_int, f"unknown_{label_id_int}"
|
|
|
+ )
|
|
|
+ mineru_category = self.CATEGORY_MAP.get(original_label, "text")
|
|
|
+ bbox_float = [float(b.item()) for b in box]
|
|
|
+ x1 = min(w, max(0, bbox_float[0]))
|
|
|
+ y1 = min(h, max(0, bbox_float[1]))
|
|
|
+ x2 = min(w, max(0, bbox_float[2]))
|
|
|
+ y2 = min(h, max(0, bbox_float[3]))
|
|
|
+ bbox = [int(x1), int(y1), int(x2), int(y2)]
|
|
|
+ width = bbox[2] - bbox[0]
|
|
|
+ height = bbox[3] - bbox[1]
|
|
|
+ if width < 10 or height < 10:
|
|
|
+ continue
|
|
|
+ if width * height > orig_w * orig_h * 0.95:
|
|
|
+ continue
|
|
|
+ poly = [
|
|
|
+ bbox[0], bbox[1], bbox[2], bbox[1],
|
|
|
+ bbox[2], bbox[3], bbox[0], bbox[3],
|
|
|
+ ]
|
|
|
+ formatted.append({
|
|
|
+ "category": mineru_category,
|
|
|
+ "bbox": bbox,
|
|
|
+ "confidence": score,
|
|
|
+ "raw": {
|
|
|
+ "original_label": original_label,
|
|
|
+ "original_label_id": label_id_int,
|
|
|
+ "poly": poly,
|
|
|
+ "width": width,
|
|
|
+ "height": height,
|
|
|
+ },
|
|
|
+ })
|
|
|
+ all_formatted.append(formatted)
|
|
|
+
|
|
|
+ return all_formatted
|
|
|
+
|
|
|
+ def visualize(
|
|
|
+ self,
|
|
|
+ img: np.ndarray,
|
|
|
+ results: List[Dict],
|
|
|
+ output_path: Optional[str] = None,
|
|
|
+ show_confidence: bool = True,
|
|
|
+ min_confidence: float = 0.0,
|
|
|
+ ) -> np.ndarray:
|
|
|
+ """可视化检测结果(与 DoclingLayoutDetector 一致)。"""
|
|
|
+ import random
|
|
|
+
|
|
|
+ vis_img = img.copy()
|
|
|
+ predefined_colors = {
|
|
|
+ "text": (153, 0, 76),
|
|
|
+ "title": (102, 102, 255),
|
|
|
+ "header": (128, 128, 128),
|
|
|
+ "footer": (128, 128, 128),
|
|
|
+ "page_footnote": (200, 200, 200),
|
|
|
+ "table_body": (204, 204, 0),
|
|
|
+ "table_caption": (255, 255, 102),
|
|
|
+ "image_body": (153, 255, 51),
|
|
|
+ "image_caption": (102, 178, 255),
|
|
|
+ "interline_equation": (0, 255, 0),
|
|
|
+ "code": (102, 0, 204),
|
|
|
+ "abandon": (100, 100, 100),
|
|
|
+ }
|
|
|
+ filtered = [r for r in results if r["confidence"] >= min_confidence]
|
|
|
+ if not filtered:
|
|
|
+ return vis_img
|
|
|
+ category_colors = {}
|
|
|
+ for res in filtered:
|
|
|
+ cat = res["category"]
|
|
|
+ if cat not in category_colors:
|
|
|
+ category_colors[cat] = (
|
|
|
+ predefined_colors.get(cat)
|
|
|
+ or (random.randint(50, 255), random.randint(50, 255), random.randint(50, 255))
|
|
|
+ )
|
|
|
+ for res in filtered:
|
|
|
+ x1, y1, x2, y2 = res["bbox"]
|
|
|
+ cat = res["category"]
|
|
|
+ color = category_colors[cat]
|
|
|
+ orig = res.get("raw", {}).get("original_label", cat)
|
|
|
+ cv2.rectangle(vis_img, (x1, y1), (x2, y2), color, 2)
|
|
|
+ label = f"{orig}->{cat} {res['confidence']:.2f}" if show_confidence else f"{orig}->{cat}"
|
|
|
+ ls, _ = cv2.getTextSize(label, cv2.FONT_HERSHEY_SIMPLEX, 0.4, 1)
|
|
|
+ cv2.rectangle(vis_img, (x1, y1 - ls[1] - 4), (x1 + ls[0], y1), color, -1)
|
|
|
+ cv2.putText(vis_img, label, (x1, y1 - 2), cv2.FONT_HERSHEY_SIMPLEX, 0.4, (255, 255, 255), 1, cv2.LINE_AA)
|
|
|
+ if output_path:
|
|
|
+ Path(output_path).parent.mkdir(parents=True, exist_ok=True)
|
|
|
+ cv2.imwrite(output_path, vis_img)
|
|
|
+ return vis_img
|
|
|
+
|
|
|
+
|
|
|
+# 测试代码
|
|
|
+if __name__ == "__main__":
|
|
|
+ import sys
|
|
|
+
|
|
|
+ # 测试配置(可使用 HF 仓库 ID 或本地 HF_HOME 缓存路径)
|
|
|
+ config = {
|
|
|
+ "model_dir": "PaddlePaddle/PP-DocLayoutV3_safetensors",
|
|
|
+ "device": "cpu",
|
|
|
+ "conf": 0.3,
|
|
|
+ "num_threads": 4,
|
|
|
+ }
|
|
|
+
|
|
|
+ # 初始化检测器
|
|
|
+ print("🔧 Initializing PP-DocLayoutV3 Detector...")
|
|
|
+ detector = PPDocLayoutV3Detector(config)
|
|
|
+ detector.initialize()
|
|
|
+
|
|
|
+ # 读取测试图像
|
|
|
+ img_path = "/Users/zhch158/workspace/data/流水分析/施博深/bank_statement_yusys_v3/施博深/施博深_page_007.png"
|
|
|
+
|
|
|
+ print(f"\n📖 Loading image: {img_path}")
|
|
|
+ img = cv2.imread(img_path)
|
|
|
+
|
|
|
+ if img is None:
|
|
|
+ print(f"❌ Failed to load image: {img_path}")
|
|
|
+ sys.exit(1)
|
|
|
+
|
|
|
+ print(f" Image shape: {img.shape}")
|
|
|
+
|
|
|
+ # 执行检测
|
|
|
+ print("\n🔍 Detecting layout...")
|
|
|
+ results = detector.detect(img)
|
|
|
+
|
|
|
+ print(f"\n✅ 检测到 {len(results)} 个区域:")
|
|
|
+ for i, res in enumerate(results, 1):
|
|
|
+ print(
|
|
|
+ f" [{i}] {res['category']}: "
|
|
|
+ f"score={res['confidence']:.3f}, "
|
|
|
+ f"bbox={res['bbox']}, "
|
|
|
+ f"original={res['raw']['original_label']}"
|
|
|
+ )
|
|
|
+
|
|
|
+ # 统计各类别
|
|
|
+ category_counts = {}
|
|
|
+ for res in results:
|
|
|
+ cat = res["category"]
|
|
|
+ category_counts[cat] = category_counts.get(cat, 0) + 1
|
|
|
+
|
|
|
+ print(f"\n📊 类别统计 (MinerU格式):")
|
|
|
+ for cat, count in sorted(category_counts.items()):
|
|
|
+ print(f" - {cat}: {count}")
|
|
|
+
|
|
|
+ # 可视化
|
|
|
+ if len(results) > 0:
|
|
|
+ print("\n🎨 Generating visualization...")
|
|
|
+
|
|
|
+ output_dir = Path(__file__).resolve().parent.parent.parent / "tests" / "output"
|
|
|
+ output_dir.mkdir(parents=True, exist_ok=True)
|
|
|
+ output_path = output_dir / f"{Path(img_path).stem}_PP_DocLayoutV3_layout_vis.jpg"
|
|
|
+
|
|
|
+ vis_img = detector.visualize(
|
|
|
+ img,
|
|
|
+ results,
|
|
|
+ output_path=str(output_path),
|
|
|
+ show_confidence=True,
|
|
|
+ min_confidence=0.0,
|
|
|
+ )
|
|
|
+
|
|
|
+ print(f"💾 Visualization saved to: {output_path}")
|
|
|
+
|
|
|
+ # 清理
|
|
|
+ detector.cleanup()
|
|
|
+ print("\n✅ 测试完成!")
|