Răsfoiți Sursa

feat: 新增表格识别管线,支持将HTML表格转换为Markdown并保存结果

zhch158_admin 1 lună în urmă
părinte
comite
7b2c02ea44
1 a modificat fișierele cu 12 adăugiri și 10 ștergeri
  1. 12 10
      zhch/table_recognition_v2_single_process.py

+ 12 - 10
zhch/ table_recognition_v2_single_process.py → zhch/table_recognition_v2_single_process.py

@@ -37,11 +37,11 @@ def html_table_to_markdown(html: str) -> str:
     html = re.sub(r'</?body>', '', html, flags=re.IGNORECASE)
     html = html.strip()
     return html
- 
-def save_json_tables(json_res: Dict[str, Any], output_dir: str, base_name: str
-                     ) -> Tuple[str, List[str]]:
+
+def save_json_tables(json_res: Dict[str, Any], output_dir: str, base_name: str,
+                     normalize_numbers: bool = True) -> Tuple[str, List[str], int]:
     if not json_res:
-        return "", []
+        return "", [], 0
 
     # 从 table_res_list 中取出 pred_html 转为 Markdown
     table_list = (json_res or {}).get("table_res_list", []) or []
@@ -52,7 +52,7 @@ def save_json_tables(json_res: Dict[str, Any], output_dir: str, base_name: str
         if not html:
             continue
                 # 2. 标准化 table_res_list 中的HTML表格
-        else:
+        elif normalize_numbers:
             normalized_html = normalize_markdown_table(html)
             
             if html != normalized_html:
@@ -67,7 +67,7 @@ def save_json_tables(json_res: Dict[str, Any], output_dir: str, base_name: str
     with open(json_fp, "w", encoding="utf-8") as f:
         json.dump(json_res, f, ensure_ascii=False, indent=2)
 
-    return json_fp.as_posix(), md_tables
+    return json_fp.as_posix(), md_tables, changes_count
 
 def save_markdown_tables(md_tables: List[str], output_dir: str, base_name: str,
                          normalize_numbers: bool = True) -> str:
@@ -121,7 +121,6 @@ def process_images_with_table_pipeline(
     with tqdm(total=total, desc="Processing images", unit="img",
               bar_format='{l_bar}{bar}| {n_fmt}/{total_fmt} [{elapsed}<{remaining}, {rate_fmt}]') as pbar:
         for img_path in image_paths:
-            img_path = str(img_path)
             start = time.time()
             try:
                 outputs = pipeline.predict(
@@ -144,7 +143,7 @@ def process_images_with_table_pipeline(
                     # 保存结构化JSON
                     json_res = res.json.get("res", res.json)
 
-                    saved_json, md_tables = save_json_tables(json_res, str(output_path), base_name)
+                    saved_json, md_tables, changes_count = save_json_tables(json_res, str(output_path), base_name, normalize_numbers=normalize_numbers)
     
                     saved_md = save_markdown_tables(md_tables, str(output_path), base_name,
                                                      normalize_numbers=normalize_numbers)
@@ -157,7 +156,10 @@ def process_images_with_table_pipeline(
                         "json_path": saved_json,
                         "markdown_path": saved_md,
                         "tables_detected": len(md_tables),
-                        "is_pdf_page": "_page_" in input_path.name
+                        "is_pdf_page": "_page_" in input_path.name,
+                        "normalize_numbers": normalize_numbers,
+                        "changes_applied": changes_count > 0,
+                        "character_changes_count": changes_count,
                     })
 
                 pbar.update(1)
@@ -285,7 +287,7 @@ if __name__ == "__main__":
     if len(sys.argv) == 1:
         # 演示默认参数(请按需修改)
         demo = {
-            "--input_file": "/Users/zhch158/workspace/data/流水分析/A用户_单元格扫描流水.img",
+            "--input_dir": "/Users/zhch158/workspace/data/流水分析/A用户_单元格扫描流水.img",
             "--output_dir": "/Users/zhch158/workspace/data/流水分析/A用户_单元格扫描流水/table_recognition_v2_Results",
             "--pipeline": "./my_config/table_recognition_v2.yaml",
             "--device": "cpu",