pipeline_manager.py 15 KB


  1. from typing import Dict, List, Any, Optional, Union
  2. from pathlib import Path
  3. import numpy as np
  4. from PIL import Image
  5. import fitz # PyMuPDF
  6. from loguru import logger
  7. from .model_factory import ModelFactory
  8. from .config_manager import ConfigManager
  9. from models.adapters import BaseAdapter
  10. class FinancialDocPipeline:
  11. """金融文档处理统一流水线"""
  12. def __init__(self, config_path: str):
  13. self.config = ConfigManager.load_config(config_path)
  14. self.scene_name = self.config.get('scene_name', 'unknown')
  15. # 初始化各个组件
  16. self._init_components()
  17. def _init_components(self):
  18. """初始化处理组件"""
  19. try:
  20. # 1. 预处理器(方向分类、图像矫正等)
  21. self.preprocessor = ModelFactory.create_preprocessor(
  22. self.config['preprocessor']
  23. )
  24. # 2. 版式检测器
  25. self.layout_detector = ModelFactory.create_layout_detector(
  26. self.config['layout_detection']
  27. )
  28. # 3. VL识别器(表格、公式等)
  29. self.vl_recognizer = ModelFactory.create_vl_recognizer(
  30. self.config['vl_recognition']
  31. )
  32. # 4. OCR识别器
  33. self.ocr_recognizer = ModelFactory.create_ocr_recognizer(
  34. self.config['ocr_recognition']
  35. )
  36. logger.info(f"✅ Initialized pipeline for scene: {self.scene_name}")
  37. except Exception as e:
  38. logger.error(f"❌ Failed to initialize pipeline components: {e}")
  39. raise
  40. def process_document(self, document_path: str) -> Dict[str, Any]:
  41. """
  42. 处理文档的主流程
  43. Args:
  44. document_path: 文档路径
  45. Returns:
  46. 处理结果,包含所有元素的坐标和内容信息
  47. """
  48. results = {
  49. 'scene': self.scene_name,
  50. 'document_path': document_path,
  51. 'pages': [],
  52. 'metadata': self._extract_metadata(document_path)
  53. }
  54. try:
  55. # 加载文档图像
  56. images = self._load_document_images(document_path)
  57. logger.info(f"📄 Loaded {len(images)} pages from document")
  58. for page_idx, image in enumerate(images):
  59. logger.info(f"🔍 Processing page {page_idx + 1}/{len(images)}")
  60. page_result = self._process_single_page(image, page_idx)
  61. results['pages'].append(page_result)
  62. except Exception as e:
  63. logger.error(f"❌ Failed to process document: {e}")
  64. raise
  65. return results
  66. def _load_document_images(self, document_path: str) -> List[np.ndarray]:
  67. """加载文档图像"""
  68. document_path = Path(document_path)
  69. if not document_path.exists():
  70. raise FileNotFoundError(f"Document not found: {document_path}")
  71. images = []
  72. if document_path.suffix.lower() == '.pdf':
  73. # 处理PDF文件
  74. doc = fitz.open(document_path)
  75. try:
  76. for page_num in range(len(doc)):
  77. page = doc.load_page(page_num)
  78. # 设置合适的DPI
  79. dpi = self.config.get('input', {}).get('dpi', 200)
  80. mat = fitz.Matrix(dpi/72, dpi/72)
  81. pix = page.get_pixmap(matrix=mat)
  82. img_data = pix.tobytes("ppm")
  83. # 转换为numpy数组
  84. from io import BytesIO
  85. img = Image.open(BytesIO(img_data))
  86. img_array = np.array(img)
  87. images.append(img_array)
  88. finally:
  89. doc.close()
  90. elif document_path.suffix.lower() in ['.png', '.jpg', '.jpeg', '.bmp', '.tiff']:
  91. # 处理图像文件
  92. img = Image.open(document_path)
  93. if img.mode != 'RGB':
  94. img = img.convert('RGB')
  95. img_array = np.array(img)
  96. images.append(img_array)
  97. else:
  98. raise ValueError(f"Unsupported file format: {document_path.suffix}")
  99. return images
  100. def _extract_metadata(self, document_path: str) -> Dict[str, Any]:
  101. """提取文档元数据"""
  102. document_path = Path(document_path)
  103. metadata = {
  104. 'filename': document_path.name,
  105. 'size': document_path.stat().st_size,
  106. 'format': document_path.suffix.lower()
  107. }
  108. # 如果是PDF,提取更多元数据
  109. if document_path.suffix.lower() == '.pdf':
  110. try:
  111. doc = fitz.open(document_path)
  112. metadata.update({
  113. 'page_count': len(doc),
  114. 'title': doc.metadata.get('title', ''),
  115. 'author': doc.metadata.get('author', ''),
  116. 'subject': doc.metadata.get('subject', ''),
  117. 'creator': doc.metadata.get('creator', '')
  118. })
  119. doc.close()
  120. except Exception:
  121. pass
  122. return metadata
  123. def _process_single_page(self, image: np.ndarray, page_idx: int) -> Dict[str, Any]:
  124. """处理单页文档"""
  125. # 1. 预处理(方向校正等)
  126. try:
  127. preprocessed_image, rotate_angle = self.preprocessor.process(image)
  128. except Exception as e:
  129. logger.warning(f"⚠️ Preprocessing failed for page {page_idx}: {e}")
  130. preprocessed_image = image
  131. # 2. 版式检测
  132. try:
  133. layout_results = self.layout_detector.detect(preprocessed_image)
  134. logger.info(f"📋 Detected {len(layout_results)} layout elements on page {page_idx}")
  135. except Exception as e:
  136. logger.error(f"❌ Layout detection failed for page {page_idx}: {e}")
  137. layout_results = []
  138. # 3. 根据场景类型分别处理不同元素
  139. page_elements = []
  140. for layout_item in layout_results:
  141. try:
  142. element_type = layout_item['category']
  143. if element_type in ['table_body', 'table']:
  144. # 表格使用VL模型处理
  145. element_result = self._process_table_element(
  146. preprocessed_image, layout_item
  147. )
  148. elif element_type in ['text', 'title', 'ocr_text']:
  149. # 文本使用OCR处理
  150. element_result = self._process_text_element(
  151. preprocessed_image, layout_item
  152. )
  153. elif element_type in ['interline_equation', 'inline_equation']:
  154. # 公式使用VL模型处理
  155. element_result = self._process_formula_element(
  156. preprocessed_image, layout_item
  157. )
  158. else:
  159. # 其他元素保持原样
  160. element_result = layout_item.copy()
  161. element_result['type'] = element_type
  162. page_elements.append(element_result)
  163. except Exception as e:
  164. logger.warning(f"⚠️ Failed to process element {element_type}: {e}")
  165. # 添加失败的元素,标记为错误
  166. error_element = layout_item.copy()
  167. error_element['type'] = 'error'
  168. error_element['error'] = str(e)
  169. page_elements.append(error_element)
  170. return {
  171. 'page_idx': page_idx,
  172. 'elements': page_elements,
  173. 'layout_raw': layout_results,
  174. 'image_shape': preprocessed_image.shape,
  175. 'processed_image': preprocessed_image,
  176. 'angle': rotate_angle
  177. }
  178. def _process_table_element(self, image: np.ndarray, layout_item: Dict[str, Any]) -> Dict[str, Any]:
  179. """处理表格元素"""
  180. try:
  181. # 裁剪表格区域
  182. cropped_table = self._crop_region(image, layout_item['bbox'])
  183. # 使用VL模型识别表格
  184. table_result = self.vl_recognizer.recognize_table(
  185. cropped_table,
  186. return_cells_coordinate=True # 关键:返回单元格坐标
  187. )
  188. # 转换坐标到原图坐标系
  189. if 'cells' in table_result:
  190. for cell in table_result['cells']:
  191. cell['absolute_bbox'] = self._convert_to_absolute_coords(
  192. cell['bbox'], layout_item['bbox']
  193. )
  194. result = {
  195. 'type': 'table',
  196. 'bbox': layout_item['bbox'],
  197. 'confidence': layout_item.get('confidence', 0.0),
  198. 'content': table_result,
  199. 'scene_specific': self._add_scene_specific_info(table_result)
  200. }
  201. logger.info(f"✅ Table processed with {len(table_result.get('cells', []))} cells")
  202. return result
  203. except Exception as e:
  204. logger.error(f"❌ Table processing failed: {e}")
  205. return {
  206. 'type': 'table',
  207. 'bbox': layout_item['bbox'],
  208. 'content': {'html': '', 'markdown': '', 'cells': []},
  209. 'error': str(e)
  210. }
  211. def _process_text_element(self, image: np.ndarray, layout_item: Dict[str, Any]) -> Dict[str, Any]:
  212. """处理文本元素"""
  213. try:
  214. # 裁剪文本区域
  215. cropped_text = self._crop_region(image, layout_item['bbox'])
  216. # 使用OCR识别文本
  217. text_results = self.ocr_recognizer.recognize_text(cropped_text)
  218. # 合并识别结果
  219. combined_text = ""
  220. if text_results:
  221. text_parts = [item['text'] for item in text_results if item['confidence'] > 0.5]
  222. combined_text = " ".join(text_parts)
  223. result = {
  224. 'type': layout_item['category'],
  225. 'bbox': layout_item['bbox'],
  226. 'confidence': layout_item.get('confidence', 0.0),
  227. 'content': {
  228. 'text': combined_text,
  229. 'ocr_details': text_results
  230. }
  231. }
  232. logger.info(f"✅ Text processed: '{combined_text[:50]}...'")
  233. return result
  234. except Exception as e:
  235. logger.error(f"❌ Text processing failed: {e}")
  236. return {
  237. 'type': layout_item['category'],
  238. 'bbox': layout_item['bbox'],
  239. 'content': {'text': '', 'ocr_details': []},
  240. 'error': str(e)
  241. }
  242. def _process_formula_element(self, image: np.ndarray, layout_item: Dict[str, Any]) -> Dict[str, Any]:
  243. """处理公式元素"""
  244. try:
  245. # 裁剪公式区域
  246. cropped_formula = self._crop_region(image, layout_item['bbox'])
  247. # 使用VL模型识别公式
  248. formula_result = self.vl_recognizer.recognize_formula(cropped_formula)
  249. result = {
  250. 'type': 'formula',
  251. 'bbox': layout_item['bbox'],
  252. 'confidence': layout_item.get('confidence', 0.0),
  253. 'content': formula_result
  254. }
  255. logger.info(f"✅ Formula processed: {formula_result.get('latex', '')[:50]}...")
  256. return result
  257. except Exception as e:
  258. logger.error(f"❌ Formula processing failed: {e}")
  259. return {
  260. 'type': 'formula',
  261. 'bbox': layout_item['bbox'],
  262. 'content': {'latex': '', 'confidence': 0.0},
  263. 'error': str(e)
  264. }
  265. def _crop_region(self, image: np.ndarray, bbox: List[float]) -> np.ndarray:
  266. """裁剪图像区域"""
  267. if len(bbox) < 4:
  268. return image
  269. x1, y1, x2, y2 = map(int, bbox)
  270. # 边界检查
  271. h, w = image.shape[:2]
  272. x1 = max(0, min(x1, w))
  273. y1 = max(0, min(y1, h))
  274. x2 = max(x1, min(x2, w))
  275. y2 = max(y1, min(y2, h))
  276. return image[y1:y2, x1:x2]
  277. def _convert_to_absolute_coords(self, relative_bbox: List[float], region_bbox: List[float]) -> List[float]:
  278. """将相对坐标转换为绝对坐标"""
  279. if len(relative_bbox) < 4 or len(region_bbox) < 4:
  280. return relative_bbox
  281. rx1, ry1, rx2, ry2 = relative_bbox
  282. bx1, by1, bx2, by2 = region_bbox
  283. # 计算绝对坐标
  284. abs_x1 = bx1 + rx1
  285. abs_y1 = by1 + ry1
  286. abs_x2 = bx1 + rx2
  287. abs_y2 = by1 + ry2
  288. return [abs_x1, abs_y1, abs_x2, abs_y2]
  289. def _add_scene_specific_info(self, content: Dict[str, Any]) -> Dict[str, Any]:
  290. """根据场景添加特定信息"""
  291. if self.scene_name == 'bank_statement':
  292. return self._process_bank_statement_table(content)
  293. elif self.scene_name == 'financial_report':
  294. return self._process_financial_report_table(content)
  295. return {}
  296. def _process_bank_statement_table(self, content: Dict[str, Any]) -> Dict[str, Any]:
  297. """处理银行流水表格特定逻辑"""
  298. scene_info = {
  299. 'table_type': 'bank_statement',
  300. 'expected_columns': ['日期', '摘要', '收入', '支出', '余额'],
  301. 'validation_rules': {
  302. 'amount_format': True,
  303. 'date_format': True,
  304. 'balance_consistency': True
  305. }
  306. }
  307. # 进行银行流水特定的验证和处理
  308. if 'html' in content and content['html']:
  309. # 这里可以添加银行流水特定的HTML后处理逻辑
  310. pass
  311. return scene_info
  312. def _process_financial_report_table(self, content: Dict[str, Any]) -> Dict[str, Any]:
  313. """处理财务报表特定逻辑"""
  314. scene_info = {
  315. 'table_type': 'financial_report',
  316. 'complex_headers': True,
  317. 'merged_cells': True,
  318. 'validation_rules': {
  319. 'accounting_format': True,
  320. 'sum_validation': True
  321. }
  322. }
  323. # 进行财务报表特定的验证和处理
  324. if 'html' in content and content['html']:
  325. # 这里可以添加财务报表特定的HTML后处理逻辑
  326. pass
  327. return scene_info
  328. def cleanup(self):
  329. """清理资源"""
  330. try:
  331. if hasattr(self, 'preprocessor'):
  332. self.preprocessor.cleanup()
  333. if hasattr(self, 'layout_detector'):
  334. self.layout_detector.cleanup()
  335. if hasattr(self, 'vl_recognizer'):
  336. self.vl_recognizer.cleanup()
  337. if hasattr(self, 'ocr_recognizer'):
  338. self.ocr_recognizer.cleanup()
  339. logger.info("✅ Pipeline cleanup completed")
  340. except Exception as e:
  341. logger.warning(f"⚠️ Cleanup failed: {e}")
  342. def __enter__(self):
  343. return self
  344. def __exit__(self, exc_type, exc_val, exc_tb):
  345. self.cleanup()