#!/usr/bin/env python3 """ 单元格裁剪图预处理实验:去水印 →(可选去噪/对比度)→ 放大 → OCR。 与 pipeline 二次 OCR 对齐,使用 ocr_tools.pytorch_models.PytorchPaddleOCR(非 paddleocr pip 包)。 用法: python cell_preprocess_lab.py cell219.png -o /tmp/cell_lab python cell_preprocess_lab.py /path/to/tablecell_ocr/ -o /tmp/batch --compare-methods python cell_preprocess_lab.py cell217.png -o /tmp/out --denoise --contrast """ from __future__ import annotations import argparse import json import os import sys from pathlib import Path from typing import Any, Dict, List, Optional, Tuple import cv2 import numpy as np import yaml _repo_root = Path(__file__).resolve().parents[2] _parser_root = _repo_root / "ocr_tools" / "universal_doc_parser" for _p in (_repo_root, _parser_root): if str(_p) not in sys.path: sys.path.insert(0, str(_p)) from ocr_utils.watermark import WatermarkProcessor, merge_watermark_config from ocr_utils.watermark.contrast import apply_contrast_enhancement_config _DEFAULT_CONFIG = ( _repo_root / "ocr_tools/universal_doc_parser/config/bank_statement_yusys_local.yaml" ) _IMAGE_SUFFIXES = {".png", ".jpg", ".jpeg", ".bmp", ".tif", ".tiff", ".webp"} _OCR_ENGINE: Any = None _CONFIG_PATH: Optional[Path] = None def _get_ocr_engine() -> Any: """与 main_v2 pipeline 相同:ModelFactory → MinerU OCR atom model。""" global _OCR_ENGINE if _OCR_ENGINE is not None: return _OCR_ENGINE cfg_path = _CONFIG_PATH or _DEFAULT_CONFIG if not cfg_path.is_file(): raise FileNotFoundError(f"场景配置不存在: {cfg_path}") with open(cfg_path, encoding="utf-8") as f: raw = yaml.safe_load(f) or {} ocr_cfg = raw.get("ocr_recognition") or {} errors: List[str] = [] try: from core.model_factory import ModelFactory recognizer = ModelFactory.create_ocr_recognizer(ocr_cfg) engine = getattr(recognizer, "ocr_model", recognizer) if engine is None: raise RuntimeError("ocr_model 未初始化") _OCR_ENGINE = engine return _OCR_ENGINE except Exception as e: errors.append(f"ModelFactory/MinerU: {e}") det_path = os.environ.get("OCR_DET_MODEL_PATH") rec_path = os.environ.get("OCR_REC_MODEL_PATH") if det_path or rec_path: try: from ocr_tools.pytorch_models.pytorch_paddle import PytorchPaddleOCR kw: Dict[str, Any] = {"lang": ocr_cfg.get("language", "ch")} if det_path: kw["det_model_path"] = det_path if rec_path: kw["rec_model_path"] = rec_path _OCR_ENGINE = PytorchPaddleOCR(**kw) return _OCR_ENGINE except Exception as e2: errors.append(f"PytorchPaddleOCR(env paths): {e2}") try: from paddleocr import PaddleOCR _OCR_ENGINE = PaddleOCR(use_angle_cls=False, lang="ch", show_log=False) return _OCR_ENGINE except Exception as e3: errors.append( f"paddleocr pip(可选 pip install paddleocr): {e3}" ) raise ImportError( "无法加载 OCR 引擎。请在 mineru 环境中运行,并确保场景 YAML 中 ocr_recognition " f"可正常初始化(与 main_v2 相同)。详情:\n - " + "\n - ".join(errors) ) def _parse_rec_item(rec_item: Any) -> Tuple[str, float]: if rec_item is None: return "", 0.0 if isinstance(rec_item, tuple) and len(rec_item) >= 2: txt = str(rec_item[0] or "").strip() sc = float(rec_item[1] or 0.0) return txt, 0.0 if not txt else sc if isinstance(rec_item, list) and len(rec_item) >= 2: if isinstance(rec_item[0], (list, tuple)): parts: List[str] = [] scores: List[float] = [] for item in rec_item: t, s = _parse_rec_item(item) if t: parts.append(t) scores.append(s) if not parts: return "", 0.0 combined = "".join(parts) n = sum(len(t) for t in parts) return combined, sum(len(t) * s for t, s in zip(parts, scores)) / max(n, 1) txt = str(rec_item[0] or "").strip() sc = float(rec_item[1] or 0.0) return txt, 0.0 if not txt else sc return "", 0.0 def _ocr_cell(img_bgr: np.ndarray, *, det: bool = True, rec: bool = True) -> Dict[str, Any]: """整格 det+rec,与 TextFiller._recognize_whole_cell 类似。""" try: engine = _get_ocr_engine() # paddleocr.PaddleOCR 与 PytorchPaddleOCR / MinerU 接口略有差异 if engine.__class__.__name__ == "PaddleOCR": rgb = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2RGB) res = engine.ocr(rgb, cls=False) lines = [] if res and res[0]: for item in res[0]: if item and len(item) >= 2: text, score = str(item[1][0]), float(item[1][1]) lines.append({"text": text, "score": score}) text = "".join(ln["text"] for ln in lines) sc = ( sum(len(ln["text"]) * ln["score"] for ln in lines) / max(len(text), 1) if lines else 0.0 ) return {"text": text, "score": sc, "lines": lines, "backend": "paddleocr"} res = engine.ocr(img_bgr, det=det, rec=rec) lines: List[Dict[str, Any]] = [] if res and res[0]: for item in res[0]: if not item or len(item) < 2: continue box, rec_part = item[0], item[1] text, score = _parse_rec_item(rec_part) if text: lines.append({"text": text, "score": score, "box": box}) text = "".join(ln["text"] for ln in lines) score = ( sum(len(ln["text"]) * ln["score"] for ln in lines) / max(len(text), 1) if lines else 0.0 ) return {"text": text, "score": score, "lines": lines, "mode": f"det={det},rec={rec}"} except Exception as e: return { "text": "", "score": 0.0, "lines": [], "error": str(e), "hint": "使用: conda activate mineru && python cell_preprocess_lab.py ...", } def _median_denoise(img: np.ndarray) -> np.ndarray: return cv2.medianBlur(img, 3) def _upscale_min_side(img: np.ndarray, min_side: int = 64) -> np.ndarray: h, w = img.shape[:2] if h >= min_side and w >= min_side: return img scale = max(min_side / max(h, 1), min_side / max(w, 1), 1.0) return cv2.resize(img, None, fx=scale, fy=scale, interpolation=cv2.INTER_CUBIC) def run_cell_pipeline( raw_bgr: np.ndarray, *, wm_method: str = "masked_adaptive", apply_denoise: bool = False, apply_contrast: bool = False, upscale_min: int = 64, ) -> Tuple[Dict[str, np.ndarray], List[str]]: stages: Dict[str, np.ndarray] = {"00_raw": raw_bgr.copy()} order: List[str] = ["00_raw"] wm_cfg = merge_watermark_config("cell", {"enabled": True, "method": wm_method}) proc = WatermarkProcessor(wm_cfg, scope="cell") img, _ = proc.process(raw_bgr, force=True) stages["01_wm"] = img order.append("01_wm") step = 2 if apply_denoise: img = _median_denoise(img) key = f"{step:02d}_denoise" stages[key] = img.copy() order.append(key) step += 1 if apply_contrast: gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY) if img.ndim == 3 else img ce = dict(wm_cfg.get("contrast_enhancement") or {}) ce["enabled"] = True gray = apply_contrast_enhancement_config(gray, ce) img = cv2.cvtColor(gray, cv2.COLOR_GRAY2BGR) key = f"{step:02d}_contrast" stages[key] = img.copy() order.append(key) step += 1 img = _upscale_min_side(img, upscale_min) key = f"{step:02d}_upscale" stages[key] = img order.append(key) return stages, order def process_one( input_path: Path, output_dir: Path, *, compare_methods: bool = False, run_ocr: bool = True, apply_denoise: bool = False, apply_contrast: bool = False, ) -> Dict[str, Any]: output_dir.mkdir(parents=True, exist_ok=True) raw = cv2.imread(str(input_path)) if raw is None: raise FileNotFoundError(f"无法读取: {input_path}") report: Dict[str, Any] = { "input": str(input_path), "pipeline_note": ( "默认 01_wm→upscale,不做 median 去噪(小格易糊笔画)。" "可用 --denoise / --contrast 对比。" ), "stages": {}, } methods = ["threshold", "masked_adaptive"] if compare_methods else ["masked_adaptive"] ocr_keys = {"00_raw", "01_wm"} # 总是 OCR 最终 upscale 阶段 for method in methods: sub_dir = output_dir / method if compare_methods else output_dir sub_dir.mkdir(parents=True, exist_ok=True) stage_imgs, order = run_cell_pipeline( raw, wm_method=method, apply_denoise=apply_denoise, apply_contrast=apply_contrast, ) method_report: Dict[str, Any] = {"files": {}, "ocr": {}} final_key = order[-1] for key in order: out_path = sub_dir / f"{input_path.stem}_{key}.png" cv2.imwrite(str(out_path), stage_imgs[key]) method_report["files"][key] = str(out_path) if run_ocr and (key in ocr_keys or key == final_key): method_report["ocr"][key] = _ocr_cell(stage_imgs[key]) if run_ocr: method_report["ocr_recommended"] = method_report["ocr"].get( "01_wm" ) or method_report["ocr"].get(final_key) report["stages"][method] = method_report report_path = output_dir / f"{input_path.stem}_lab_report.json" with open(report_path, "w", encoding="utf-8") as f: json.dump(report, f, ensure_ascii=False, indent=2) report["report_path"] = str(report_path) return report def collect_inputs(path: Path) -> List[Path]: if path.is_file(): return [path] files: List[Path] = [] for p in sorted(path.iterdir()): if p.suffix.lower() in _IMAGE_SUFFIXES and "cell" in p.name: files.append(p) return files def main() -> None: global _CONFIG_PATH parser = argparse.ArgumentParser(description="单元格预处理实验 lab") parser.add_argument("input", type=Path, help="单元格 PNG 或 tablecell_ocr 目录") parser.add_argument("-o", "--output", type=Path, required=True, help="输出目录") parser.add_argument( "-c", "--config", type=Path, default=_DEFAULT_CONFIG, help="场景 YAML(用于加载与 pipeline 相同的 OCR)", ) parser.add_argument( "--compare-methods", action="store_true", help="对比 threshold 与 masked_adaptive", ) parser.add_argument("--no-ocr", action="store_true", help="跳过 OCR 探测") parser.add_argument( "--denoise", action="store_true", help="在去水印后增加 median 去噪(默认关闭,小图易损笔画)", ) parser.add_argument( "--contrast", action="store_true", help="在去噪/放大前增加 text_restore 对比度", ) parser.add_argument( "--det-model-path", type=Path, default=None, help="覆盖检测模型 .pth(或设环境变量 OCR_DET_MODEL_PATH)", ) parser.add_argument( "--rec-model-path", type=Path, default=None, help="覆盖识别模型 .pth(或设环境变量 OCR_REC_MODEL_PATH)", ) args = parser.parse_args() _CONFIG_PATH = args.config if args.det_model_path: os.environ["OCR_DET_MODEL_PATH"] = str(args.det_model_path) if args.rec_model_path: os.environ["OCR_REC_MODEL_PATH"] = str(args.rec_model_path) inputs = collect_inputs(args.input) if not inputs: print(f"未找到输入: {args.input}") sys.exit(1) for inp in inputs: out = args.output / inp.stem if len(inputs) > 1 else args.output report = process_one( inp, out, compare_methods=args.compare_methods, run_ocr=not args.no_ocr, apply_denoise=args.denoise, apply_contrast=args.contrast, ) print(json.dumps(report, ensure_ascii=False, indent=2)) if __name__ == "__main__": if len(sys.argv) == 1: print("ℹ️ 未提供命令行参数,使用默认配置运行...") default_config = { # "input": "/Users/zhch158/workspace/data/流水分析/彭_广东兴宁农村商业银行/bank_statement_yusys_local/debug/table_recognition_wired/tablecell_ocr/彭_广东兴宁农村商业银行_page_002_0/cell029_whole_78.0111.0111.078.0司.png", "input": "/Users/zhch158/workspace/data/流水分析/彭_广东兴宁农村商业银行/bank_statement_yusys_local/debug/table_recognition_wired/tablecell_ocr/彭_广东兴宁农村商业银行_page_002_0/cell121_empty_empty.png", # "input": "/Users/zhch158/workspace/data/流水分析/彭_广东兴宁农村商业银行/bank_statement_yusys_local/debug/table_recognition_wired/tablecell_ocr/彭_广东兴宁农村商业银行_page_002_0/cell217_lines_取款.png", # "input": "/Users/zhch158/workspace/data/流水分析/彭_广东兴宁农村商业银行/bank_statement_yusys_local/debug/table_recognition_wired/tablecell_ocr/彭_广东兴宁农村商业银行_page_002_0/cell219_empty_empty.png", "output": "./output/彭_广东兴宁农村商业银行", "compare-methods": True, } sys.argv = [sys.argv[0], default_config["input"]] for key, value in default_config.items(): if key == "input": continue flag = f"--{key.replace('_', '-')}" if isinstance(value, bool) and value: sys.argv.append(flag) elif not isinstance(value, bool): sys.argv.extend([flag, str(value)]) sys.exit(main())