cell_preprocess_lab.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385
  1. #!/usr/bin/env python3
  2. """
  3. 单元格裁剪图预处理实验:去水印 →(可选去噪/对比度)→ 放大 → OCR。
  4. 与 pipeline 二次 OCR 对齐,使用 ocr_tools.pytorch_models.PytorchPaddleOCR(非 paddleocr pip 包)。
  5. 用法:
  6. python cell_preprocess_lab.py cell219.png -o /tmp/cell_lab
  7. python cell_preprocess_lab.py /path/to/tablecell_ocr/ -o /tmp/batch --compare-methods
  8. python cell_preprocess_lab.py cell217.png -o /tmp/out --denoise --contrast
  9. """
  10. from __future__ import annotations
  11. import argparse
  12. import json
  13. import os
  14. import sys
  15. from pathlib import Path
  16. from typing import Any, Dict, List, Optional, Tuple
  17. import cv2
  18. import numpy as np
  19. import yaml
  20. _repo_root = Path(__file__).resolve().parents[2]
  21. _parser_root = _repo_root / "ocr_tools" / "universal_doc_parser"
  22. for _p in (_repo_root, _parser_root):
  23. if str(_p) not in sys.path:
  24. sys.path.insert(0, str(_p))
  25. from ocr_utils.watermark import WatermarkProcessor, merge_watermark_config
  26. from ocr_utils.watermark.contrast import apply_contrast_enhancement_config
  27. _DEFAULT_CONFIG = (
  28. _repo_root
  29. / "ocr_tools/universal_doc_parser/config/bank_statement_yusys_local.yaml"
  30. )
  31. _IMAGE_SUFFIXES = {".png", ".jpg", ".jpeg", ".bmp", ".tif", ".tiff", ".webp"}
  32. _OCR_ENGINE: Any = None
  33. _CONFIG_PATH: Optional[Path] = None
  34. def _get_ocr_engine() -> Any:
  35. """与 main_v2 pipeline 相同:ModelFactory → MinerU OCR atom model。"""
  36. global _OCR_ENGINE
  37. if _OCR_ENGINE is not None:
  38. return _OCR_ENGINE
  39. cfg_path = _CONFIG_PATH or _DEFAULT_CONFIG
  40. if not cfg_path.is_file():
  41. raise FileNotFoundError(f"场景配置不存在: {cfg_path}")
  42. with open(cfg_path, encoding="utf-8") as f:
  43. raw = yaml.safe_load(f) or {}
  44. ocr_cfg = raw.get("ocr_recognition") or {}
  45. errors: List[str] = []
  46. try:
  47. from core.model_factory import ModelFactory
  48. recognizer = ModelFactory.create_ocr_recognizer(ocr_cfg)
  49. engine = getattr(recognizer, "ocr_model", recognizer)
  50. if engine is None:
  51. raise RuntimeError("ocr_model 未初始化")
  52. _OCR_ENGINE = engine
  53. return _OCR_ENGINE
  54. except Exception as e:
  55. errors.append(f"ModelFactory/MinerU: {e}")
  56. det_path = os.environ.get("OCR_DET_MODEL_PATH")
  57. rec_path = os.environ.get("OCR_REC_MODEL_PATH")
  58. if det_path or rec_path:
  59. try:
  60. from ocr_tools.pytorch_models.pytorch_paddle import PytorchPaddleOCR
  61. kw: Dict[str, Any] = {"lang": ocr_cfg.get("language", "ch")}
  62. if det_path:
  63. kw["det_model_path"] = det_path
  64. if rec_path:
  65. kw["rec_model_path"] = rec_path
  66. _OCR_ENGINE = PytorchPaddleOCR(**kw)
  67. return _OCR_ENGINE
  68. except Exception as e2:
  69. errors.append(f"PytorchPaddleOCR(env paths): {e2}")
  70. try:
  71. from paddleocr import PaddleOCR
  72. _OCR_ENGINE = PaddleOCR(use_angle_cls=False, lang="ch", show_log=False)
  73. return _OCR_ENGINE
  74. except Exception as e3:
  75. errors.append(
  76. f"paddleocr pip(可选 pip install paddleocr): {e3}"
  77. )
  78. raise ImportError(
  79. "无法加载 OCR 引擎。请在 mineru 环境中运行,并确保场景 YAML 中 ocr_recognition "
  80. f"可正常初始化(与 main_v2 相同)。详情:\n - " + "\n - ".join(errors)
  81. )
  82. def _parse_rec_item(rec_item: Any) -> Tuple[str, float]:
  83. if rec_item is None:
  84. return "", 0.0
  85. if isinstance(rec_item, tuple) and len(rec_item) >= 2:
  86. txt = str(rec_item[0] or "").strip()
  87. sc = float(rec_item[1] or 0.0)
  88. return txt, 0.0 if not txt else sc
  89. if isinstance(rec_item, list) and len(rec_item) >= 2:
  90. if isinstance(rec_item[0], (list, tuple)):
  91. parts: List[str] = []
  92. scores: List[float] = []
  93. for item in rec_item:
  94. t, s = _parse_rec_item(item)
  95. if t:
  96. parts.append(t)
  97. scores.append(s)
  98. if not parts:
  99. return "", 0.0
  100. combined = "".join(parts)
  101. n = sum(len(t) for t in parts)
  102. return combined, sum(len(t) * s for t, s in zip(parts, scores)) / max(n, 1)
  103. txt = str(rec_item[0] or "").strip()
  104. sc = float(rec_item[1] or 0.0)
  105. return txt, 0.0 if not txt else sc
  106. return "", 0.0
  107. def _ocr_cell(img_bgr: np.ndarray, *, det: bool = True, rec: bool = True) -> Dict[str, Any]:
  108. """整格 det+rec,与 TextFiller._recognize_whole_cell 类似。"""
  109. try:
  110. engine = _get_ocr_engine()
  111. # paddleocr.PaddleOCR 与 PytorchPaddleOCR / MinerU 接口略有差异
  112. if engine.__class__.__name__ == "PaddleOCR":
  113. rgb = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2RGB)
  114. res = engine.ocr(rgb, cls=False)
  115. lines = []
  116. if res and res[0]:
  117. for item in res[0]:
  118. if item and len(item) >= 2:
  119. text, score = str(item[1][0]), float(item[1][1])
  120. lines.append({"text": text, "score": score})
  121. text = "".join(ln["text"] for ln in lines)
  122. sc = (
  123. sum(len(ln["text"]) * ln["score"] for ln in lines) / max(len(text), 1)
  124. if lines
  125. else 0.0
  126. )
  127. return {"text": text, "score": sc, "lines": lines, "backend": "paddleocr"}
  128. res = engine.ocr(img_bgr, det=det, rec=rec)
  129. lines: List[Dict[str, Any]] = []
  130. if res and res[0]:
  131. for item in res[0]:
  132. if not item or len(item) < 2:
  133. continue
  134. box, rec_part = item[0], item[1]
  135. text, score = _parse_rec_item(rec_part)
  136. if text:
  137. lines.append({"text": text, "score": score, "box": box})
  138. text = "".join(ln["text"] for ln in lines)
  139. score = (
  140. sum(len(ln["text"]) * ln["score"] for ln in lines) / max(len(text), 1)
  141. if lines
  142. else 0.0
  143. )
  144. return {"text": text, "score": score, "lines": lines, "mode": f"det={det},rec={rec}"}
  145. except Exception as e:
  146. return {
  147. "text": "",
  148. "score": 0.0,
  149. "lines": [],
  150. "error": str(e),
  151. "hint": "使用: conda activate mineru && python cell_preprocess_lab.py ...",
  152. }
  153. def _median_denoise(img: np.ndarray) -> np.ndarray:
  154. return cv2.medianBlur(img, 3)
  155. def _upscale_min_side(img: np.ndarray, min_side: int = 64) -> np.ndarray:
  156. h, w = img.shape[:2]
  157. if h >= min_side and w >= min_side:
  158. return img
  159. scale = max(min_side / max(h, 1), min_side / max(w, 1), 1.0)
  160. return cv2.resize(img, None, fx=scale, fy=scale, interpolation=cv2.INTER_CUBIC)
  161. def run_cell_pipeline(
  162. raw_bgr: np.ndarray,
  163. *,
  164. wm_method: str = "masked_adaptive",
  165. apply_denoise: bool = False,
  166. apply_contrast: bool = False,
  167. upscale_min: int = 64,
  168. ) -> Tuple[Dict[str, np.ndarray], List[str]]:
  169. stages: Dict[str, np.ndarray] = {"00_raw": raw_bgr.copy()}
  170. order: List[str] = ["00_raw"]
  171. wm_cfg = merge_watermark_config("cell", {"enabled": True, "method": wm_method})
  172. proc = WatermarkProcessor(wm_cfg, scope="cell")
  173. img, _ = proc.process(raw_bgr, force=True)
  174. stages["01_wm"] = img
  175. order.append("01_wm")
  176. step = 2
  177. if apply_denoise:
  178. img = _median_denoise(img)
  179. key = f"{step:02d}_denoise"
  180. stages[key] = img.copy()
  181. order.append(key)
  182. step += 1
  183. if apply_contrast:
  184. gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY) if img.ndim == 3 else img
  185. ce = dict(wm_cfg.get("contrast_enhancement") or {})
  186. ce["enabled"] = True
  187. gray = apply_contrast_enhancement_config(gray, ce)
  188. img = cv2.cvtColor(gray, cv2.COLOR_GRAY2BGR)
  189. key = f"{step:02d}_contrast"
  190. stages[key] = img.copy()
  191. order.append(key)
  192. step += 1
  193. img = _upscale_min_side(img, upscale_min)
  194. key = f"{step:02d}_upscale"
  195. stages[key] = img
  196. order.append(key)
  197. return stages, order
  198. def process_one(
  199. input_path: Path,
  200. output_dir: Path,
  201. *,
  202. compare_methods: bool = False,
  203. run_ocr: bool = True,
  204. apply_denoise: bool = False,
  205. apply_contrast: bool = False,
  206. ) -> Dict[str, Any]:
  207. output_dir.mkdir(parents=True, exist_ok=True)
  208. raw = cv2.imread(str(input_path))
  209. if raw is None:
  210. raise FileNotFoundError(f"无法读取: {input_path}")
  211. report: Dict[str, Any] = {
  212. "input": str(input_path),
  213. "pipeline_note": (
  214. "默认 01_wm→upscale,不做 median 去噪(小格易糊笔画)。"
  215. "可用 --denoise / --contrast 对比。"
  216. ),
  217. "stages": {},
  218. }
  219. methods = ["threshold", "masked_adaptive"] if compare_methods else ["masked_adaptive"]
  220. ocr_keys = {"00_raw", "01_wm"}
  221. # 总是 OCR 最终 upscale 阶段
  222. for method in methods:
  223. sub_dir = output_dir / method if compare_methods else output_dir
  224. sub_dir.mkdir(parents=True, exist_ok=True)
  225. stage_imgs, order = run_cell_pipeline(
  226. raw,
  227. wm_method=method,
  228. apply_denoise=apply_denoise,
  229. apply_contrast=apply_contrast,
  230. )
  231. method_report: Dict[str, Any] = {"files": {}, "ocr": {}}
  232. final_key = order[-1]
  233. for key in order:
  234. out_path = sub_dir / f"{input_path.stem}_{key}.png"
  235. cv2.imwrite(str(out_path), stage_imgs[key])
  236. method_report["files"][key] = str(out_path)
  237. if run_ocr and (key in ocr_keys or key == final_key):
  238. method_report["ocr"][key] = _ocr_cell(stage_imgs[key])
  239. if run_ocr:
  240. method_report["ocr_recommended"] = method_report["ocr"].get(
  241. "01_wm"
  242. ) or method_report["ocr"].get(final_key)
  243. report["stages"][method] = method_report
  244. report_path = output_dir / f"{input_path.stem}_lab_report.json"
  245. with open(report_path, "w", encoding="utf-8") as f:
  246. json.dump(report, f, ensure_ascii=False, indent=2)
  247. report["report_path"] = str(report_path)
  248. return report
  249. def collect_inputs(path: Path) -> List[Path]:
  250. if path.is_file():
  251. return [path]
  252. files: List[Path] = []
  253. for p in sorted(path.iterdir()):
  254. if p.suffix.lower() in _IMAGE_SUFFIXES and "cell" in p.name:
  255. files.append(p)
  256. return files
  257. def main() -> None:
  258. global _CONFIG_PATH
  259. parser = argparse.ArgumentParser(description="单元格预处理实验 lab")
  260. parser.add_argument("input", type=Path, help="单元格 PNG 或 tablecell_ocr 目录")
  261. parser.add_argument("-o", "--output", type=Path, required=True, help="输出目录")
  262. parser.add_argument(
  263. "-c",
  264. "--config",
  265. type=Path,
  266. default=_DEFAULT_CONFIG,
  267. help="场景 YAML(用于加载与 pipeline 相同的 OCR)",
  268. )
  269. parser.add_argument(
  270. "--compare-methods",
  271. action="store_true",
  272. help="对比 threshold 与 masked_adaptive",
  273. )
  274. parser.add_argument("--no-ocr", action="store_true", help="跳过 OCR 探测")
  275. parser.add_argument(
  276. "--denoise",
  277. action="store_true",
  278. help="在去水印后增加 median 去噪(默认关闭,小图易损笔画)",
  279. )
  280. parser.add_argument(
  281. "--contrast",
  282. action="store_true",
  283. help="在去噪/放大前增加 text_restore 对比度",
  284. )
  285. parser.add_argument(
  286. "--det-model-path",
  287. type=Path,
  288. default=None,
  289. help="覆盖检测模型 .pth(或设环境变量 OCR_DET_MODEL_PATH)",
  290. )
  291. parser.add_argument(
  292. "--rec-model-path",
  293. type=Path,
  294. default=None,
  295. help="覆盖识别模型 .pth(或设环境变量 OCR_REC_MODEL_PATH)",
  296. )
  297. args = parser.parse_args()
  298. _CONFIG_PATH = args.config
  299. if args.det_model_path:
  300. os.environ["OCR_DET_MODEL_PATH"] = str(args.det_model_path)
  301. if args.rec_model_path:
  302. os.environ["OCR_REC_MODEL_PATH"] = str(args.rec_model_path)
  303. inputs = collect_inputs(args.input)
  304. if not inputs:
  305. print(f"未找到输入: {args.input}")
  306. sys.exit(1)
  307. for inp in inputs:
  308. out = args.output / inp.stem if len(inputs) > 1 else args.output
  309. report = process_one(
  310. inp,
  311. out,
  312. compare_methods=args.compare_methods,
  313. run_ocr=not args.no_ocr,
  314. apply_denoise=args.denoise,
  315. apply_contrast=args.contrast,
  316. )
  317. print(json.dumps(report, ensure_ascii=False, indent=2))
  318. if __name__ == "__main__":
  319. if len(sys.argv) == 1:
  320. print("ℹ️ 未提供命令行参数,使用默认配置运行...")
  321. default_config = {
  322. # "input": "/Users/zhch158/workspace/data/流水分析/彭_广东兴宁农村商业银行/bank_statement_yusys_local/debug/table_recognition_wired/tablecell_ocr/彭_广东兴宁农村商业银行_page_002_0/cell029_whole_78.0111.0111.078.0司.png",
  323. "input": "/Users/zhch158/workspace/data/流水分析/彭_广东兴宁农村商业银行/bank_statement_yusys_local/debug/table_recognition_wired/tablecell_ocr/彭_广东兴宁农村商业银行_page_002_0/cell121_empty_empty.png",
  324. # "input": "/Users/zhch158/workspace/data/流水分析/彭_广东兴宁农村商业银行/bank_statement_yusys_local/debug/table_recognition_wired/tablecell_ocr/彭_广东兴宁农村商业银行_page_002_0/cell217_lines_取款.png",
  325. # "input": "/Users/zhch158/workspace/data/流水分析/彭_广东兴宁农村商业银行/bank_statement_yusys_local/debug/table_recognition_wired/tablecell_ocr/彭_广东兴宁农村商业银行_page_002_0/cell219_empty_empty.png",
  326. "output": "./output/彭_广东兴宁农村商业银行",
  327. "compare-methods": True,
  328. }
  329. sys.argv = [sys.argv[0], default_config["input"]]
  330. for key, value in default_config.items():
  331. if key == "input":
  332. continue
  333. flag = f"--{key.replace('_', '-')}"
  334. if isinstance(value, bool) and value:
  335. sys.argv.append(flag)
  336. elif not isinstance(value, bool):
  337. sys.argv.extend([flag, str(value)])
  338. sys.exit(main())