|
|
@@ -25,6 +25,9 @@ from ppstructurev3_utils import (
|
|
|
)
|
|
|
from utils import normalize_markdown_table, get_input_files
|
|
|
|
|
|
+# 🎯 新增:导入适配器
|
|
|
+from adapters.table_recognition_adapter import apply_table_recognition_adapter, restore_original_function
|
|
|
+
|
|
|
def html_table_to_markdown(html: str) -> str:
|
|
|
"""
|
|
|
将简单HTML表格转换为Markdown表格。
|
|
|
@@ -94,7 +97,8 @@ def process_images_with_table_pipeline(
|
|
|
pipeline_cfg: str = "./my_config/table_recognition_v2.yaml",
|
|
|
device: str = "gpu:0",
|
|
|
output_dir: str = "./output",
|
|
|
- normalize_numbers: bool = True
|
|
|
+ normalize_numbers: bool = True,
|
|
|
+ use_enhanced_adapter: bool = True # 🎯 新增参数
|
|
|
) -> List[Dict[str, Any]]:
|
|
|
"""
|
|
|
运行 table_recognition_v2 管线,输出 JSON、可视化图,且将每个表格HTML转为Markdown保存。
|
|
|
@@ -102,6 +106,15 @@ def process_images_with_table_pipeline(
|
|
|
output_path = Path(output_dir).resolve()
|
|
|
output_path.mkdir(parents=True, exist_ok=True)
|
|
|
|
|
|
+ # 🎯 应用适配器
|
|
|
+ adapter_applied = False
|
|
|
+ if use_enhanced_adapter:
|
|
|
+ adapter_applied = apply_table_recognition_adapter()
|
|
|
+ if adapter_applied:
|
|
|
+ print("🎯 Enhanced table recognition adapter activated")
|
|
|
+ else:
|
|
|
+ print("⚠️ Failed to apply adapter, using original implementation")
|
|
|
+
|
|
|
print(f"Initializing pipeline '{pipeline_cfg}' on device '{device}'...")
|
|
|
try:
|
|
|
os.environ['PYTHONWARNINGS'] = 'ignore::UserWarning'
|
|
|
@@ -109,82 +122,89 @@ def process_images_with_table_pipeline(
|
|
|
print(f"Pipeline initialized successfully on {device}")
|
|
|
except Exception as e:
|
|
|
print(f"Failed to initialize pipeline: {e}", file=sys.stderr)
|
|
|
- traceback.print_exc()
|
|
|
+ if adapter_applied:
|
|
|
+ restore_original_function()
|
|
|
return []
|
|
|
|
|
|
- results_all: List[Dict[str, Any]] = []
|
|
|
- total = len(image_paths)
|
|
|
- print(f"Processing {total} images with table_recognition_v2")
|
|
|
- print(f"🔧 数字标准化: {'启用' if normalize_numbers else '禁用'}")
|
|
|
-
|
|
|
- with tqdm(total=total, desc="Processing images", unit="img",
|
|
|
- bar_format='{l_bar}{bar}| {n_fmt}/{total_fmt} [{elapsed}<{remaining}, {rate_fmt}]') as pbar:
|
|
|
- for img_path in image_paths:
|
|
|
- start = time.time()
|
|
|
- try:
|
|
|
- outputs = pipeline.predict(
|
|
|
- img_path,
|
|
|
- use_doc_orientation_classify=True,
|
|
|
- use_doc_unwarping=False,
|
|
|
- use_layout_detection=True,
|
|
|
- use_ocr_results_with_table_cells=True,
|
|
|
- use_table_orientation_classify=True,
|
|
|
- use_wired_table_cells_trans_to_html=True,
|
|
|
- # 新增:关闭单元格内拆分,整格识别以保留折行文本,
|
|
|
- # 修改paddlex/inference/pipelines/table_recognition/pipeline_v2.py
|
|
|
- # get_table_recognition_res传入参数self.cells_split_ocr=False,保证单元格内换行不被拆分
|
|
|
- use_table_cells_split_ocr=False,
|
|
|
- )
|
|
|
- cost = time.time() - start
|
|
|
-
|
|
|
- # 一般每张图片只返回一个结果
|
|
|
- for idx, res in enumerate(outputs):
|
|
|
- if idx > 0:
|
|
|
- raise ValueError("Multiple results found for a single image")
|
|
|
-
|
|
|
- input_path = Path(res["input_path"])
|
|
|
- base_name = input_path.stem
|
|
|
-
|
|
|
- res.save_all(save_path=output_path.as_posix()) # 保存所有结果到指定路径
|
|
|
- # 保存结构化JSON
|
|
|
- json_res = res.json.get("res", res.json)
|
|
|
-
|
|
|
- saved_json, md_tables, changes_count = save_json_tables(json_res, str(output_path), base_name, normalize_numbers=normalize_numbers)
|
|
|
+ try:
|
|
|
+ results_all: List[Dict[str, Any]] = []
|
|
|
+ total = len(image_paths)
|
|
|
+ print(f"Processing {total} images with table_recognition_v2")
|
|
|
+ print(f"🔧 数字标准化: {'启用' if normalize_numbers else '禁用'}")
|
|
|
+ print(f"🎯 增强适配器: {'启用' if adapter_applied else '禁用'}")
|
|
|
+
|
|
|
+ with tqdm(total=total, desc="Processing images", unit="img",
|
|
|
+ bar_format='{l_bar}{bar}| {n_fmt}/{total_fmt} [{elapsed}<{remaining}, {rate_fmt}]') as pbar:
|
|
|
+ for img_path in image_paths:
|
|
|
+ start = time.time()
|
|
|
+ try:
|
|
|
+ outputs = pipeline.predict(
|
|
|
+ img_path,
|
|
|
+ use_doc_orientation_classify=True,
|
|
|
+ use_doc_unwarping=False,
|
|
|
+ use_layout_detection=True,
|
|
|
+ use_ocr_results_with_table_cells=True,
|
|
|
+ use_table_orientation_classify=True,
|
|
|
+ use_wired_table_cells_trans_to_html=True,
|
|
|
+ # 🎯 注意:适配器模式下不需要这个参数
|
|
|
+ # use_table_cells_split_ocr=False,
|
|
|
+ )
|
|
|
+ cost = time.time() - start
|
|
|
+
|
|
|
+ # 一般每张图片只返回一个结果
|
|
|
+ for idx, res in enumerate(outputs):
|
|
|
+ if idx > 0:
|
|
|
+ raise ValueError("Multiple results found for a single image")
|
|
|
+
|
|
|
+ input_path = Path(res["input_path"])
|
|
|
+ base_name = input_path.stem
|
|
|
+
|
|
|
+ res.save_all(save_path=output_path.as_posix()) # 保存所有结果到指定路径
|
|
|
+ # 保存结构化JSON
|
|
|
+ json_res = res.json.get("res", res.json)
|
|
|
+
|
|
|
+ saved_json, md_tables, changes_count = save_json_tables(json_res, str(output_path), base_name, normalize_numbers=normalize_numbers)
|
|
|
|
|
|
- saved_md = save_markdown_tables(md_tables, str(output_path), base_name,
|
|
|
+ saved_md = save_markdown_tables(md_tables, str(output_path), base_name,
|
|
|
normalize_numbers=normalize_numbers)
|
|
|
|
|
|
+ results_all.append({
|
|
|
+ "image_path": str(input_path),
|
|
|
+ "success": True,
|
|
|
+ "time_sec": cost,
|
|
|
+ "device": device,
|
|
|
+ "json_path": saved_json,
|
|
|
+ "markdown_path": saved_md,
|
|
|
+ "tables_detected": len(md_tables),
|
|
|
+ "is_pdf_page": "_page_" in input_path.name,
|
|
|
+ "normalize_numbers": normalize_numbers,
|
|
|
+ "changes_applied": changes_count > 0,
|
|
|
+ "character_changes_count": changes_count,
|
|
|
+ })
|
|
|
+
|
|
|
+ pbar.update(1)
|
|
|
+ ok = sum(1 for r in results_all if r.get("success"))
|
|
|
+ pbar.set_postfix(time=f"{cost:.2f}s", ok=ok)
|
|
|
+
|
|
|
+ except Exception as e:
|
|
|
+ traceback.print_exc()
|
|
|
results_all.append({
|
|
|
- "image_path": str(input_path),
|
|
|
- "success": True,
|
|
|
- "time_sec": cost,
|
|
|
+ "image_path": str(img_path),
|
|
|
+ "success": False,
|
|
|
+ "time_sec": 0,
|
|
|
"device": device,
|
|
|
- "json_path": saved_json,
|
|
|
- "markdown_path": saved_md,
|
|
|
- "tables_detected": len(md_tables),
|
|
|
- "is_pdf_page": "_page_" in input_path.name,
|
|
|
- "normalize_numbers": normalize_numbers,
|
|
|
- "changes_applied": changes_count > 0,
|
|
|
- "character_changes_count": changes_count,
|
|
|
+ "error": str(e)
|
|
|
})
|
|
|
+ pbar.update(1)
|
|
|
+ pbar.set_postfix_str("error")
|
|
|
|
|
|
- pbar.update(1)
|
|
|
- ok = sum(1 for r in results_all if r.get("success"))
|
|
|
- pbar.set_postfix(time=f"{cost:.2f}s", ok=ok)
|
|
|
-
|
|
|
- except Exception as e:
|
|
|
- traceback.print_exc()
|
|
|
- results_all.append({
|
|
|
- "image_path": str(img_path),
|
|
|
- "success": False,
|
|
|
- "time_sec": 0,
|
|
|
- "device": device,
|
|
|
- "error": str(e)
|
|
|
- })
|
|
|
- pbar.update(1)
|
|
|
- pbar.set_postfix_str("error")
|
|
|
-
|
|
|
- return results_all
|
|
|
+ return results_all
|
|
|
+
|
|
|
+ finally:
|
|
|
+ # 🎯 清理:恢复原始函数
|
|
|
+ if adapter_applied:
|
|
|
+ restore_original_function()
|
|
|
+ print("🔄 Original function restored")
|
|
|
|
|
|
def main():
|
|
|
parser = argparse.ArgumentParser(description="table_recognition_v2 单管线运行(输出Markdown表格)")
|
|
|
@@ -202,9 +222,11 @@ def main():
|
|
|
parser.add_argument("--no-normalize", action="store_true", help="禁用数字标准化(仅对Markdown内容生效)")
|
|
|
parser.add_argument("--test_mode", action="store_true", help="仅处理前20个文件")
|
|
|
parser.add_argument("--collect_results", type=str, help="将处理状态收集到指定CSV")
|
|
|
-
|
|
|
+ parser.add_argument("--no-adapter", action="store_true", help="禁用增强适配器")
|
|
|
+
|
|
|
args = parser.parse_args()
|
|
|
normalize_numbers = not args.no_normalize
|
|
|
+ use_enhanced_adapter = not args.no_adapter
|
|
|
|
|
|
# 复用 ppstructurev3_utils 的输入收集逻辑
|
|
|
input_files = get_input_files(args)
|
|
|
@@ -222,7 +244,8 @@ def main():
|
|
|
args.pipeline,
|
|
|
args.device,
|
|
|
args.output_dir,
|
|
|
- normalize_numbers=normalize_numbers
|
|
|
+ normalize_numbers=normalize_numbers,
|
|
|
+ use_enhanced_adapter=use_enhanced_adapter # 🎯 传递参数
|
|
|
)
|
|
|
total_time = time.time() - start
|
|
|
|