Ver Fonte

feat(新增单元格预处理与参数扫描功能): 在ocr_tools/cell_preprocess_lab中新增cell_preprocess_lab.py和cell121_sweep.py文件,分别实现单元格裁剪图的预处理流程和参数扫描功能,支持去水印、去噪、对比度调整及OCR识别,提升OCR处理的灵活性和准确性。

zhch158_admin há 3 dias atrás
pai
commit
130984410f

+ 194 - 0
ocr_tools/cell_preprocess_lab/cell121_sweep.py

@@ -0,0 +1,194 @@
+#!/usr/bin/env python3
+"""cell121 参数扫描:去水印方式 / threshold / contrast / upscale / det 阈值 / 整格 rec。"""
+from __future__ import annotations
+
+import json
+import os
+import sys
+from itertools import product
+from pathlib import Path
+from typing import Any, Dict, List, Optional, Tuple
+
+import cv2
+import numpy as np
+
+_repo_root = Path(__file__).resolve().parents[2]
+if str(_repo_root) not in sys.path:
+    sys.path.insert(0, str(_repo_root))
+
+from ocr_utils.watermark import WatermarkProcessor, merge_watermark_config
+from ocr_utils.watermark.contrast import apply_contrast_enhancement_config
+
+CELL121 = Path(
+    "/Users/zhch158/workspace/data/流水分析/彭_广东兴宁农村商业银行/"
+    "bank_statement_yusys_local/debug/table_recognition_wired/tablecell_ocr/"
+    "彭_广东兴宁农村商业银行_page_002_0/cell121_empty_empty.png"
+)
+OUT_DIR = Path(__file__).parent / "output/彭_广东兴宁农村商业银行/cell121_sweep"
+MODEL_DIR = Path(
+    "/Users/zhch158/models/modelscope_cache/models/OpenDataLab/"
+    "PDF-Extract-Kit-1___0/models/OCR/paddleocr_torch"
+)
+
+TARGET = "20240927"
+
+
+def _upscale(img: np.ndarray, min_side: int) -> np.ndarray:
+    h, w = img.shape[:2]
+    if h >= min_side and w >= min_side:
+        return img
+    s = max(min_side / max(h, 1), min_side / max(w, 1), 1.0)
+    return cv2.resize(img, None, fx=s, fy=s, interpolation=cv2.INTER_CUBIC)
+
+
+def _preprocess(
+    raw: np.ndarray,
+    *,
+    method: str,
+    thresh: Optional[int],
+    contrast: bool,
+    upscale: int,
+) -> np.ndarray:
+    user: Dict[str, Any] = {"enabled": True, "method": method}
+    if method == "threshold" and thresh is not None:
+        user["threshold"] = thresh
+    cfg = merge_watermark_config("cell", user)
+    img, _ = WatermarkProcessor(cfg, scope="cell").process(raw, force=True)
+    if contrast:
+        gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
+        ce = dict(cfg.get("contrast_enhancement") or {})
+        ce["enabled"] = True
+        ce["text_black_target"] = 88
+        gray = apply_contrast_enhancement_config(gray, ce)
+        img = cv2.cvtColor(gray, cv2.COLOR_GRAY2BGR)
+    return _upscale(img, upscale)
+
+
+def _ocr(engine: Any, img: np.ndarray, *, det: bool, rec: bool) -> Dict[str, Any]:
+    try:
+        res = engine.ocr(img, det=det, rec=rec)
+        texts: List[str] = []
+        if res and res[0]:
+            if det:
+                for item in res[0]:
+                    if item and len(item) >= 2 and item[1]:
+                        texts.append(str(item[1][0] or ""))
+            else:
+                for item in res[0]:
+                    if isinstance(item, (list, tuple)) and len(item) >= 1:
+                        texts.append(str(item[0] or ""))
+        text = "".join(texts).strip()
+        return {
+            "text": text,
+            "det": det,
+            "rec": rec,
+            "n_boxes": len(res[0]) if res and res[0] else 0,
+        }
+    except Exception as e:
+        return {"text": "", "error": str(e), "det": det, "rec": rec}
+
+
+def _make_engine(det_thresh: float) -> Any:
+    from ocr_tools.pytorch_models.pytorch_paddle import PytorchPaddleOCR
+
+    return PytorchPaddleOCR(
+        lang="ch",
+        det_model_path=str(MODEL_DIR / "ch_PP-OCRv5_det_infer.pth"),
+        rec_model_path=str(MODEL_DIR / "ch_PP-OCRv4_rec_server_doc_infer.pth"),
+        det_db_box_thresh=det_thresh,
+    )
+
+
+def main() -> None:
+    if not CELL121.is_file():
+        raise FileNotFoundError(CELL121)
+    raw = cv2.imread(str(CELL121))
+    OUT_DIR.mkdir(parents=True, exist_ok=True)
+
+    methods = ["threshold", "masked_adaptive"]
+    thresholds = [155, 165, 170, 175, 180, None]
+    contrasts = [False, True]
+    upscales = [64, 96, 128, 192]
+    det_threshs = [0.2, 0.3, 0.4, 0.5]
+    ocr_modes = [("det_rec", True, True), ("whole_rec", False, True)]
+
+    results: List[Dict[str, Any]] = []
+    hits: List[Dict[str, Any]] = []
+    engines: Dict[float, Any] = {}
+
+    total = 0
+    for method, thresh, contrast, upscale, det_th in product(
+        methods, thresholds, contrasts, upscales, det_threshs
+    ):
+        if method != "threshold" and thresh is not None:
+            continue
+        if det_th not in engines:
+            print(f"加载 OCR det_db_box_thresh={det_th} ...")
+            engines[det_th] = _make_engine(det_th)
+
+        img = _preprocess(
+            raw, method=method, thresh=thresh, contrast=contrast, upscale=upscale
+        )
+        tag = (
+            f"{method}_t{thresh or 'd'}_c{int(contrast)}_u{upscale}_det{det_th}"
+        )
+        cv2.imwrite(str(OUT_DIR / f"{tag}.png"), img)
+
+        for mode_name, det, rec in ocr_modes:
+            total += 1
+            ocr = _ocr(engines[det_th], img, det=det, rec=rec)
+            row = {
+                "tag": tag,
+                "method": method,
+                "threshold": thresh,
+                "contrast": contrast,
+                "upscale": upscale,
+                "det_db_box_thresh": det_th,
+                "ocr_mode": mode_name,
+                **ocr,
+            }
+            results.append(row)
+            t = row.get("text", "")
+            if TARGET in t or (len(t) >= 6 and t.isdigit()):
+                row["match"] = "full" if TARGET in t else "partial"
+                hits.append(row)
+                print(f"HIT [{row['match']}] {mode_name} {tag} -> {t!r}")
+
+    # 原图对照
+    for det_th in [0.3, 0.5]:
+        if det_th not in engines:
+            engines[det_th] = _make_engine(det_th)
+        for mode_name, det, rec in ocr_modes:
+            ocr = _ocr(engines[det_th], _upscale(raw, 128), det=det, rec=rec)
+            row = {
+                "tag": "raw_upscale128",
+                "det_db_box_thresh": det_th,
+                "ocr_mode": mode_name,
+                **ocr,
+            }
+            results.append(row)
+            if TARGET in (row.get("text") or ""):
+                hits.append(row)
+
+    report = {
+        "input": str(CELL121),
+        "target": TARGET,
+        "total_trials": total,
+        "hits": hits,
+        "all_results": results,
+    }
+    out_json = OUT_DIR / "cell121_sweep_report.json"
+    out_json.write_text(json.dumps(report, ensure_ascii=False, indent=2), encoding="utf-8")
+
+    print(f"\n完成 {total} 次 OCR 试验,命中 {len(hits)} 条")
+    print(f"报告: {out_json}")
+    if hits:
+        print("\n最佳命中:")
+        for h in hits[:10]:
+            print(f"  {h.get('ocr_mode')} {h.get('tag')}: {h.get('text')!r}")
+    else:
+        print("未出现完整 20240927,请查看 cell121_sweep/*.png 与 report 中 partial 结果")
+
+
+if __name__ == "__main__":
+    main()

+ 385 - 0
ocr_tools/cell_preprocess_lab/cell_preprocess_lab.py

@@ -0,0 +1,385 @@
+#!/usr/bin/env python3
+"""
+单元格裁剪图预处理实验:去水印 →(可选去噪/对比度)→ 放大 → OCR。
+
+与 pipeline 二次 OCR 对齐,使用 ocr_tools.pytorch_models.PytorchPaddleOCR(非 paddleocr pip 包)。
+
+用法:
+    python cell_preprocess_lab.py cell219.png -o /tmp/cell_lab
+    python cell_preprocess_lab.py /path/to/tablecell_ocr/ -o /tmp/batch --compare-methods
+    python cell_preprocess_lab.py cell217.png -o /tmp/out --denoise --contrast
+"""
+from __future__ import annotations
+
+import argparse
+import json
+import os
+import sys
+from pathlib import Path
+from typing import Any, Dict, List, Optional, Tuple
+
+import cv2
+import numpy as np
+import yaml
+
+_repo_root = Path(__file__).resolve().parents[2]
+_parser_root = _repo_root / "ocr_tools" / "universal_doc_parser"
+for _p in (_repo_root, _parser_root):
+    if str(_p) not in sys.path:
+        sys.path.insert(0, str(_p))
+
+from ocr_utils.watermark import WatermarkProcessor, merge_watermark_config
+from ocr_utils.watermark.contrast import apply_contrast_enhancement_config
+
+_DEFAULT_CONFIG = (
+    _repo_root
+    / "ocr_tools/universal_doc_parser/config/bank_statement_yusys_local.yaml"
+)
+_IMAGE_SUFFIXES = {".png", ".jpg", ".jpeg", ".bmp", ".tif", ".tiff", ".webp"}
+
+_OCR_ENGINE: Any = None
+_CONFIG_PATH: Optional[Path] = None
+
+
+def _get_ocr_engine() -> Any:
+    """与 main_v2 pipeline 相同:ModelFactory → MinerU OCR atom model。"""
+    global _OCR_ENGINE
+    if _OCR_ENGINE is not None:
+        return _OCR_ENGINE
+
+    cfg_path = _CONFIG_PATH or _DEFAULT_CONFIG
+    if not cfg_path.is_file():
+        raise FileNotFoundError(f"场景配置不存在: {cfg_path}")
+
+    with open(cfg_path, encoding="utf-8") as f:
+        raw = yaml.safe_load(f) or {}
+    ocr_cfg = raw.get("ocr_recognition") or {}
+
+    errors: List[str] = []
+    try:
+        from core.model_factory import ModelFactory
+
+        recognizer = ModelFactory.create_ocr_recognizer(ocr_cfg)
+        engine = getattr(recognizer, "ocr_model", recognizer)
+        if engine is None:
+            raise RuntimeError("ocr_model 未初始化")
+        _OCR_ENGINE = engine
+        return _OCR_ENGINE
+    except Exception as e:
+        errors.append(f"ModelFactory/MinerU: {e}")
+
+    det_path = os.environ.get("OCR_DET_MODEL_PATH")
+    rec_path = os.environ.get("OCR_REC_MODEL_PATH")
+    if det_path or rec_path:
+        try:
+            from ocr_tools.pytorch_models.pytorch_paddle import PytorchPaddleOCR
+
+            kw: Dict[str, Any] = {"lang": ocr_cfg.get("language", "ch")}
+            if det_path:
+                kw["det_model_path"] = det_path
+            if rec_path:
+                kw["rec_model_path"] = rec_path
+            _OCR_ENGINE = PytorchPaddleOCR(**kw)
+            return _OCR_ENGINE
+        except Exception as e2:
+            errors.append(f"PytorchPaddleOCR(env paths): {e2}")
+
+    try:
+        from paddleocr import PaddleOCR
+
+        _OCR_ENGINE = PaddleOCR(use_angle_cls=False, lang="ch", show_log=False)
+        return _OCR_ENGINE
+    except Exception as e3:
+        errors.append(
+            f"paddleocr pip(可选 pip install paddleocr): {e3}"
+        )
+
+    raise ImportError(
+        "无法加载 OCR 引擎。请在 mineru 环境中运行,并确保场景 YAML 中 ocr_recognition "
+        f"可正常初始化(与 main_v2 相同)。详情:\n  - " + "\n  - ".join(errors)
+    )
+
+
+def _parse_rec_item(rec_item: Any) -> Tuple[str, float]:
+    if rec_item is None:
+        return "", 0.0
+    if isinstance(rec_item, tuple) and len(rec_item) >= 2:
+        txt = str(rec_item[0] or "").strip()
+        sc = float(rec_item[1] or 0.0)
+        return txt, 0.0 if not txt else sc
+    if isinstance(rec_item, list) and len(rec_item) >= 2:
+        if isinstance(rec_item[0], (list, tuple)):
+            parts: List[str] = []
+            scores: List[float] = []
+            for item in rec_item:
+                t, s = _parse_rec_item(item)
+                if t:
+                    parts.append(t)
+                    scores.append(s)
+            if not parts:
+                return "", 0.0
+            combined = "".join(parts)
+            n = sum(len(t) for t in parts)
+            return combined, sum(len(t) * s for t, s in zip(parts, scores)) / max(n, 1)
+        txt = str(rec_item[0] or "").strip()
+        sc = float(rec_item[1] or 0.0)
+        return txt, 0.0 if not txt else sc
+    return "", 0.0
+
+
+def _ocr_cell(img_bgr: np.ndarray, *, det: bool = True, rec: bool = True) -> Dict[str, Any]:
+    """整格 det+rec,与 TextFiller._recognize_whole_cell 类似。"""
+    try:
+        engine = _get_ocr_engine()
+        # paddleocr.PaddleOCR 与 PytorchPaddleOCR / MinerU 接口略有差异
+        if engine.__class__.__name__ == "PaddleOCR":
+            rgb = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2RGB)
+            res = engine.ocr(rgb, cls=False)
+            lines = []
+            if res and res[0]:
+                for item in res[0]:
+                    if item and len(item) >= 2:
+                        text, score = str(item[1][0]), float(item[1][1])
+                        lines.append({"text": text, "score": score})
+            text = "".join(ln["text"] for ln in lines)
+            sc = (
+                sum(len(ln["text"]) * ln["score"] for ln in lines) / max(len(text), 1)
+                if lines
+                else 0.0
+            )
+            return {"text": text, "score": sc, "lines": lines, "backend": "paddleocr"}
+        res = engine.ocr(img_bgr, det=det, rec=rec)
+        lines: List[Dict[str, Any]] = []
+        if res and res[0]:
+            for item in res[0]:
+                if not item or len(item) < 2:
+                    continue
+                box, rec_part = item[0], item[1]
+                text, score = _parse_rec_item(rec_part)
+                if text:
+                    lines.append({"text": text, "score": score, "box": box})
+        text = "".join(ln["text"] for ln in lines)
+        score = (
+            sum(len(ln["text"]) * ln["score"] for ln in lines) / max(len(text), 1)
+            if lines
+            else 0.0
+        )
+        return {"text": text, "score": score, "lines": lines, "mode": f"det={det},rec={rec}"}
+    except Exception as e:
+        return {
+            "text": "",
+            "score": 0.0,
+            "lines": [],
+            "error": str(e),
+            "hint": "使用: conda activate mineru && python cell_preprocess_lab.py ...",
+        }
+
+
+def _median_denoise(img: np.ndarray) -> np.ndarray:
+    return cv2.medianBlur(img, 3)
+
+
+def _upscale_min_side(img: np.ndarray, min_side: int = 64) -> np.ndarray:
+    h, w = img.shape[:2]
+    if h >= min_side and w >= min_side:
+        return img
+    scale = max(min_side / max(h, 1), min_side / max(w, 1), 1.0)
+    return cv2.resize(img, None, fx=scale, fy=scale, interpolation=cv2.INTER_CUBIC)
+
+
+def run_cell_pipeline(
+    raw_bgr: np.ndarray,
+    *,
+    wm_method: str = "masked_adaptive",
+    apply_denoise: bool = False,
+    apply_contrast: bool = False,
+    upscale_min: int = 64,
+) -> Tuple[Dict[str, np.ndarray], List[str]]:
+    stages: Dict[str, np.ndarray] = {"00_raw": raw_bgr.copy()}
+    order: List[str] = ["00_raw"]
+
+    wm_cfg = merge_watermark_config("cell", {"enabled": True, "method": wm_method})
+    proc = WatermarkProcessor(wm_cfg, scope="cell")
+    img, _ = proc.process(raw_bgr, force=True)
+    stages["01_wm"] = img
+    order.append("01_wm")
+
+    step = 2
+    if apply_denoise:
+        img = _median_denoise(img)
+        key = f"{step:02d}_denoise"
+        stages[key] = img.copy()
+        order.append(key)
+        step += 1
+
+    if apply_contrast:
+        gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY) if img.ndim == 3 else img
+        ce = dict(wm_cfg.get("contrast_enhancement") or {})
+        ce["enabled"] = True
+        gray = apply_contrast_enhancement_config(gray, ce)
+        img = cv2.cvtColor(gray, cv2.COLOR_GRAY2BGR)
+        key = f"{step:02d}_contrast"
+        stages[key] = img.copy()
+        order.append(key)
+        step += 1
+
+    img = _upscale_min_side(img, upscale_min)
+    key = f"{step:02d}_upscale"
+    stages[key] = img
+    order.append(key)
+    return stages, order
+
+
+def process_one(
+    input_path: Path,
+    output_dir: Path,
+    *,
+    compare_methods: bool = False,
+    run_ocr: bool = True,
+    apply_denoise: bool = False,
+    apply_contrast: bool = False,
+) -> Dict[str, Any]:
+    output_dir.mkdir(parents=True, exist_ok=True)
+    raw = cv2.imread(str(input_path))
+    if raw is None:
+        raise FileNotFoundError(f"无法读取: {input_path}")
+
+    report: Dict[str, Any] = {
+        "input": str(input_path),
+        "pipeline_note": (
+            "默认 01_wm→upscale,不做 median 去噪(小格易糊笔画)。"
+            "可用 --denoise / --contrast 对比。"
+        ),
+        "stages": {},
+    }
+    methods = ["threshold", "masked_adaptive"] if compare_methods else ["masked_adaptive"]
+
+    ocr_keys = {"00_raw", "01_wm"}
+    # 总是 OCR 最终 upscale 阶段
+    for method in methods:
+        sub_dir = output_dir / method if compare_methods else output_dir
+        sub_dir.mkdir(parents=True, exist_ok=True)
+        stage_imgs, order = run_cell_pipeline(
+            raw,
+            wm_method=method,
+            apply_denoise=apply_denoise,
+            apply_contrast=apply_contrast,
+        )
+        method_report: Dict[str, Any] = {"files": {}, "ocr": {}}
+        final_key = order[-1]
+        for key in order:
+            out_path = sub_dir / f"{input_path.stem}_{key}.png"
+            cv2.imwrite(str(out_path), stage_imgs[key])
+            method_report["files"][key] = str(out_path)
+            if run_ocr and (key in ocr_keys or key == final_key):
+                method_report["ocr"][key] = _ocr_cell(stage_imgs[key])
+        if run_ocr:
+            method_report["ocr_recommended"] = method_report["ocr"].get(
+                "01_wm"
+            ) or method_report["ocr"].get(final_key)
+        report["stages"][method] = method_report
+
+    report_path = output_dir / f"{input_path.stem}_lab_report.json"
+    with open(report_path, "w", encoding="utf-8") as f:
+        json.dump(report, f, ensure_ascii=False, indent=2)
+    report["report_path"] = str(report_path)
+    return report
+
+
+def collect_inputs(path: Path) -> List[Path]:
+    if path.is_file():
+        return [path]
+    files: List[Path] = []
+    for p in sorted(path.iterdir()):
+        if p.suffix.lower() in _IMAGE_SUFFIXES and "cell" in p.name:
+            files.append(p)
+    return files
+
+
+def main() -> None:
+    global _CONFIG_PATH
+    parser = argparse.ArgumentParser(description="单元格预处理实验 lab")
+    parser.add_argument("input", type=Path, help="单元格 PNG 或 tablecell_ocr 目录")
+    parser.add_argument("-o", "--output", type=Path, required=True, help="输出目录")
+    parser.add_argument(
+        "-c",
+        "--config",
+        type=Path,
+        default=_DEFAULT_CONFIG,
+        help="场景 YAML(用于加载与 pipeline 相同的 OCR)",
+    )
+    parser.add_argument(
+        "--compare-methods",
+        action="store_true",
+        help="对比 threshold 与 masked_adaptive",
+    )
+    parser.add_argument("--no-ocr", action="store_true", help="跳过 OCR 探测")
+    parser.add_argument(
+        "--denoise",
+        action="store_true",
+        help="在去水印后增加 median 去噪(默认关闭,小图易损笔画)",
+    )
+    parser.add_argument(
+        "--contrast",
+        action="store_true",
+        help="在去噪/放大前增加 text_restore 对比度",
+    )
+    parser.add_argument(
+        "--det-model-path",
+        type=Path,
+        default=None,
+        help="覆盖检测模型 .pth(或设环境变量 OCR_DET_MODEL_PATH)",
+    )
+    parser.add_argument(
+        "--rec-model-path",
+        type=Path,
+        default=None,
+        help="覆盖识别模型 .pth(或设环境变量 OCR_REC_MODEL_PATH)",
+    )
+    args = parser.parse_args()
+    _CONFIG_PATH = args.config
+    if args.det_model_path:
+        os.environ["OCR_DET_MODEL_PATH"] = str(args.det_model_path)
+    if args.rec_model_path:
+        os.environ["OCR_REC_MODEL_PATH"] = str(args.rec_model_path)
+
+    inputs = collect_inputs(args.input)
+    if not inputs:
+        print(f"未找到输入: {args.input}")
+        sys.exit(1)
+
+    for inp in inputs:
+        out = args.output / inp.stem if len(inputs) > 1 else args.output
+        report = process_one(
+            inp,
+            out,
+            compare_methods=args.compare_methods,
+            run_ocr=not args.no_ocr,
+            apply_denoise=args.denoise,
+            apply_contrast=args.contrast,
+        )
+        print(json.dumps(report, ensure_ascii=False, indent=2))
+
+
+if __name__ == "__main__":
+    if len(sys.argv) == 1:
+        print("ℹ️  未提供命令行参数,使用默认配置运行...")
+        default_config = {
+            # "input": "/Users/zhch158/workspace/data/流水分析/彭_广东兴宁农村商业银行/bank_statement_yusys_local/debug/table_recognition_wired/tablecell_ocr/彭_广东兴宁农村商业银行_page_002_0/cell029_whole_78.0111.0111.078.0司.png",
+            "input": "/Users/zhch158/workspace/data/流水分析/彭_广东兴宁农村商业银行/bank_statement_yusys_local/debug/table_recognition_wired/tablecell_ocr/彭_广东兴宁农村商业银行_page_002_0/cell121_empty_empty.png",
+            # "input": "/Users/zhch158/workspace/data/流水分析/彭_广东兴宁农村商业银行/bank_statement_yusys_local/debug/table_recognition_wired/tablecell_ocr/彭_广东兴宁农村商业银行_page_002_0/cell217_lines_取款.png",
+            # "input": "/Users/zhch158/workspace/data/流水分析/彭_广东兴宁农村商业银行/bank_statement_yusys_local/debug/table_recognition_wired/tablecell_ocr/彭_广东兴宁农村商业银行_page_002_0/cell219_empty_empty.png",
+            "output": "./output/彭_广东兴宁农村商业银行",
+            "compare-methods": True,
+        }
+        sys.argv = [sys.argv[0], default_config["input"]]
+        for key, value in default_config.items():
+            if key == "input":
+                continue
+            flag = f"--{key.replace('_', '-')}"
+            if isinstance(value, bool) and value:
+                sys.argv.append(flag)
+            elif not isinstance(value, bool):
+                sys.argv.extend([flag, str(value)])
+
+    sys.exit(main())