cell_preprocess_lab.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388
  1. #!/usr/bin/env python3
  2. """
  3. 单元格裁剪图预处理实验:去水印 →(可选去噪/对比度)→ 放大 → OCR。
  4. 与 pipeline 二次 OCR 对齐,使用 ocr_tools.pytorch_models.PytorchPaddleOCR(非 paddleocr pip 包)。
  5. 用法:
  6. python cell_preprocess_lab.py cell219.png -o /tmp/cell_lab
  7. python cell_preprocess_lab.py /path/to/tablecell_ocr/ -o /tmp/batch --compare-methods
  8. python cell_preprocess_lab.py cell217.png -o /tmp/out --denoise --contrast
  9. 参数网格扫描见 cell_sweep.py:
  10. python cell_sweep.py cell219_empty_empty_raw.png -o ./out -t "ATM存折取款"
  11. """
  12. from __future__ import annotations
  13. import argparse
  14. import json
  15. import os
  16. import sys
  17. from pathlib import Path
  18. from typing import Any, Dict, List, Optional, Tuple
  19. import cv2
  20. import numpy as np
  21. import yaml
  22. _repo_root = Path(__file__).resolve().parents[2]
  23. _parser_root = _repo_root / "ocr_tools" / "universal_doc_parser"
  24. for _p in (_repo_root, _parser_root):
  25. if str(_p) not in sys.path:
  26. sys.path.insert(0, str(_p))
  27. from ocr_utils.watermark import WatermarkProcessor, merge_watermark_config
  28. from ocr_utils.watermark.contrast import apply_contrast_enhancement_config
  29. _DEFAULT_CONFIG = (
  30. _repo_root
  31. / "ocr_tools/universal_doc_parser/config/bank_statement_yusys_local.yaml"
  32. )
  33. _IMAGE_SUFFIXES = {".png", ".jpg", ".jpeg", ".bmp", ".tif", ".tiff", ".webp"}
  34. _OCR_ENGINE: Any = None
  35. _CONFIG_PATH: Optional[Path] = None
  36. def _get_ocr_engine() -> Any:
  37. """与 main_v2 pipeline 相同:ModelFactory → MinerU OCR atom model。"""
  38. global _OCR_ENGINE
  39. if _OCR_ENGINE is not None:
  40. return _OCR_ENGINE
  41. cfg_path = _CONFIG_PATH or _DEFAULT_CONFIG
  42. if not cfg_path.is_file():
  43. raise FileNotFoundError(f"场景配置不存在: {cfg_path}")
  44. with open(cfg_path, encoding="utf-8") as f:
  45. raw = yaml.safe_load(f) or {}
  46. ocr_cfg = raw.get("ocr_recognition") or {}
  47. errors: List[str] = []
  48. try:
  49. from core.model_factory import ModelFactory
  50. recognizer = ModelFactory.create_ocr_recognizer(ocr_cfg)
  51. engine = getattr(recognizer, "ocr_model", recognizer)
  52. if engine is None:
  53. raise RuntimeError("ocr_model 未初始化")
  54. _OCR_ENGINE = engine
  55. return _OCR_ENGINE
  56. except Exception as e:
  57. errors.append(f"ModelFactory/MinerU: {e}")
  58. det_path = os.environ.get("OCR_DET_MODEL_PATH")
  59. rec_path = os.environ.get("OCR_REC_MODEL_PATH")
  60. if det_path or rec_path:
  61. try:
  62. from ocr_tools.pytorch_models.pytorch_paddle import PytorchPaddleOCR
  63. kw: Dict[str, Any] = {"lang": ocr_cfg.get("language", "ch")}
  64. if det_path:
  65. kw["det_model_path"] = det_path
  66. if rec_path:
  67. kw["rec_model_path"] = rec_path
  68. _OCR_ENGINE = PytorchPaddleOCR(**kw)
  69. return _OCR_ENGINE
  70. except Exception as e2:
  71. errors.append(f"PytorchPaddleOCR(env paths): {e2}")
  72. try:
  73. from paddleocr import PaddleOCR
  74. _OCR_ENGINE = PaddleOCR(use_angle_cls=False, lang="ch", show_log=False)
  75. return _OCR_ENGINE
  76. except Exception as e3:
  77. errors.append(
  78. f"paddleocr pip(可选 pip install paddleocr): {e3}"
  79. )
  80. raise ImportError(
  81. "无法加载 OCR 引擎。请在 mineru 环境中运行,并确保场景 YAML 中 ocr_recognition "
  82. f"可正常初始化(与 main_v2 相同)。详情:\n - " + "\n - ".join(errors)
  83. )
  84. def _parse_rec_item(rec_item: Any) -> Tuple[str, float]:
  85. if rec_item is None:
  86. return "", 0.0
  87. if isinstance(rec_item, tuple) and len(rec_item) >= 2:
  88. txt = str(rec_item[0] or "").strip()
  89. sc = float(rec_item[1] or 0.0)
  90. return txt, 0.0 if not txt else sc
  91. if isinstance(rec_item, list) and len(rec_item) >= 2:
  92. if isinstance(rec_item[0], (list, tuple)):
  93. parts: List[str] = []
  94. scores: List[float] = []
  95. for item in rec_item:
  96. t, s = _parse_rec_item(item)
  97. if t:
  98. parts.append(t)
  99. scores.append(s)
  100. if not parts:
  101. return "", 0.0
  102. combined = "".join(parts)
  103. n = sum(len(t) for t in parts)
  104. return combined, sum(len(t) * s for t, s in zip(parts, scores)) / max(n, 1)
  105. txt = str(rec_item[0] or "").strip()
  106. sc = float(rec_item[1] or 0.0)
  107. return txt, 0.0 if not txt else sc
  108. return "", 0.0
  109. def _ocr_cell(img_bgr: np.ndarray, *, det: bool = True, rec: bool = True) -> Dict[str, Any]:
  110. """整格 det+rec,与 TextFiller._recognize_whole_cell 类似。"""
  111. try:
  112. engine = _get_ocr_engine()
  113. # paddleocr.PaddleOCR 与 PytorchPaddleOCR / MinerU 接口略有差异
  114. if engine.__class__.__name__ == "PaddleOCR":
  115. rgb = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2RGB)
  116. res = engine.ocr(rgb, cls=False)
  117. lines = []
  118. if res and res[0]:
  119. for item in res[0]:
  120. if item and len(item) >= 2:
  121. text, score = str(item[1][0]), float(item[1][1])
  122. lines.append({"text": text, "score": score})
  123. text = "".join(ln["text"] for ln in lines)
  124. sc = (
  125. sum(len(ln["text"]) * ln["score"] for ln in lines) / max(len(text), 1)
  126. if lines
  127. else 0.0
  128. )
  129. return {"text": text, "score": sc, "lines": lines, "backend": "paddleocr"}
  130. res = engine.ocr(img_bgr, det=det, rec=rec)
  131. lines: List[Dict[str, Any]] = []
  132. if res and res[0]:
  133. for item in res[0]:
  134. if not item or len(item) < 2:
  135. continue
  136. box, rec_part = item[0], item[1]
  137. text, score = _parse_rec_item(rec_part)
  138. if text:
  139. lines.append({"text": text, "score": score, "box": box})
  140. text = "".join(ln["text"] for ln in lines)
  141. score = (
  142. sum(len(ln["text"]) * ln["score"] for ln in lines) / max(len(text), 1)
  143. if lines
  144. else 0.0
  145. )
  146. return {"text": text, "score": score, "lines": lines, "mode": f"det={det},rec={rec}"}
  147. except Exception as e:
  148. return {
  149. "text": "",
  150. "score": 0.0,
  151. "lines": [],
  152. "error": str(e),
  153. "hint": "使用: conda activate mineru && python cell_preprocess_lab.py ...",
  154. }
  155. def _median_denoise(img: np.ndarray) -> np.ndarray:
  156. return cv2.medianBlur(img, 3)
  157. def _upscale_min_side(img: np.ndarray, min_side: int = 64) -> np.ndarray:
  158. h, w = img.shape[:2]
  159. if h >= min_side and w >= min_side:
  160. return img
  161. scale = max(min_side / max(h, 1), min_side / max(w, 1), 1.0)
  162. return cv2.resize(img, None, fx=scale, fy=scale, interpolation=cv2.INTER_CUBIC)
  163. def run_cell_pipeline(
  164. raw_bgr: np.ndarray,
  165. *,
  166. wm_method: str = "masked_adaptive",
  167. apply_denoise: bool = False,
  168. apply_contrast: bool = False,
  169. upscale_min: int = 64,
  170. ) -> Tuple[Dict[str, np.ndarray], List[str]]:
  171. stages: Dict[str, np.ndarray] = {"00_raw": raw_bgr.copy()}
  172. order: List[str] = ["00_raw"]
  173. wm_cfg = merge_watermark_config("cell", {"enabled": True, "method": wm_method})
  174. proc = WatermarkProcessor(wm_cfg, scope="cell")
  175. img, _ = proc.process(raw_bgr, force=True)
  176. stages["01_wm"] = img
  177. order.append("01_wm")
  178. step = 2
  179. if apply_denoise:
  180. img = _median_denoise(img)
  181. key = f"{step:02d}_denoise"
  182. stages[key] = img.copy()
  183. order.append(key)
  184. step += 1
  185. if apply_contrast:
  186. gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY) if img.ndim == 3 else img
  187. ce = dict(wm_cfg.get("contrast_enhancement") or {})
  188. ce["enabled"] = True
  189. gray = apply_contrast_enhancement_config(gray, ce)
  190. img = cv2.cvtColor(gray, cv2.COLOR_GRAY2BGR)
  191. key = f"{step:02d}_contrast"
  192. stages[key] = img.copy()
  193. order.append(key)
  194. step += 1
  195. img = _upscale_min_side(img, upscale_min)
  196. key = f"{step:02d}_upscale"
  197. stages[key] = img
  198. order.append(key)
  199. return stages, order
  200. def process_one(
  201. input_path: Path,
  202. output_dir: Path,
  203. *,
  204. compare_methods: bool = False,
  205. run_ocr: bool = True,
  206. apply_denoise: bool = False,
  207. apply_contrast: bool = False,
  208. ) -> Dict[str, Any]:
  209. output_dir.mkdir(parents=True, exist_ok=True)
  210. raw = cv2.imread(str(input_path))
  211. if raw is None:
  212. raise FileNotFoundError(f"无法读取: {input_path}")
  213. report: Dict[str, Any] = {
  214. "input": str(input_path),
  215. "pipeline_note": (
  216. "默认 01_wm→upscale,不做 median 去噪(小格易糊笔画)。"
  217. "可用 --denoise / --contrast 对比。"
  218. ),
  219. "stages": {},
  220. }
  221. methods = ["threshold", "masked_adaptive"] if compare_methods else ["masked_adaptive"]
  222. ocr_keys = {"00_raw", "01_wm"}
  223. # 总是 OCR 最终 upscale 阶段
  224. for method in methods:
  225. sub_dir = output_dir / method if compare_methods else output_dir
  226. sub_dir.mkdir(parents=True, exist_ok=True)
  227. stage_imgs, order = run_cell_pipeline(
  228. raw,
  229. wm_method=method,
  230. apply_denoise=apply_denoise,
  231. apply_contrast=apply_contrast,
  232. )
  233. method_report: Dict[str, Any] = {"files": {}, "ocr": {}}
  234. final_key = order[-1]
  235. for key in order:
  236. out_path = sub_dir / f"{input_path.stem}_{key}.png"
  237. cv2.imwrite(str(out_path), stage_imgs[key])
  238. method_report["files"][key] = str(out_path)
  239. if run_ocr and (key in ocr_keys or key == final_key):
  240. method_report["ocr"][key] = _ocr_cell(stage_imgs[key])
  241. if run_ocr:
  242. method_report["ocr_recommended"] = method_report["ocr"].get(
  243. "01_wm"
  244. ) or method_report["ocr"].get(final_key)
  245. report["stages"][method] = method_report
  246. report_path = output_dir / f"{input_path.stem}_lab_report.json"
  247. with open(report_path, "w", encoding="utf-8") as f:
  248. json.dump(report, f, ensure_ascii=False, indent=2)
  249. report["report_path"] = str(report_path)
  250. return report
  251. def collect_inputs(path: Path) -> List[Path]:
  252. if path.is_file():
  253. return [path]
  254. files: List[Path] = []
  255. for p in sorted(path.iterdir()):
  256. if p.suffix.lower() in _IMAGE_SUFFIXES and "cell" in p.name:
  257. files.append(p)
  258. return files
  259. def main() -> None:
  260. global _CONFIG_PATH
  261. parser = argparse.ArgumentParser(description="单元格预处理实验 lab")
  262. parser.add_argument("input", type=Path, help="单元格 PNG 或 tablecell_ocr 目录")
  263. parser.add_argument("-o", "--output", type=Path, required=True, help="输出目录")
  264. parser.add_argument(
  265. "-c",
  266. "--config",
  267. type=Path,
  268. default=_DEFAULT_CONFIG,
  269. help="场景 YAML(用于加载与 pipeline 相同的 OCR)",
  270. )
  271. parser.add_argument(
  272. "--compare-methods",
  273. action="store_true",
  274. help="对比 threshold 与 masked_adaptive",
  275. )
  276. parser.add_argument("--no-ocr", action="store_true", help="跳过 OCR 探测")
  277. parser.add_argument(
  278. "--denoise",
  279. action="store_true",
  280. help="在去水印后增加 median 去噪(默认关闭,小图易损笔画)",
  281. )
  282. parser.add_argument(
  283. "--contrast",
  284. action="store_true",
  285. help="在去噪/放大前增加 text_restore 对比度",
  286. )
  287. parser.add_argument(
  288. "--det-model-path",
  289. type=Path,
  290. default=None,
  291. help="覆盖检测模型 .pth(或设环境变量 OCR_DET_MODEL_PATH)",
  292. )
  293. parser.add_argument(
  294. "--rec-model-path",
  295. type=Path,
  296. default=None,
  297. help="覆盖识别模型 .pth(或设环境变量 OCR_REC_MODEL_PATH)",
  298. )
  299. args = parser.parse_args()
  300. _CONFIG_PATH = args.config
  301. if args.det_model_path:
  302. os.environ["OCR_DET_MODEL_PATH"] = str(args.det_model_path)
  303. if args.rec_model_path:
  304. os.environ["OCR_REC_MODEL_PATH"] = str(args.rec_model_path)
  305. inputs = collect_inputs(args.input)
  306. if not inputs:
  307. print(f"未找到输入: {args.input}")
  308. sys.exit(1)
  309. for inp in inputs:
  310. out = args.output / inp.stem if len(inputs) > 1 else args.output
  311. report = process_one(
  312. inp,
  313. out,
  314. compare_methods=args.compare_methods,
  315. run_ocr=not args.no_ocr,
  316. apply_denoise=args.denoise,
  317. apply_contrast=args.contrast,
  318. )
  319. print(json.dumps(report, ensure_ascii=False, indent=2))
  320. if __name__ == "__main__":
  321. if len(sys.argv) == 1:
  322. print("ℹ️ 未提供命令行参数,使用默认配置运行...")
  323. default_config = {
  324. # "input": "/Users/zhch158/workspace/data/流水分析/彭_广东兴宁农村商业银行/bank_statement_yusys_local/debug/table_recognition_wired/tablecell_ocr/彭_广东兴宁农村商业银行_page_002_0/cell029_whole_78.0111.0111.078.0司.png",
  325. "input": "/Users/zhch158/workspace/data/流水分析/彭_广东兴宁农村商业银行/bank_statement_yusys_local/debug/table_recognition_wired/tablecell_ocr/彭_广东兴宁农村商业银行_page_002_0/cell121_empty_empty.png",
  326. # "input": "/Users/zhch158/workspace/data/流水分析/彭_广东兴宁农村商业银行/bank_statement_yusys_local/debug/table_recognition_wired/tablecell_ocr/彭_广东兴宁农村商业银行_page_002_0/cell217_lines_取款.png",
  327. # "input": "/Users/zhch158/workspace/data/流水分析/彭_广东兴宁农村商业银行/bank_statement_yusys_local/debug/table_recognition_wired/tablecell_ocr/彭_广东兴宁农村商业银行_page_002_0/cell219_empty_empty.png",
  328. "output": "./output/彭_广东兴宁农村商业银行",
  329. "compare-methods": True,
  330. }
  331. sys.argv = [sys.argv[0], default_config["input"]]
  332. for key, value in default_config.items():
  333. if key == "input":
  334. continue
  335. flag = f"--{key.replace('_', '-')}"
  336. if isinstance(value, bool) and value:
  337. sys.argv.append(flag)
  338. elif not isinstance(value, bool):
  339. sys.argv.extend([flag, str(value)])
  340. sys.exit(main())