|
|
@@ -34,11 +34,16 @@ from ppstructurev3_utils import (
|
|
|
save_markdown_content
|
|
|
)
|
|
|
|
|
|
+# 🎯 新增:导入适配器
|
|
|
+from adapters import apply_table_recognition_adapter, restore_original_function
|
|
|
+
|
|
|
+
|
|
|
def process_images_unified(image_paths: List[str],
|
|
|
pipeline_name: str = "PP-StructureV3",
|
|
|
device: str = "gpu:0",
|
|
|
output_dir: str = "./output",
|
|
|
- normalize_numbers: bool = True) -> List[Dict[str, Any]]:
|
|
|
+ normalize_numbers: bool = True,
|
|
|
+ use_enhanced_adapter: bool = True) -> List[Dict[str, Any]]: # 🎯 新增参数
|
|
|
"""
|
|
|
统一的图像处理函数,支持数字标准化
|
|
|
"""
|
|
|
@@ -46,6 +51,15 @@ def process_images_unified(image_paths: List[str],
|
|
|
output_path = Path(output_dir)
|
|
|
output_path.mkdir(parents=True, exist_ok=True)
|
|
|
|
|
|
+ # 🎯 应用适配器
|
|
|
+ adapter_applied = False
|
|
|
+ if use_enhanced_adapter:
|
|
|
+ adapter_applied = apply_table_recognition_adapter()
|
|
|
+ if adapter_applied:
|
|
|
+ print("🎯 Enhanced table recognition adapter activated")
|
|
|
+ else:
|
|
|
+ print("⚠️ Failed to apply adapter, using original implementation")
|
|
|
+
|
|
|
print(f"Initializing pipeline '{pipeline_name}' on device '{device}'...")
|
|
|
|
|
|
try:
|
|
|
@@ -59,122 +73,138 @@ def process_images_unified(image_paths: List[str],
|
|
|
except Exception as e:
|
|
|
print(f"Failed to initialize pipeline: {e}", file=sys.stderr)
|
|
|
traceback.print_exc()
|
|
|
+ if adapter_applied:
|
|
|
+ restore_original_function()
|
|
|
return []
|
|
|
|
|
|
- all_results = []
|
|
|
- total_images = len(image_paths)
|
|
|
-
|
|
|
- print(f"Processing {total_images} images one by one")
|
|
|
- print(f"🔧 数字标准化: {'启用' if normalize_numbers else '禁用'}")
|
|
|
-
|
|
|
- # 使用tqdm显示进度
|
|
|
- with tqdm(total=total_images, desc="Processing images", unit="img",
|
|
|
- bar_format='{l_bar}{bar}| {n_fmt}/{total_fmt} [{elapsed}<{remaining}, {rate_fmt}]') as pbar:
|
|
|
+ try:
|
|
|
+ all_results = []
|
|
|
+ total_images = len(image_paths)
|
|
|
+
|
|
|
+ print(f"Processing {total_images} images one by one")
|
|
|
+ print(f"🔧 数字标准化: {'启用' if normalize_numbers else '禁用'}")
|
|
|
+ print(f"🎯 增强适配器: {'启用' if adapter_applied else '禁用'}")
|
|
|
|
|
|
- # 逐个处理图像
|
|
|
- for img_path in image_paths:
|
|
|
- start_time = time.time()
|
|
|
+ # 使用tqdm显示进度
|
|
|
+ with tqdm(total=total_images, desc="Processing images", unit="img",
|
|
|
+ bar_format='{l_bar}{bar}| {n_fmt}/{total_fmt} [{elapsed}<{remaining}, {rate_fmt}]') as pbar:
|
|
|
|
|
|
- try:
|
|
|
- # 使用pipeline预测单个图像
|
|
|
- results = pipeline.predict(
|
|
|
- img_path,
|
|
|
- use_doc_orientation_classify=True,
|
|
|
- use_doc_unwarping=False,
|
|
|
- use_seal_recognition=True,
|
|
|
- use_table_recognition=True,
|
|
|
- use_formula_recognition=False,
|
|
|
- use_chart_recognition=True,
|
|
|
- )
|
|
|
+ # 逐个处理图像
|
|
|
+ for img_path in image_paths:
|
|
|
+ start_time = time.time()
|
|
|
|
|
|
- processing_time = time.time() - start_time
|
|
|
-
|
|
|
- # 处理结果
|
|
|
- for idx, result in enumerate(results):
|
|
|
- if idx > 0:
|
|
|
- raise ValueError("Multiple results found for a single image")
|
|
|
- try:
|
|
|
- input_path = Path(result["input_path"])
|
|
|
-
|
|
|
- # 生成输出文件名
|
|
|
- if result.get("page_index") is not None:
|
|
|
- output_filename = f"{input_path.stem}_{result['page_index']}"
|
|
|
- else:
|
|
|
- output_filename = f"{input_path.stem}"
|
|
|
-
|
|
|
- # 转换并保存标准JSON格式
|
|
|
- json_content = result.json['res']
|
|
|
- json_output_path, converted_json = convert_pruned_result_to_json(
|
|
|
- json_content,
|
|
|
- str(input_path),
|
|
|
- output_dir,
|
|
|
- output_filename,
|
|
|
- normalize_numbers=normalize_numbers
|
|
|
- )
|
|
|
+ try:
|
|
|
+ # 使用pipeline预测单个图像
|
|
|
+ results = pipeline.predict(
|
|
|
+ img_path,
|
|
|
+ use_doc_orientation_classify=False,
|
|
|
+ use_doc_unwarping=False,
|
|
|
+ use_layout_detection=True,
|
|
|
+ use_seal_recognition=True,
|
|
|
+ use_table_recognition=True,
|
|
|
+ use_formula_recognition=False,
|
|
|
+ use_chart_recognition=True,
|
|
|
+ use_ocr_results_with_table_cells=True,
|
|
|
+ use_table_orientation_classify=True,
|
|
|
+ use_wired_table_cells_trans_to_html=True,
|
|
|
+ use_wireless_table_cells_trans_to_html=True,
|
|
|
+ )
|
|
|
+
|
|
|
+ processing_time = time.time() - start_time
|
|
|
+
|
|
|
+ # 处理结果
|
|
|
+ for idx, result in enumerate(results):
|
|
|
+ if idx > 0:
|
|
|
+ raise ValueError("Multiple results found for a single image")
|
|
|
+ try:
|
|
|
+ input_path = Path(result["input_path"])
|
|
|
+
|
|
|
+ # 生成输出文件名
|
|
|
+ if result.get("page_index") is not None:
|
|
|
+ output_filename = f"{input_path.stem}_{result['page_index']}"
|
|
|
+ else:
|
|
|
+ output_filename = f"{input_path.stem}"
|
|
|
+
|
|
|
+ # 转换并保存标准JSON格式
|
|
|
+ json_content = result.json['res']
|
|
|
+ json_output_path, converted_json = convert_pruned_result_to_json(
|
|
|
+ json_content,
|
|
|
+ str(input_path),
|
|
|
+ output_dir,
|
|
|
+ output_filename,
|
|
|
+ normalize_numbers=normalize_numbers
|
|
|
+ )
|
|
|
|
|
|
- # 保存输出图像
|
|
|
- img_content = result.img
|
|
|
- saved_images = save_output_images(img_content, str(output_dir), output_filename)
|
|
|
+ # 保存输出图像
|
|
|
+ img_content = result.img
|
|
|
+ saved_images = save_output_images(img_content, str(output_dir), output_filename)
|
|
|
|
|
|
- # 保存Markdown内容
|
|
|
- markdown_content = result.markdown
|
|
|
- md_output_path = save_markdown_content(
|
|
|
- markdown_content,
|
|
|
- output_dir,
|
|
|
- output_filename,
|
|
|
- normalize_numbers=normalize_numbers,
|
|
|
- key_text='markdown_texts',
|
|
|
- key_images='markdown_images'
|
|
|
- )
|
|
|
-
|
|
|
- # 记录处理结果
|
|
|
- all_results.append({
|
|
|
- "image_path": str(input_path),
|
|
|
- "processing_time": processing_time,
|
|
|
- "success": True,
|
|
|
- "device": device,
|
|
|
- "output_json": json_output_path,
|
|
|
- "output_md": md_output_path,
|
|
|
- "is_pdf_page": "_page_" in input_path.name, # 标记是否为PDF页面
|
|
|
- "processing_info": converted_json.get('processing_info', {})
|
|
|
- })
|
|
|
+ # 保存Markdown内容
|
|
|
+ markdown_content = result.markdown
|
|
|
+ md_output_path = save_markdown_content(
|
|
|
+ markdown_content,
|
|
|
+ output_dir,
|
|
|
+ output_filename,
|
|
|
+ normalize_numbers=normalize_numbers,
|
|
|
+ key_text='markdown_texts',
|
|
|
+ key_images='markdown_images'
|
|
|
+ )
|
|
|
+
|
|
|
+ # 记录处理结果
|
|
|
+ all_results.append({
|
|
|
+ "image_path": str(input_path),
|
|
|
+ "processing_time": processing_time,
|
|
|
+ "success": True,
|
|
|
+ "device": device,
|
|
|
+ "output_json": json_output_path,
|
|
|
+ "output_md": md_output_path,
|
|
|
+ "is_pdf_page": "_page_" in input_path.name,
|
|
|
+ "processing_info": converted_json.get('processing_info', {})
|
|
|
+ })
|
|
|
|
|
|
- except Exception as e:
|
|
|
- print(f"Error saving result for {result.get('input_path', 'unknown')}: {e}", file=sys.stderr)
|
|
|
- traceback.print_exc()
|
|
|
- all_results.append({
|
|
|
- "image_path": str(img_path),
|
|
|
- "processing_time": 0,
|
|
|
- "success": False,
|
|
|
- "device": device,
|
|
|
- "error": str(e)
|
|
|
- })
|
|
|
-
|
|
|
- # 更新进度条
|
|
|
- success_count = sum(1 for r in all_results if r.get('success', False))
|
|
|
-
|
|
|
- pbar.update(1)
|
|
|
- pbar.set_postfix({
|
|
|
- 'time': f"{processing_time:.2f}s",
|
|
|
- 'success': f"{success_count}/{len(all_results)}",
|
|
|
- 'rate': f"{success_count/len(all_results)*100:.1f}%"
|
|
|
- })
|
|
|
-
|
|
|
- except Exception as e:
|
|
|
- print(f"Error processing {Path(img_path).name}: {e}", file=sys.stderr)
|
|
|
- traceback.print_exc()
|
|
|
-
|
|
|
- # 添加错误结果
|
|
|
- all_results.append({
|
|
|
- "image_path": str(img_path),
|
|
|
- "processing_time": 0,
|
|
|
- "success": False,
|
|
|
- "device": device,
|
|
|
- "error": str(e)
|
|
|
- })
|
|
|
- pbar.update(1)
|
|
|
+ except Exception as e:
|
|
|
+ print(f"Error saving result for {result.get('input_path', 'unknown')}: {e}", file=sys.stderr)
|
|
|
+ traceback.print_exc()
|
|
|
+ all_results.append({
|
|
|
+ "image_path": str(img_path),
|
|
|
+ "processing_time": 0,
|
|
|
+ "success": False,
|
|
|
+ "device": device,
|
|
|
+ "error": str(e)
|
|
|
+ })
|
|
|
+
|
|
|
+ # 更新进度条
|
|
|
+ success_count = sum(1 for r in all_results if r.get('success', False))
|
|
|
+
|
|
|
+ pbar.update(1)
|
|
|
+ pbar.set_postfix({
|
|
|
+ 'time': f"{processing_time:.2f}s",
|
|
|
+ 'success': f"{success_count}/{len(all_results)}",
|
|
|
+ 'rate': f"{success_count/len(all_results)*100:.1f}%"
|
|
|
+ })
|
|
|
+
|
|
|
+ except Exception as e:
|
|
|
+ print(f"Error processing {Path(img_path).name}: {e}", file=sys.stderr)
|
|
|
+ traceback.print_exc()
|
|
|
+
|
|
|
+ # 添加错误结果
|
|
|
+ all_results.append({
|
|
|
+ "image_path": str(img_path),
|
|
|
+ "processing_time": 0,
|
|
|
+ "success": False,
|
|
|
+ "device": device,
|
|
|
+ "error": str(e)
|
|
|
+ })
|
|
|
+ pbar.update(1)
|
|
|
+
|
|
|
+ return all_results
|
|
|
|
|
|
- return all_results
|
|
|
+ finally:
|
|
|
+ # 🎯 清理:恢复原始函数
|
|
|
+ if adapter_applied:
|
|
|
+ restore_original_function()
|
|
|
+ print("🔄 Original function restored")
|
|
|
+
|
|
|
|
|
|
def main():
|
|
|
"""主函数"""
|
|
|
@@ -194,10 +224,12 @@ def main():
|
|
|
parser.add_argument("--no-normalize", action="store_true", help="禁用数字标准化")
|
|
|
parser.add_argument("--test_mode", action="store_true", help="Test mode (process only 20 files)")
|
|
|
parser.add_argument("--collect_results", type=str, help="收集处理结果到指定CSV文件")
|
|
|
+ parser.add_argument("--no-adapter", action="store_true", help="禁用增强适配器") # 🎯 新增参数
|
|
|
|
|
|
args = parser.parse_args()
|
|
|
|
|
|
normalize_numbers = not args.no_normalize
|
|
|
+ use_enhanced_adapter = not args.no_adapter # 🎯 新增
|
|
|
|
|
|
try:
|
|
|
# 获取并预处理输入文件
|
|
|
@@ -221,7 +253,8 @@ def main():
|
|
|
args.pipeline,
|
|
|
args.device,
|
|
|
args.output_dir,
|
|
|
- normalize_numbers=normalize_numbers
|
|
|
+ normalize_numbers=normalize_numbers,
|
|
|
+ use_enhanced_adapter=use_enhanced_adapter # 🎯 传递参数
|
|
|
)
|
|
|
total_time = time.time() - start_time
|
|
|
|