batch_service.py 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400
  1. """
  2. 批量处理服务
  3. 基于首页模板应用到多个文件,适用于多页银行流水等场景
  4. """
  5. import json
  6. import sys
  7. from pathlib import Path
  8. from typing import List, Dict, Optional, Tuple, Any, Callable
  9. from concurrent.futures import ThreadPoolExecutor, as_completed
  10. from loguru import logger
  11. from PIL import Image
  12. import io
  13. # 添加 ocr_platform 根目录到 Python 路径(用于导入 core 和 ocr_utils)
  14. _file_path = Path(__file__).resolve()
  15. ocr_platform_root = _file_path.parents[2] # batch_service.py -> services -> backend -> table_line_generator -> ocr_platform
  16. if str(ocr_platform_root) not in sys.path:
  17. sys.path.insert(0, str(ocr_platform_root))
  18. from table_line_generator.core.table_analyzer import TableAnalyzer
  19. from table_line_generator.core.ocr_parser import OcrParser
  20. from table_line_generator.core.drawing_service import DrawingService
  21. from table_line_generator.backend.services.template_service import TemplateService
  22. class BatchProcessor:
  23. """
  24. 批量表格处理器
  25. 工作流程:
  26. 1. 从首页学习模板结构(主要是竖线/列边界)
  27. 2. 将模板应用到所有页(竖线复用,横线自适应)
  28. 3. 支持并行处理提升效率
  29. """
  30. def __init__(self, max_workers: int = 4):
  31. """
  32. 初始化批量处理器
  33. Args:
  34. max_workers: 并行处理的最大线程数
  35. """
  36. self.max_workers = max_workers
  37. def learn_template_from_structure(
  38. self,
  39. first_page_structure: Dict
  40. ) -> Dict:
  41. """
  42. 从首页结构中学习模板
  43. Args:
  44. first_page_structure: 首页的表格结构
  45. Returns:
  46. 模板字典,包含可复用的部分(主要是列边界)
  47. """
  48. template = {
  49. 'vertical_lines': first_page_structure.get('vertical_lines', []).copy(),
  50. 'table_bbox': first_page_structure.get('table_bbox'),
  51. 'col_widths': first_page_structure.get('col_widths'),
  52. 'total_cols': first_page_structure.get('total_cols'),
  53. 'mode': first_page_structure.get('mode', 'hybrid')
  54. }
  55. logger.info(f"从首页学习模板: {len(template['vertical_lines'])} 条竖线, {template['total_cols']} 列")
  56. return template
  57. def apply_template_to_structure(
  58. self,
  59. template: Dict,
  60. target_ocr_data: Dict,
  61. adjust_rows: bool = True,
  62. y_tolerance: int = 5,
  63. min_row_height: int = 20
  64. ) -> Dict:
  65. """
  66. 将模板应用到目标页面的 OCR 数据
  67. Args:
  68. template: 从首页学习的模板
  69. target_ocr_data: 目标页面的 OCR 数据
  70. adjust_rows: 是否自适应调整行分割
  71. y_tolerance: Y轴聚类容差
  72. min_row_height: 最小行高
  73. Returns:
  74. 应用模板后的结构
  75. """
  76. # 创建分析器
  77. analyzer = TableAnalyzer(None, target_ocr_data)
  78. if adjust_rows:
  79. # 重新分析行结构(自适应)
  80. analyzed = analyzer.analyze(
  81. y_tolerance=y_tolerance,
  82. min_row_height=min_row_height,
  83. method=template.get('mode', 'auto')
  84. )
  85. # 复用模板的列信息
  86. new_structure = {
  87. 'horizontal_lines': analyzed['horizontal_lines'],
  88. 'vertical_lines': template['vertical_lines'].copy(),
  89. 'table_bbox': template.get('table_bbox') or analyzed.get('table_bbox'),
  90. 'row_height': analyzed.get('row_height'),
  91. 'col_widths': template.get('col_widths'),
  92. 'total_rows': analyzed.get('total_rows'),
  93. 'total_cols': template.get('total_cols'),
  94. 'mode': template.get('mode'),
  95. 'modified_h_lines': [],
  96. 'modified_v_lines': [],
  97. 'image_rotation_angle': target_ocr_data.get('image_rotation_angle', 0.0),
  98. 'skew_angle': target_ocr_data.get('skew_angle', 0.0),
  99. 'is_skew_corrected': target_ocr_data.get('is_skew_corrected', False)
  100. }
  101. else:
  102. # 完全复用模板(包括横线)
  103. new_structure = template.copy()
  104. new_structure['image_rotation_angle'] = target_ocr_data.get('image_rotation_angle', 0.0)
  105. new_structure['skew_angle'] = target_ocr_data.get('skew_angle', 0.0)
  106. new_structure['is_skew_corrected'] = target_ocr_data.get('is_skew_corrected', False)
  107. return new_structure
  108. def process_batch_from_data_source(
  109. self,
  110. template_name: str,
  111. file_pairs: List[Dict],
  112. output_dir: str,
  113. parallel: bool = True,
  114. adjust_rows: bool = True,
  115. structure_suffix: str = "_structure.json",
  116. image_suffix: str = "_with_lines.png",
  117. progress_callback: Optional[Callable[[int, int], None]] = None
  118. ) -> Dict:
  119. """
  120. 批量处理数据源中的文件
  121. Args:
  122. template_name: 模板名称(从 TemplateService 加载)
  123. file_pairs: 文件对列表 [{'json_path': ..., 'image_path': ...}, ...]
  124. output_dir: 输出目录
  125. parallel: 是否并行处理
  126. adjust_rows: 是否自适应调整行分割
  127. structure_suffix: 结构文件后缀
  128. image_suffix: 输出图片后缀
  129. progress_callback: 进度回调 callback(index, total)
  130. Returns:
  131. 处理结果摘要
  132. """
  133. total = len(file_pairs)
  134. results = []
  135. # 加载模板
  136. template_service = TemplateService()
  137. logger.info(f"开始批量处理: {total} 个文件, 使用模板: {template_name}, 并行={parallel}")
  138. if parallel and total > 1:
  139. # 并行处理
  140. with ThreadPoolExecutor(max_workers=self.max_workers) as executor:
  141. futures = {
  142. executor.submit(
  143. self._process_single_file,
  144. pair, template_name, output_dir, adjust_rows, structure_suffix
  145. ): idx
  146. for idx, pair in enumerate(file_pairs)
  147. }
  148. for future in as_completed(futures):
  149. idx = futures[future]
  150. try:
  151. result = future.result()
  152. results.append(result)
  153. if progress_callback:
  154. progress_callback(idx, total)
  155. status = "✅" if result['success'] else "❌"
  156. logger.info(f"{status} [{idx+1}/{total}] {result.get('filename', 'unknown')}")
  157. except Exception as e:
  158. logger.error(f"❌ 处理失败 [{idx+1}/{total}]: {e}")
  159. results.append({
  160. 'success': False,
  161. 'error': str(e),
  162. 'index': idx
  163. })
  164. else:
  165. # 串行处理
  166. for idx, pair in enumerate(file_pairs):
  167. try:
  168. result = self._process_single_file(
  169. pair, template_name, output_dir, adjust_rows, structure_suffix
  170. )
  171. results.append(result)
  172. if progress_callback:
  173. progress_callback(idx, total)
  174. status = "✅" if result['success'] else "❌"
  175. logger.info(f"{status} [{idx+1}/{total}] {result.get('filename', 'unknown')}")
  176. except Exception as e:
  177. logger.error(f"❌ 处理失败 [{idx+1}/{total}]: {e}")
  178. results.append({
  179. 'success': False,
  180. 'error': str(e),
  181. 'index': idx
  182. })
  183. # 统计结果
  184. success_count = sum(1 for r in results if r.get('success'))
  185. failed_count = total - success_count
  186. summary = {
  187. 'total': total,
  188. 'success': success_count,
  189. 'failed': failed_count,
  190. 'results': results
  191. }
  192. logger.info(f"📊 批量处理完成: 成功 {success_count}/{total}, 失败 {failed_count}")
  193. return summary
  194. def _process_single_file(
  195. self,
  196. file_pair: Dict,
  197. template_name: str,
  198. output_dir: str,
  199. adjust_rows: bool,
  200. structure_suffix: str = "_structure.json"
  201. ) -> Dict:
  202. """
  203. 处理单个文件
  204. Args:
  205. file_pair: 文件对 {'json_path': ..., 'image_path': ...}
  206. template_name: 模板名称(从 TemplateService 加载)
  207. output_dir: 输出目录
  208. adjust_rows: 是否调整行
  209. structure_suffix: 结构文件后缀
  210. Returns:
  211. 处理结果
  212. """
  213. json_path = Path(file_pair['json_path'])
  214. image_path = Path(file_pair['image_path'])
  215. try:
  216. # 1. 读取 OCR 数据
  217. with open(json_path, 'r', encoding='utf-8') as f:
  218. ocr_result = json.load(f)
  219. # 2. 解析 OCR 数据(获取目标页面的 table_bbox 和 ocr_data)
  220. target_table_bbox, ocr_data = OcrParser.parse(ocr_result)
  221. target_image_size = ocr_data.get('image_size', {'width': 1, 'height': 1})
  222. # 3. 使用 TemplateService.preview_apply() 应用模板到目标页面
  223. # 这会自动处理坐标映射,适配不同尺寸的图片
  224. template_service = TemplateService()
  225. applied_template = template_service.preview_apply(
  226. template_name=template_name,
  227. target_image_size=target_image_size,
  228. target_table_bbox=target_table_bbox,
  229. mode='relative' # 使用相对坐标映射,适应不同尺寸图片
  230. )
  231. # 4. 构建最终结构(合并应用的模板和目标页面的元数据)
  232. if adjust_rows:
  233. # 如果启用自适应行,则自动分析目标页面的行结构
  234. analyzer = TableAnalyzer(None, ocr_data)
  235. analyzed = analyzer.analyze()
  236. # 关键:使用模板的竖线(已通过 preview_apply 映射),结合目标的横线
  237. new_structure = {
  238. 'horizontal_lines': analyzed['horizontal_lines'], # 自适应调整
  239. 'vertical_lines': applied_template['vertical_lines'], # 来自模板,已映射
  240. 'table_bbox': applied_template['table_bbox'], # 目标页面的 bbox
  241. 'row_height': analyzed.get('row_height'),
  242. 'col_widths': applied_template.get('col_widths'),
  243. 'total_rows': analyzed.get('total_rows'),
  244. 'total_cols': applied_template.get('total_cols'),
  245. 'mode': applied_template.get('mode', 'hybrid'),
  246. 'modified_h_lines': [],
  247. 'modified_v_lines': [],
  248. # 各页使用自己的旋转角度
  249. 'image_rotation_angle': ocr_data.get('image_rotation_angle', 0.0),
  250. 'skew_angle': ocr_data.get('skew_angle', 0.0),
  251. 'is_skew_corrected': ocr_data.get('is_skew_corrected', False)
  252. }
  253. else:
  254. # 完全复用应用的模板(包括竖线和横线)
  255. new_structure = applied_template.copy()
  256. new_structure['image_rotation_angle'] = ocr_data.get('image_rotation_angle', 0.0)
  257. new_structure['skew_angle'] = ocr_data.get('skew_angle', 0.0)
  258. new_structure['is_skew_corrected'] = ocr_data.get('is_skew_corrected', False)
  259. # 5. 保存结构文件
  260. output_path = Path(output_dir)
  261. output_path.mkdir(parents=True, exist_ok=True)
  262. structure_filename = f"{image_path.stem}{structure_suffix}"
  263. structure_path = output_path / structure_filename
  264. # 准备保存的数据(移除不需要的字段)
  265. structure_to_save = new_structure.copy()
  266. for key in ['rows', 'columns']:
  267. structure_to_save.pop(key, None)
  268. with open(structure_path, 'w', encoding='utf-8') as f:
  269. json.dump(structure_to_save, f, ensure_ascii=False, indent=2)
  270. return {
  271. 'success': True,
  272. 'json_path': str(json_path),
  273. 'image_path': str(image_path),
  274. 'structure_path': str(structure_path),
  275. 'filename': image_path.name,
  276. 'rows': new_structure.get('total_rows', 0),
  277. 'cols': new_structure.get('total_cols', 0)
  278. }
  279. except Exception as e:
  280. logger.exception(f"处理文件失败: {json_path}")
  281. return {
  282. 'success': False,
  283. 'json_path': str(json_path),
  284. 'image_path': str(image_path),
  285. 'error': str(e),
  286. 'filename': image_path.name if image_path else 'unknown'
  287. }
  288. def draw_batch_images(
  289. self,
  290. results: List[Dict],
  291. line_width: int = 2,
  292. line_color: Tuple[int, int, int] = (0, 0, 0)
  293. ) -> List[Dict]:
  294. """
  295. 批量绘制表格线到图片上
  296. Args:
  297. results: process_batch_from_data_source 的返回结果中的 results 列表
  298. line_width: 线条宽度
  299. line_color: 线条颜色 RGB
  300. Returns:
  301. 绘制结果列表
  302. """
  303. draw_results = []
  304. for result in results:
  305. if not result.get('success'):
  306. continue
  307. try:
  308. image_path = Path(result['image_path'])
  309. structure_path = Path(result['structure_path'])
  310. # 读取图片
  311. image = Image.open(image_path)
  312. if image.mode != 'RGB':
  313. image = image.convert('RGB')
  314. # 读取结构
  315. with open(structure_path, 'r', encoding='utf-8') as f:
  316. structure = json.load(f)
  317. # 绘制线条
  318. image_with_lines = DrawingService.draw_clean_lines(
  319. image, structure, line_width=line_width, line_color=line_color
  320. )
  321. # 保存
  322. output_path = structure_path.parent / f"{image_path.stem}.png"
  323. image_with_lines.save(str(output_path), 'PNG')
  324. draw_results.append({
  325. 'success': True,
  326. 'image_path': str(output_path),
  327. 'filename': image_path.name
  328. })
  329. except Exception as e:
  330. logger.error(f"绘制图片失败 {result.get('filename')}: {e}")
  331. draw_results.append({
  332. 'success': False,
  333. 'error': str(e),
  334. 'filename': result.get('filename')
  335. })
  336. success_count = sum(1 for r in draw_results if r.get('success'))
  337. logger.info(f"🖼️ 绘制完成: {success_count}/{len(results)} 张图片")
  338. return draw_results