| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385 |
- #!/usr/bin/env python3
- """
- 单元格裁剪图预处理实验:去水印 →(可选去噪/对比度)→ 放大 → OCR。
- 与 pipeline 二次 OCR 对齐,使用 ocr_tools.pytorch_models.PytorchPaddleOCR(非 paddleocr pip 包)。
- 用法:
- python cell_preprocess_lab.py cell219.png -o /tmp/cell_lab
- python cell_preprocess_lab.py /path/to/tablecell_ocr/ -o /tmp/batch --compare-methods
- python cell_preprocess_lab.py cell217.png -o /tmp/out --denoise --contrast
- """
- from __future__ import annotations
- import argparse
- import json
- import os
- import sys
- from pathlib import Path
- from typing import Any, Dict, List, Optional, Tuple
- import cv2
- import numpy as np
- import yaml
- _repo_root = Path(__file__).resolve().parents[2]
- _parser_root = _repo_root / "ocr_tools" / "universal_doc_parser"
- for _p in (_repo_root, _parser_root):
- if str(_p) not in sys.path:
- sys.path.insert(0, str(_p))
- from ocr_utils.watermark import WatermarkProcessor, merge_watermark_config
- from ocr_utils.watermark.contrast import apply_contrast_enhancement_config
- _DEFAULT_CONFIG = (
- _repo_root
- / "ocr_tools/universal_doc_parser/config/bank_statement_yusys_local.yaml"
- )
- _IMAGE_SUFFIXES = {".png", ".jpg", ".jpeg", ".bmp", ".tif", ".tiff", ".webp"}
- _OCR_ENGINE: Any = None
- _CONFIG_PATH: Optional[Path] = None
- def _get_ocr_engine() -> Any:
- """与 main_v2 pipeline 相同:ModelFactory → MinerU OCR atom model。"""
- global _OCR_ENGINE
- if _OCR_ENGINE is not None:
- return _OCR_ENGINE
- cfg_path = _CONFIG_PATH or _DEFAULT_CONFIG
- if not cfg_path.is_file():
- raise FileNotFoundError(f"场景配置不存在: {cfg_path}")
- with open(cfg_path, encoding="utf-8") as f:
- raw = yaml.safe_load(f) or {}
- ocr_cfg = raw.get("ocr_recognition") or {}
- errors: List[str] = []
- try:
- from core.model_factory import ModelFactory
- recognizer = ModelFactory.create_ocr_recognizer(ocr_cfg)
- engine = getattr(recognizer, "ocr_model", recognizer)
- if engine is None:
- raise RuntimeError("ocr_model 未初始化")
- _OCR_ENGINE = engine
- return _OCR_ENGINE
- except Exception as e:
- errors.append(f"ModelFactory/MinerU: {e}")
- det_path = os.environ.get("OCR_DET_MODEL_PATH")
- rec_path = os.environ.get("OCR_REC_MODEL_PATH")
- if det_path or rec_path:
- try:
- from ocr_tools.pytorch_models.pytorch_paddle import PytorchPaddleOCR
- kw: Dict[str, Any] = {"lang": ocr_cfg.get("language", "ch")}
- if det_path:
- kw["det_model_path"] = det_path
- if rec_path:
- kw["rec_model_path"] = rec_path
- _OCR_ENGINE = PytorchPaddleOCR(**kw)
- return _OCR_ENGINE
- except Exception as e2:
- errors.append(f"PytorchPaddleOCR(env paths): {e2}")
- try:
- from paddleocr import PaddleOCR
- _OCR_ENGINE = PaddleOCR(use_angle_cls=False, lang="ch", show_log=False)
- return _OCR_ENGINE
- except Exception as e3:
- errors.append(
- f"paddleocr pip(可选 pip install paddleocr): {e3}"
- )
- raise ImportError(
- "无法加载 OCR 引擎。请在 mineru 环境中运行,并确保场景 YAML 中 ocr_recognition "
- f"可正常初始化(与 main_v2 相同)。详情:\n - " + "\n - ".join(errors)
- )
- def _parse_rec_item(rec_item: Any) -> Tuple[str, float]:
- if rec_item is None:
- return "", 0.0
- if isinstance(rec_item, tuple) and len(rec_item) >= 2:
- txt = str(rec_item[0] or "").strip()
- sc = float(rec_item[1] or 0.0)
- return txt, 0.0 if not txt else sc
- if isinstance(rec_item, list) and len(rec_item) >= 2:
- if isinstance(rec_item[0], (list, tuple)):
- parts: List[str] = []
- scores: List[float] = []
- for item in rec_item:
- t, s = _parse_rec_item(item)
- if t:
- parts.append(t)
- scores.append(s)
- if not parts:
- return "", 0.0
- combined = "".join(parts)
- n = sum(len(t) for t in parts)
- return combined, sum(len(t) * s for t, s in zip(parts, scores)) / max(n, 1)
- txt = str(rec_item[0] or "").strip()
- sc = float(rec_item[1] or 0.0)
- return txt, 0.0 if not txt else sc
- return "", 0.0
- def _ocr_cell(img_bgr: np.ndarray, *, det: bool = True, rec: bool = True) -> Dict[str, Any]:
- """整格 det+rec,与 TextFiller._recognize_whole_cell 类似。"""
- try:
- engine = _get_ocr_engine()
- # paddleocr.PaddleOCR 与 PytorchPaddleOCR / MinerU 接口略有差异
- if engine.__class__.__name__ == "PaddleOCR":
- rgb = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2RGB)
- res = engine.ocr(rgb, cls=False)
- lines = []
- if res and res[0]:
- for item in res[0]:
- if item and len(item) >= 2:
- text, score = str(item[1][0]), float(item[1][1])
- lines.append({"text": text, "score": score})
- text = "".join(ln["text"] for ln in lines)
- sc = (
- sum(len(ln["text"]) * ln["score"] for ln in lines) / max(len(text), 1)
- if lines
- else 0.0
- )
- return {"text": text, "score": sc, "lines": lines, "backend": "paddleocr"}
- res = engine.ocr(img_bgr, det=det, rec=rec)
- lines: List[Dict[str, Any]] = []
- if res and res[0]:
- for item in res[0]:
- if not item or len(item) < 2:
- continue
- box, rec_part = item[0], item[1]
- text, score = _parse_rec_item(rec_part)
- if text:
- lines.append({"text": text, "score": score, "box": box})
- text = "".join(ln["text"] for ln in lines)
- score = (
- sum(len(ln["text"]) * ln["score"] for ln in lines) / max(len(text), 1)
- if lines
- else 0.0
- )
- return {"text": text, "score": score, "lines": lines, "mode": f"det={det},rec={rec}"}
- except Exception as e:
- return {
- "text": "",
- "score": 0.0,
- "lines": [],
- "error": str(e),
- "hint": "使用: conda activate mineru && python cell_preprocess_lab.py ...",
- }
- def _median_denoise(img: np.ndarray) -> np.ndarray:
- return cv2.medianBlur(img, 3)
- def _upscale_min_side(img: np.ndarray, min_side: int = 64) -> np.ndarray:
- h, w = img.shape[:2]
- if h >= min_side and w >= min_side:
- return img
- scale = max(min_side / max(h, 1), min_side / max(w, 1), 1.0)
- return cv2.resize(img, None, fx=scale, fy=scale, interpolation=cv2.INTER_CUBIC)
- def run_cell_pipeline(
- raw_bgr: np.ndarray,
- *,
- wm_method: str = "masked_adaptive",
- apply_denoise: bool = False,
- apply_contrast: bool = False,
- upscale_min: int = 64,
- ) -> Tuple[Dict[str, np.ndarray], List[str]]:
- stages: Dict[str, np.ndarray] = {"00_raw": raw_bgr.copy()}
- order: List[str] = ["00_raw"]
- wm_cfg = merge_watermark_config("cell", {"enabled": True, "method": wm_method})
- proc = WatermarkProcessor(wm_cfg, scope="cell")
- img, _ = proc.process(raw_bgr, force=True)
- stages["01_wm"] = img
- order.append("01_wm")
- step = 2
- if apply_denoise:
- img = _median_denoise(img)
- key = f"{step:02d}_denoise"
- stages[key] = img.copy()
- order.append(key)
- step += 1
- if apply_contrast:
- gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY) if img.ndim == 3 else img
- ce = dict(wm_cfg.get("contrast_enhancement") or {})
- ce["enabled"] = True
- gray = apply_contrast_enhancement_config(gray, ce)
- img = cv2.cvtColor(gray, cv2.COLOR_GRAY2BGR)
- key = f"{step:02d}_contrast"
- stages[key] = img.copy()
- order.append(key)
- step += 1
- img = _upscale_min_side(img, upscale_min)
- key = f"{step:02d}_upscale"
- stages[key] = img
- order.append(key)
- return stages, order
- def process_one(
- input_path: Path,
- output_dir: Path,
- *,
- compare_methods: bool = False,
- run_ocr: bool = True,
- apply_denoise: bool = False,
- apply_contrast: bool = False,
- ) -> Dict[str, Any]:
- output_dir.mkdir(parents=True, exist_ok=True)
- raw = cv2.imread(str(input_path))
- if raw is None:
- raise FileNotFoundError(f"无法读取: {input_path}")
- report: Dict[str, Any] = {
- "input": str(input_path),
- "pipeline_note": (
- "默认 01_wm→upscale,不做 median 去噪(小格易糊笔画)。"
- "可用 --denoise / --contrast 对比。"
- ),
- "stages": {},
- }
- methods = ["threshold", "masked_adaptive"] if compare_methods else ["masked_adaptive"]
- ocr_keys = {"00_raw", "01_wm"}
- # 总是 OCR 最终 upscale 阶段
- for method in methods:
- sub_dir = output_dir / method if compare_methods else output_dir
- sub_dir.mkdir(parents=True, exist_ok=True)
- stage_imgs, order = run_cell_pipeline(
- raw,
- wm_method=method,
- apply_denoise=apply_denoise,
- apply_contrast=apply_contrast,
- )
- method_report: Dict[str, Any] = {"files": {}, "ocr": {}}
- final_key = order[-1]
- for key in order:
- out_path = sub_dir / f"{input_path.stem}_{key}.png"
- cv2.imwrite(str(out_path), stage_imgs[key])
- method_report["files"][key] = str(out_path)
- if run_ocr and (key in ocr_keys or key == final_key):
- method_report["ocr"][key] = _ocr_cell(stage_imgs[key])
- if run_ocr:
- method_report["ocr_recommended"] = method_report["ocr"].get(
- "01_wm"
- ) or method_report["ocr"].get(final_key)
- report["stages"][method] = method_report
- report_path = output_dir / f"{input_path.stem}_lab_report.json"
- with open(report_path, "w", encoding="utf-8") as f:
- json.dump(report, f, ensure_ascii=False, indent=2)
- report["report_path"] = str(report_path)
- return report
- def collect_inputs(path: Path) -> List[Path]:
- if path.is_file():
- return [path]
- files: List[Path] = []
- for p in sorted(path.iterdir()):
- if p.suffix.lower() in _IMAGE_SUFFIXES and "cell" in p.name:
- files.append(p)
- return files
- def main() -> None:
- global _CONFIG_PATH
- parser = argparse.ArgumentParser(description="单元格预处理实验 lab")
- parser.add_argument("input", type=Path, help="单元格 PNG 或 tablecell_ocr 目录")
- parser.add_argument("-o", "--output", type=Path, required=True, help="输出目录")
- parser.add_argument(
- "-c",
- "--config",
- type=Path,
- default=_DEFAULT_CONFIG,
- help="场景 YAML(用于加载与 pipeline 相同的 OCR)",
- )
- parser.add_argument(
- "--compare-methods",
- action="store_true",
- help="对比 threshold 与 masked_adaptive",
- )
- parser.add_argument("--no-ocr", action="store_true", help="跳过 OCR 探测")
- parser.add_argument(
- "--denoise",
- action="store_true",
- help="在去水印后增加 median 去噪(默认关闭,小图易损笔画)",
- )
- parser.add_argument(
- "--contrast",
- action="store_true",
- help="在去噪/放大前增加 text_restore 对比度",
- )
- parser.add_argument(
- "--det-model-path",
- type=Path,
- default=None,
- help="覆盖检测模型 .pth(或设环境变量 OCR_DET_MODEL_PATH)",
- )
- parser.add_argument(
- "--rec-model-path",
- type=Path,
- default=None,
- help="覆盖识别模型 .pth(或设环境变量 OCR_REC_MODEL_PATH)",
- )
- args = parser.parse_args()
- _CONFIG_PATH = args.config
- if args.det_model_path:
- os.environ["OCR_DET_MODEL_PATH"] = str(args.det_model_path)
- if args.rec_model_path:
- os.environ["OCR_REC_MODEL_PATH"] = str(args.rec_model_path)
- inputs = collect_inputs(args.input)
- if not inputs:
- print(f"未找到输入: {args.input}")
- sys.exit(1)
- for inp in inputs:
- out = args.output / inp.stem if len(inputs) > 1 else args.output
- report = process_one(
- inp,
- out,
- compare_methods=args.compare_methods,
- run_ocr=not args.no_ocr,
- apply_denoise=args.denoise,
- apply_contrast=args.contrast,
- )
- print(json.dumps(report, ensure_ascii=False, indent=2))
- if __name__ == "__main__":
- if len(sys.argv) == 1:
- print("ℹ️ 未提供命令行参数,使用默认配置运行...")
- default_config = {
- # "input": "/Users/zhch158/workspace/data/流水分析/彭_广东兴宁农村商业银行/bank_statement_yusys_local/debug/table_recognition_wired/tablecell_ocr/彭_广东兴宁农村商业银行_page_002_0/cell029_whole_78.0111.0111.078.0司.png",
- "input": "/Users/zhch158/workspace/data/流水分析/彭_广东兴宁农村商业银行/bank_statement_yusys_local/debug/table_recognition_wired/tablecell_ocr/彭_广东兴宁农村商业银行_page_002_0/cell121_empty_empty.png",
- # "input": "/Users/zhch158/workspace/data/流水分析/彭_广东兴宁农村商业银行/bank_statement_yusys_local/debug/table_recognition_wired/tablecell_ocr/彭_广东兴宁农村商业银行_page_002_0/cell217_lines_取款.png",
- # "input": "/Users/zhch158/workspace/data/流水分析/彭_广东兴宁农村商业银行/bank_statement_yusys_local/debug/table_recognition_wired/tablecell_ocr/彭_广东兴宁农村商业银行_page_002_0/cell219_empty_empty.png",
- "output": "./output/彭_广东兴宁农村商业银行",
- "compare-methods": True,
- }
- sys.argv = [sys.argv[0], default_config["input"]]
- for key, value in default_config.items():
- if key == "input":
- continue
- flag = f"--{key.replace('_', '-')}"
- if isinstance(value, bool) and value:
- sys.argv.append(flag)
- elif not isinstance(value, bool):
- sys.argv.extend([flag, str(value)])
- sys.exit(main())
|