table_mode_selector.py 34 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849
  1. # zhch/table_mode_selector.py
  2. import cv2
  3. import numpy as np
  4. from paddlex import create_pipeline, create_model
  5. class TableModeSelector:
  6. def __init__(self):
  7. # 使用配置中的layout模型
  8. self.layout_model = create_model(model_name="PP-DocLayout_plus-L")
  9. # 使用配置中的模型进行预分析
  10. self.table_cls_model = create_model(model_name="PP-LCNet_x1_0_table_cls")
  11. def analyze_table_features(self, table_image):
  12. """分析表格特征,返回特征字典"""
  13. features = {}
  14. # 1. 表格类型检测
  15. table_type = self.get_table_type(table_image)
  16. features['table_type'] = table_type
  17. # 2. 复杂度分析
  18. complexity = self.analyze_complexity(table_image)
  19. features.update(complexity)
  20. # 3. 结构规整度分析
  21. regularity = self.analyze_regularity(table_image)
  22. features.update(regularity)
  23. # 4. 边框清晰度分析
  24. border_clarity = self.analyze_border_clarity(table_image)
  25. features['border_clarity'] = border_clarity
  26. return features
  27. def get_table_type(self, image):
  28. """获取表格类型"""
  29. try:
  30. result = next(self.table_cls_model.predict(image))
  31. # 调试输出,查看实际的结果格式
  32. print(f"表格分类模型输出类型: {type(result).__name__}")
  33. # 根据实际输出格式调整
  34. if hasattr(result, 'keys') or isinstance(result, dict):
  35. # 处理TopkResult对象或字典
  36. # 标准的PaddleX输出格式
  37. if 'class_ids' in result and 'scores' in result and 'label_names' in result:
  38. scores = result['scores']
  39. label_names = result['label_names']
  40. # 找到最高分数的索引
  41. max_score_idx = np.argmax(scores)
  42. best_label = label_names[max_score_idx]
  43. best_score = scores[max_score_idx]
  44. print(f"分类结果: {best_label} (置信度: {best_score:.4f})")
  45. return best_label
  46. # 其他可能的格式处理...
  47. elif 'class_ids' in result:
  48. class_ids = result['class_ids']
  49. if hasattr(class_ids, '__len__') and len(class_ids) > 0:
  50. class_id = int(class_ids[0])
  51. else:
  52. class_id = int(class_ids)
  53. return 'wired_table' if class_id == 0 else 'wireless_table'
  54. elif 'label_names' in result:
  55. label_names = result['label_names']
  56. return label_names[0] if label_names else 'wired_table'
  57. # 传统的字段名
  58. elif 'label' in result:
  59. return result['label']
  60. elif 'class_name' in result:
  61. return result['class_name']
  62. elif 'prediction' in result:
  63. return result['prediction']
  64. else:
  65. # 默认返回第一个可用值
  66. first_key = list(result.keys())[0]
  67. return str(result[first_key])
  68. # 如果上述方法都失败,使用备用方法
  69. print("使用备用的线条检测方法判断表格类型")
  70. return self.detect_table_type_by_lines(image)
  71. except Exception as e:
  72. print(f"表格分类出错: {e},使用备用方法")
  73. return self.detect_table_type_by_lines(image)
  74. def detect_table_type_by_lines(self, image):
  75. """通过线条检测判断表格类型(备用方法)"""
  76. gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
  77. edges = cv2.Canny(gray, 50, 150)
  78. # 检测直线
  79. lines = cv2.HoughLines(edges, 1, np.pi/180, threshold=100)
  80. if lines is not None and len(lines) > 10:
  81. print("检测到较多直线,判断为有线表格")
  82. return 'wired_table'
  83. else:
  84. print("检测到较少直线,判断为无线表格")
  85. return 'wireless_table'
  86. def analyze_complexity(self, image):
  87. """分析表格复杂度"""
  88. h, w = image.shape[:2]
  89. # 检测线条密度
  90. gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
  91. edges = cv2.Canny(gray, 50, 150)
  92. line_density = np.sum(edges > 0) / (h * w)
  93. # 检测合并单元格(简化实现)
  94. merged_cells_ratio = self.detect_merged_cells(image)
  95. # 文本密度分析(简化实现)
  96. text_density = self.analyze_text_density(image)
  97. return {
  98. 'line_density': line_density,
  99. 'merged_cells_ratio': merged_cells_ratio,
  100. 'text_density': text_density,
  101. 'size_complexity': (h * w) / (1000 * 1000) # 图像尺寸复杂度
  102. }
  103. def detect_merged_cells(self, image):
  104. """检测合并单元格比例(简化实现)"""
  105. # 这里使用简化的启发式方法
  106. gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
  107. # 检测水平线
  108. horizontal_kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (40, 1))
  109. horizontal_lines = cv2.morphologyEx(gray, cv2.MORPH_OPEN, horizontal_kernel)
  110. # 检测垂直线
  111. vertical_kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (1, 40))
  112. vertical_lines = cv2.morphologyEx(gray, cv2.MORPH_OPEN, vertical_kernel)
  113. # 计算线条覆盖率作为合并单元格的指标
  114. h_coverage = np.sum(horizontal_lines > 0) / horizontal_lines.size
  115. v_coverage = np.sum(vertical_lines > 0) / vertical_lines.size
  116. # 简化的合并单元格比例估算
  117. merged_ratio = 1.0 - min(h_coverage, v_coverage) * 2
  118. return max(0.0, min(1.0, merged_ratio))
  119. def analyze_text_density(self, image):
  120. """分析文本密度(简化实现)"""
  121. gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
  122. # 使用简单的阈值化来估算文本区域
  123. _, binary = cv2.threshold(gray, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU)
  124. # 计算非空白像素比例作为文本密度
  125. text_pixels = np.sum(binary == 0) # 黑色像素(文本)
  126. total_pixels = binary.size
  127. return text_pixels / total_pixels
  128. def analyze_regularity(self, image):
  129. """分析表格结构规整度"""
  130. # 检测水平和垂直线条的规律性
  131. gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
  132. # 水平线检测
  133. horizontal_kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (40, 1))
  134. horizontal_lines = cv2.morphologyEx(gray, cv2.MORPH_OPEN, horizontal_kernel)
  135. # 垂直线检测
  136. vertical_kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (1, 40))
  137. vertical_lines = cv2.morphologyEx(gray, cv2.MORPH_OPEN, vertical_kernel)
  138. # 计算规整度分数
  139. h_regularity = self.calculate_line_regularity(horizontal_lines, axis='horizontal')
  140. v_regularity = self.calculate_line_regularity(vertical_lines, axis='vertical')
  141. return {
  142. 'horizontal_regularity': h_regularity,
  143. 'vertical_regularity': v_regularity,
  144. 'overall_regularity': (h_regularity + v_regularity) / 2
  145. }
  146. def calculate_line_regularity(self, lines_image, axis='horizontal'):
  147. """计算线条规整度"""
  148. if axis == 'horizontal':
  149. # 水平方向投影
  150. projection = np.sum(lines_image, axis=1)
  151. else:
  152. # 垂直方向投影
  153. projection = np.sum(lines_image, axis=0)
  154. # 找到投影峰值
  155. peaks = []
  156. threshold = np.max(projection) * 0.3
  157. for i in range(1, len(projection) - 1):
  158. if projection[i] > threshold and projection[i] > projection[i-1] and projection[i] > projection[i+1]:
  159. peaks.append(i)
  160. if len(peaks) < 2:
  161. return 0.5 # 默认中等规整度
  162. # 计算峰值间距的标准差
  163. intervals = [peaks[i+1] - peaks[i] for i in range(len(peaks)-1)]
  164. if len(intervals) == 0:
  165. return 0.5
  166. mean_interval = np.mean(intervals)
  167. std_interval = np.std(intervals)
  168. # 规整度 = 1 - (标准差 / 平均值),值越大越规整
  169. if mean_interval == 0:
  170. return 0.5
  171. regularity = 1.0 - min(1.0, std_interval / mean_interval)
  172. return max(0.0, regularity)
  173. def analyze_border_clarity(self, image):
  174. """分析边框清晰度"""
  175. gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
  176. # 使用Sobel算子检测边缘强度
  177. sobelx = cv2.Sobel(gray, cv2.CV_64F, 1, 0, ksize=3)
  178. sobely = cv2.Sobel(gray, cv2.CV_64F, 0, 1, ksize=3)
  179. edge_magnitude = np.sqrt(sobelx**2 + sobely**2)
  180. # 计算边缘清晰度分数
  181. clarity_score = np.mean(edge_magnitude) / 255.0
  182. return clarity_score
  183. class TableModeDecisionEngine:
  184. def __init__(self):
  185. self.rules = self.load_decision_rules()
  186. def load_decision_rules(self):
  187. """加载决策规则"""
  188. return {
  189. 'wired_html_mode': {
  190. 'conditions': [
  191. ('table_type', 'in', ['wired_table', 'wired', '0']), # 支持多种格式
  192. ('border_clarity', '>', 0.6),
  193. ('merged_cells_ratio', '>', 0.3),
  194. ('overall_regularity', '<', 0.7),
  195. ('size_complexity', '>', 0.5)
  196. ],
  197. 'weight': 0.9,
  198. 'description': '复杂有线表格,几何匹配更准确'
  199. },
  200. 'wired_e2e_mode': {
  201. 'conditions': [
  202. ('table_type', 'in', ['wired_table', 'wired', '0']),
  203. ('overall_regularity', '>', 0.8),
  204. ('merged_cells_ratio', '<', 0.2),
  205. ('text_density', '>', 0.3)
  206. ],
  207. 'weight': 0.8,
  208. 'description': '规整有线表格,端到端效果好'
  209. },
  210. 'wireless_e2e_mode': {
  211. 'conditions': [
  212. ('table_type', 'in', ['wireless_table', 'wireless', '1']),
  213. ('line_density', '<', 0.1),
  214. ('text_density', '>', 0.2)
  215. ],
  216. 'weight': 0.85,
  217. 'description': '无线表格,端到端预测最适合'
  218. },
  219. 'regular_mode': {
  220. 'conditions': [
  221. ('size_complexity', '>', 1.0),
  222. ('OR', [
  223. ('border_clarity', '<', 0.4),
  224. ('overall_regularity', '<', 0.5)
  225. ])
  226. ],
  227. 'weight': 0.7,
  228. 'description': '复杂场景,需要多模型协同'
  229. }
  230. }
  231. def check_single_condition(self, features, condition):
  232. """检查单个条件"""
  233. feature_name, operator, threshold = condition
  234. if feature_name not in features:
  235. return False
  236. value = features[feature_name]
  237. if operator == '>':
  238. return value > threshold
  239. elif operator == '<':
  240. return value < threshold
  241. elif operator == '==':
  242. return value == threshold
  243. elif operator == '>=':
  244. return value >= threshold
  245. elif operator == '<=':
  246. return value <= threshold
  247. elif operator == 'in':
  248. return value in threshold # threshold 是一个列表
  249. return False
  250. def evaluate_conditions(self, features, conditions):
  251. """评估条件是否满足"""
  252. score = 0
  253. total_conditions = 0
  254. for condition in conditions:
  255. if condition[0] == 'OR':
  256. # 处理OR条件
  257. or_satisfied = any(
  258. self.check_single_condition(features, sub_cond)
  259. for sub_cond in condition[1]
  260. )
  261. if or_satisfied:
  262. score += 1
  263. total_conditions += 1
  264. else:
  265. # 处理单个条件
  266. if self.check_single_condition(features, condition):
  267. score += 1
  268. total_conditions += 1
  269. return score / total_conditions if total_conditions > 0 else 0
  270. def select_best_mode(self, features):
  271. """选择最佳模式"""
  272. mode_scores = {}
  273. for mode_name, rule in self.rules.items():
  274. conditions_score = self.evaluate_conditions(features, rule['conditions'])
  275. final_score = conditions_score * rule['weight']
  276. mode_scores[mode_name] = {
  277. 'score': final_score,
  278. 'description': rule['description']
  279. }
  280. # 选择得分最高的模式
  281. best_mode = max(mode_scores.items(), key=lambda x: x[1]['score'])
  282. return best_mode[0], best_mode[1]
  283. class IntelligentTableProcessor:
  284. def __init__(self, config_path="./PP-StructureV3-zhch.yaml"):
  285. self.selector = TableModeSelector()
  286. self.decision_engine = TableModeDecisionEngine()
  287. # 暂时不初始化完整的pipeline,避免配置问题
  288. self.config_path = config_path
  289. self.pp_structure = None
  290. def execute_with_mode(self, image_path, mode, optimized_config=None):
  291. """根据选择的模式执行表格识别"""
  292. try:
  293. print(f"正在使用 {mode} 模式处理表格...")
  294. print(f"优化配置: {optimized_config}")
  295. # 创建动态配置的pipeline
  296. result = self.create_and_run_pipeline(image_path, mode, optimized_config)
  297. return result
  298. except Exception as e:
  299. print(f"执行 {mode} 模式时出错: {e}")
  300. print("回退到基础处理模式")
  301. return self.fallback_processing(image_path)
  302. def create_and_run_pipeline(self, image_path, mode, optimized_config):
  303. """创建并运行特定模式的pipeline"""
  304. if mode == 'wired_html_mode':
  305. return self.run_wired_html_mode(image_path, optimized_config)
  306. elif mode == 'wired_e2e_mode':
  307. return self.run_wired_e2e_mode(image_path, optimized_config)
  308. elif mode == 'wireless_e2e_mode':
  309. return self.run_wireless_e2e_mode(image_path, optimized_config)
  310. elif mode == 'regular_mode':
  311. return self.run_regular_mode(image_path, optimized_config)
  312. else:
  313. print(f"未知模式: {mode},使用默认处理")
  314. return self.fallback_processing(image_path)
  315. def run_wired_html_mode(self, image_path, config):
  316. """运行有线表格转HTML模式"""
  317. print("执行有线表格转HTML模式...")
  318. try:
  319. # 使用表格识别pipeline,启用HTML模式
  320. from paddlex import create_pipeline
  321. # 创建表格识别pipeline
  322. table_pipeline = create_pipeline(
  323. pipeline=self.config_path,
  324. model_dir=None
  325. )
  326. # 模拟配置HTML模式的参数
  327. # 注意:这里需要根据实际的PaddleX API调整
  328. result = list(table_pipeline.predict(
  329. image_path,
  330. use_wired_table_html_mode=True,
  331. use_wired_table_e2e_mode=False
  332. ))
  333. return self.format_result(result, mode='wired_html_mode')
  334. except Exception as e:
  335. print(f"有线表格HTML模式执行失败: {e}")
  336. return self.create_mock_result(mode='wired_html_mode')
  337. def run_wired_e2e_mode(self, image_path, config):
  338. """运行有线表格端到端模式"""
  339. print("执行有线表格端到端模式...")
  340. try:
  341. from paddlex import create_pipeline
  342. table_pipeline = create_pipeline(
  343. pipeline=self.config_path,
  344. model_dir=None
  345. )
  346. result = list(table_pipeline.predict(
  347. image_path,
  348. use_wired_table_html_mode=False,
  349. use_wired_table_e2e_mode=True
  350. ))
  351. return self.format_result(result, mode='wired_e2e_mode')
  352. except Exception as e:
  353. print(f"有线表格端到端模式执行失败: {e}")
  354. return self.create_mock_result(mode='wired_e2e_mode')
  355. def run_wireless_e2e_mode(self, image_path, config):
  356. """运行无线表格端到端模式"""
  357. print("执行无线表格端到端模式...")
  358. try:
  359. from paddlex import create_pipeline
  360. table_pipeline = create_pipeline(
  361. pipeline=self.config_path,
  362. model_dir=None
  363. )
  364. result = list(table_pipeline.predict(
  365. image_path,
  366. use_wireless_table_e2e_mode=True
  367. ))
  368. return self.format_result(result, mode='wireless_e2e_mode')
  369. except Exception as e:
  370. print(f"无线表格端到端模式执行失败: {e}")
  371. return self.create_mock_result(mode='wireless_e2e_mode')
  372. def run_regular_mode(self, image_path, config):
  373. """运行常规模式"""
  374. print("执行常规模式...")
  375. try:
  376. # 使用完整的PP-StructureV3 pipeline
  377. if self.pp_structure is None:
  378. from paddlex import create_pipeline
  379. self.pp_structure = create_pipeline(
  380. pipeline=self.config_path
  381. )
  382. result = list(self.pp_structure.predict(image_path))
  383. return self.format_result(result, mode='regular_mode')
  384. except Exception as e:
  385. print(f"常规模式执行失败: {e}")
  386. return self.create_mock_result(mode='regular_mode')
  387. def format_result(self, raw_result, mode):
  388. """格式化结果"""
  389. try:
  390. if not raw_result:
  391. return self.create_mock_result(mode)
  392. formatted_result = {
  393. 'mode': mode,
  394. 'status': 'success',
  395. 'raw_output': raw_result,
  396. 'table_count': 0,
  397. 'tables': []
  398. }
  399. # 提取表格结果 - 根据实际的PP-StructureV3输出结构
  400. for item in raw_result:
  401. print(f"处理结果项: {type(item)}")
  402. # 检查是否有table_res_list字段(PP-StructureV3的实际结构)
  403. if hasattr(item, 'table_res_list') or 'table_res_list' in item:
  404. table_res_list = item.get('table_res_list', getattr(item, 'table_res_list', []))
  405. if table_res_list and len(table_res_list) > 0:
  406. formatted_result['table_count'] = len(table_res_list)
  407. for i, table_item in enumerate(table_res_list):
  408. # 提取HTML内容
  409. html_content = table_item.get('pred_html', 'HTML不可用')
  410. # 提取表格区域信息
  411. table_region_id = table_item.get('table_region_id', i)
  412. # 尝试从cell_box_list获取bbox信息
  413. bbox = [0, 0, 100, 100] # 默认值
  414. if 'cell_box_list' in table_item and table_item['cell_box_list']:
  415. # 从单元格列表计算整体边界框
  416. bbox = self.calculate_table_bbox_from_cells(table_item['cell_box_list'])
  417. formatted_result['tables'].append({
  418. 'table_id': i,
  419. 'table_region_id': table_region_id,
  420. 'html': html_content,
  421. 'bbox': bbox,
  422. 'cell_count': len(table_item.get('cell_box_list', [])),
  423. 'neighbor_texts': table_item.get('neighbor_texts', '')
  424. })
  425. print(f"提取表格 {i}: region_id={table_region_id}, cells={len(table_item.get('cell_box_list', []))}")
  426. # 检查parsing_res_list(可能包含额外的表格信息)
  427. elif hasattr(item, 'parsing_res_list') or 'parsing_res_list' in item:
  428. parsing_res_list = item.get('parsing_res_list', getattr(item, 'parsing_res_list', []))
  429. for parsing_item in parsing_res_list:
  430. if hasattr(parsing_item, 'label') and parsing_item.label == 'table':
  431. # 这是一个表格解析结果
  432. formatted_result['table_count'] += 1
  433. html_content = getattr(parsing_item, 'html', 'HTML不可用')
  434. bbox = getattr(parsing_item, 'bbox', [0, 0, 100, 100])
  435. formatted_result['tables'].append({
  436. 'table_id': len(formatted_result['tables']),
  437. 'html': html_content,
  438. 'bbox': bbox,
  439. 'source': 'parsing_res'
  440. })
  441. # 兼容旧版本的table_recognition_res结构
  442. elif hasattr(item, 'table_recognition_res') or 'table_recognition_res' in item:
  443. table_res = item.get('table_recognition_res', getattr(item, 'table_recognition_res', None))
  444. if table_res and len(table_res) > 0:
  445. formatted_result['table_count'] = len(table_res)
  446. for i, table in enumerate(table_res):
  447. formatted_result['tables'].append({
  448. 'table_id': i,
  449. 'html': getattr(table, 'html', 'HTML不可用'),
  450. 'bbox': getattr(table, 'bbox', [0, 0, 100, 100])
  451. })
  452. return formatted_result
  453. except Exception as e:
  454. print(f"结果格式化失败: {e}")
  455. import traceback
  456. traceback.print_exc()
  457. return self.create_mock_result(mode)
  458. def calculate_table_bbox_from_cells(self, cell_box_list):
  459. """从单元格列表计算表格的整体边界框"""
  460. try:
  461. if not cell_box_list:
  462. return [0, 0, 100, 100]
  463. min_x = float('inf')
  464. min_y = float('inf')
  465. max_x = float('-inf')
  466. max_y = float('-inf')
  467. for cell in cell_box_list:
  468. # cell格式可能是 [x1, y1, x2, y2] 或其他格式
  469. if isinstance(cell, (list, tuple)) and len(cell) >= 4:
  470. x1, y1, x2, y2 = cell[:4]
  471. min_x = min(min_x, x1, x2)
  472. min_y = min(min_y, y1, y2)
  473. max_x = max(max_x, x1, x2)
  474. max_y = max(max_y, y1, y2)
  475. elif hasattr(cell, 'bbox'):
  476. bbox = cell.bbox
  477. if len(bbox) >= 4:
  478. x1, y1, x2, y2 = bbox[:4]
  479. min_x = min(min_x, x1, x2)
  480. min_y = min(min_y, y1, y2)
  481. max_x = max(max_x, x1, x2)
  482. max_y = max(max_y, y1, y2)
  483. if min_x == float('inf'):
  484. return [0, 0, 100, 100]
  485. return [int(min_x), int(min_y), int(max_x), int(max_y)]
  486. except Exception as e:
  487. print(f"计算表格边界框失败: {e}")
  488. return [0, 0, 100, 100]
  489. def create_mock_result(self, mode):
  490. """创建模拟结果(用于测试和错误回退)"""
  491. return {
  492. 'mode': mode,
  493. 'status': 'mock',
  494. 'message': f'{mode} 模式执行完成(模拟结果)',
  495. 'table_count': 1,
  496. 'tables': [{
  497. 'table_id': 0,
  498. 'html': f'<table><tr><td>模拟{mode}结果</td></tr></table>',
  499. 'bbox': [237, 201, 1416, 2044]
  500. }]
  501. }
  502. def fallback_processing(self, image_path):
  503. """回退处理方法"""
  504. print("使用基础OCR处理...")
  505. try:
  506. from paddlex import create_pipeline
  507. # 使用基础OCR pipeline
  508. ocr_pipeline = create_pipeline(pipeline="OCR")
  509. result = list(ocr_pipeline.predict(image_path))
  510. return {
  511. 'mode': 'fallback_ocr',
  512. 'status': 'success',
  513. 'raw_output': result,
  514. 'message': '使用基础OCR处理'
  515. }
  516. except Exception as e:
  517. print(f"回退处理也失败: {e}")
  518. return {
  519. 'mode': 'error',
  520. 'status': 'failed',
  521. 'message': f'所有处理方法都失败: {e}'
  522. }
  523. def extract_all_table_regions(self, image_path):
  524. """提取所有表格区域(如果有多个表格)"""
  525. original_image = cv2.imread(image_path)
  526. layout_results = list(self.selector.layout_model.predict(image_path))
  527. all_tables = []
  528. for layout_result in layout_results:
  529. for i, box_info in enumerate(layout_result['boxes']):
  530. if box_info['label'] == 'table':
  531. coordinate = box_info['coordinate']
  532. x1, y1, x2, y2 = [int(coord) for coord in coordinate]
  533. table_image = original_image[y1:y2, x1:x2]
  534. table_info = {
  535. 'table_id': i,
  536. 'image': table_image,
  537. 'bbox': [x1, y1, x2, y2],
  538. 'score': box_info['score']
  539. }
  540. all_tables.append(table_info)
  541. # 保存每个表格区域
  542. cv2.imwrite(f'./debug_table_{i}.jpg', table_image)
  543. print(f"表格 {i}: bbox=[{x1}, {y1}, {x2}, {y2}], score={box_info['score']:.4f}")
  544. return all_tables
  545. def extract_table_region(self, image_path):
  546. """从图像中提取表格区域"""
  547. # 读取原图
  548. original_image = cv2.imread(image_path)
  549. # 使用layout模型检测版面
  550. layout_results = list(self.selector.layout_model.predict(image_path))
  551. table_regions = []
  552. for layout_result in layout_results:
  553. # 遍历检测到的所有区域
  554. for box_info in layout_result['boxes']:
  555. if box_info['label'] == 'table':
  556. # 提取表格坐标
  557. coordinate = box_info['coordinate']
  558. x1, y1, x2, y2 = [int(coord) for coord in coordinate]
  559. # 裁剪表格区域
  560. table_image = original_image[y1:y2, x1:x2]
  561. table_regions.append({
  562. 'image': table_image,
  563. 'bbox': [x1, y1, x2, y2],
  564. 'score': box_info['score']
  565. })
  566. print(f"检测到表格区域: bbox=[{x1}, {y1}, {x2}, {y2}], score={box_info['score']:.4f}")
  567. if len(table_regions) == 0:
  568. print("未检测到表格区域,使用整个图像")
  569. return original_image
  570. # 返回得分最高的表格区域
  571. best_table = max(table_regions, key=lambda x: x['score'])
  572. return best_table['image']
  573. def process_table_intelligently(self, image_path, use_layout_model=True):
  574. """智能处理表格"""
  575. try:
  576. # 1. 提取表格区域
  577. if use_layout_model:
  578. table_image = self.extract_table_region(image_path)
  579. else:
  580. table_image = cv2.imread(image_path)
  581. if table_image is None or table_image.size == 0:
  582. print("表格区域提取失败,使用原图")
  583. table_image = cv2.imread(image_path)
  584. # 保存表格区域用于调试
  585. cv2.imwrite('./debug_table_region.jpg', table_image)
  586. print(f"表格区域已保存到: ./debug_table_region.jpg")
  587. print(f"表格区域尺寸: {table_image.shape}")
  588. # 2. 分析表格特征
  589. features = self.selector.analyze_table_features(table_image)
  590. print(f"表格特征分析: {features}")
  591. # 3. 选择最佳模式
  592. best_mode, mode_info = self.decision_engine.select_best_mode(features)
  593. print(f"选择模式: {best_mode}, 分数: {mode_info['score']:.3f}")
  594. # # 4. 动态调整配置
  595. # optimized_config = self.optimize_config_for_mode(best_mode, features)
  596. # print(f"优化配置: {optimized_config}")
  597. # 5. 执行处理
  598. # result = self.execute_with_mode(image_path, best_mode, optimized_config=None)
  599. result = self.execute_with_mode(table_image, best_mode, optimized_config=None)
  600. return {
  601. 'result': result,
  602. 'selected_mode': best_mode,
  603. 'mode_description': mode_info['description'],
  604. 'confidence_score': mode_info['score'],
  605. 'table_features': features,
  606. 'table_region_shape': table_image.shape
  607. }
  608. except Exception as e:
  609. print(f"智能处理过程出错: {e}")
  610. import traceback
  611. traceback.print_exc()
  612. # 返回错误信息
  613. return {
  614. 'result': None,
  615. 'selected_mode': 'error',
  616. 'mode_description': f'处理失败: {e}',
  617. 'confidence_score': 0.0,
  618. 'table_features': {},
  619. 'error': str(e)
  620. }
  621. # 修改demo函数,更好地处理结果
  622. def demo_intelligent_table_processing():
  623. """演示智能表格处理"""
  624. try:
  625. processor = IntelligentTableProcessor("./PP-StructureV3-zhch.yaml")
  626. # 处理您之前的复杂财务表格
  627. result = processor.process_table_intelligently(
  628. "./sample_data/600916_中国黄金_2002年报_83_94_2.png",
  629. use_layout_model=True
  630. )
  631. print("\n" + "="*50)
  632. print("智能表格处理结果:")
  633. print("="*50)
  634. print(f"选择的模式: {result['selected_mode']}")
  635. print(f"选择原因: {result['mode_description']}")
  636. print(f"置信度分数: {result['confidence_score']:.3f}")
  637. if 'table_region_shape' in result:
  638. print(f"表格区域尺寸: {result['table_region_shape']}")
  639. print(f"\n表格特征分析:")
  640. for key, value in result.get('table_features', {}).items():
  641. if isinstance(value, float):
  642. print(f" {key}: {value:.4f}")
  643. else:
  644. print(f" {key}: {value}")
  645. # 处理结果
  646. if result['result']:
  647. process_result = result['result']
  648. print(f"\n处理结果:")
  649. print(f" 模式: {process_result.get('mode', 'unknown')}")
  650. print(f" 状态: {process_result.get('status', 'unknown')}")
  651. print(f" 表格数量: {process_result.get('table_count', 0)}")
  652. if process_result.get('tables'):
  653. for i, table in enumerate(process_result['tables']):
  654. print(f"\n 表格 {i}:")
  655. print(f" bbox: {table.get('bbox', 'N/A')}")
  656. print(f" 单元格数量: {table.get('cell_count', 'N/A')}")
  657. print(f" 区域ID: {table.get('table_region_id', 'N/A')}")
  658. html_content = table.get('html', '')
  659. if len(html_content) > 200:
  660. html_preview = html_content[:200] + "..."
  661. else:
  662. html_preview = html_content
  663. print(f" HTML预览: {html_preview}")
  664. # 保存完整HTML到文件
  665. html_filename = f"./table_{i}_result.html"
  666. try:
  667. with open(html_filename, 'w', encoding='utf-8') as f:
  668. f.write(html_content)
  669. print(f" 完整HTML已保存到: {html_filename}")
  670. except Exception as e:
  671. print(f" 保存HTML失败: {e}")
  672. # 根据置信度给出建议
  673. if result['confidence_score'] > 0.8:
  674. print("\n✅ 高置信度,推荐使用该模式")
  675. elif result['confidence_score'] > 0.6:
  676. print("\n⚠️ 中等置信度,可能需要人工验证")
  677. else:
  678. print("\n❌ 低置信度,建议尝试其他模式或人工处理")
  679. return result
  680. except Exception as e:
  681. print(f"演示程序出错: {e}")
  682. import traceback
  683. traceback.print_exc()
  684. return None
  685. if __name__ == "__main__":
  686. demo_intelligent_table_processing()