|
|
@@ -28,7 +28,9 @@ from utils import (
|
|
|
get_image_files_from_list,
|
|
|
get_image_files_from_csv,
|
|
|
collect_pid_files,
|
|
|
- load_images_from_pdf
|
|
|
+ load_images_from_pdf,
|
|
|
+ normalize_financial_numbers,
|
|
|
+ normalize_markdown_table
|
|
|
)
|
|
|
|
|
|
def convert_pdf_to_images(pdf_file: str, output_dir: str | None = None, dpi: int = 200) -> List[str]:
|
|
|
@@ -144,12 +146,138 @@ def get_input_files(args) -> List[str]:
|
|
|
|
|
|
return input_files
|
|
|
|
|
|
+def normalize_pipeline_result(result: Dict[str, Any], normalize_numbers: bool = True) -> Dict[str, Any]:
|
|
|
+ """
|
|
|
+ 对pipeline结果进行数字标准化处理
|
|
|
+
|
|
|
+ Args:
|
|
|
+ result: pipeline返回的结果对象
|
|
|
+ normalize_numbers: 是否启用数字标准化
|
|
|
+
|
|
|
+ Returns:
|
|
|
+ 包含标准化信息的字典
|
|
|
+ """
|
|
|
+ if not normalize_numbers:
|
|
|
+ return {
|
|
|
+ "normalize_numbers": False,
|
|
|
+ "changes_applied": False,
|
|
|
+ "character_changes_count": 0,
|
|
|
+ "parsing_res_tables_count": 0,
|
|
|
+ "table_res_list_count": 0,
|
|
|
+ "table_consistency_fixed": False
|
|
|
+ }
|
|
|
+
|
|
|
+ changes_count = 0
|
|
|
+ original_data = {}
|
|
|
+
|
|
|
+ # 获取原始数据进行备份
|
|
|
+ if 'parsing_res_list' in result:
|
|
|
+ original_data['parsing_res_list'] = [item.copy() if hasattr(item, 'copy') else dict(item) for item in result['parsing_res_list']]
|
|
|
+
|
|
|
+ if 'table_res_list' in result:
|
|
|
+ original_data['table_res_list'] = [item.copy() if hasattr(item, 'copy') else dict(item) for item in result['table_res_list']]
|
|
|
+
|
|
|
+ try:
|
|
|
+ # 1. 标准化 parsing_res_list 中的文本内容
|
|
|
+ if 'parsing_res_list' in result:
|
|
|
+ for item in result['parsing_res_list']:
|
|
|
+ if 'block_content' in item and item['block_content']:
|
|
|
+ original_content = str(item['block_content'])
|
|
|
+ normalized_content = original_content
|
|
|
+
|
|
|
+ # 根据block_label类型选择标准化方法
|
|
|
+ if 'block_label' in item and item['block_label'] == 'table':
|
|
|
+ normalized_content = normalize_markdown_table(original_content)
|
|
|
+
|
|
|
+ if original_content != normalized_content:
|
|
|
+ item['block_content'] = normalized_content
|
|
|
+ changes_count += len([1 for o, n in zip(original_content, normalized_content) if o != n])
|
|
|
+
|
|
|
+ # 2. 标准化 table_res_list 中的HTML表格
|
|
|
+ if 'table_res_list' in result:
|
|
|
+ for table_item in result['table_res_list']:
|
|
|
+ if 'pred_html' in table_item and table_item['pred_html']:
|
|
|
+ original_html = str(table_item['pred_html'])
|
|
|
+ normalized_html = normalize_markdown_table(original_html)
|
|
|
+
|
|
|
+ if original_html != normalized_html:
|
|
|
+ table_item['pred_html'] = normalized_html
|
|
|
+ changes_count += len([1 for o, n in zip(original_html, normalized_html) if o != n])
|
|
|
+
|
|
|
+ # 统计表格数量
|
|
|
+ parsing_res_tables_count = 0
|
|
|
+ table_res_list_count = 0
|
|
|
+ if 'parsing_res_list' in result:
|
|
|
+ parsing_res_tables_count = len([item for item in result['parsing_res_list']
|
|
|
+ if 'block_label' in item and item['block_label'] == 'table'])
|
|
|
+ if 'table_res_list' in result:
|
|
|
+ table_res_list_count = len(result['table_res_list'])
|
|
|
+
|
|
|
+ # 检查是否需要修复表格一致性(这里只做统计,实际修复可能需要更复杂的逻辑)
|
|
|
+ table_consistency_fixed = False
|
|
|
+ if parsing_res_tables_count != table_res_list_count:
|
|
|
+ warnings.warn(f"⚠️ Warning: Table count mismatch - parsing_res_list has {parsing_res_tables_count} tables, "
|
|
|
+ f"but table_res_list has {table_res_list_count} tables.")
|
|
|
+ table_consistency_fixed = True
|
|
|
+ # 这里可以添加实际的修复逻辑,例如根据需要添加或删除表格项
|
|
|
+ # 但由于缺乏具体规则,暂时只做统计和警告
|
|
|
+ return {
|
|
|
+ "normalize_numbers": normalize_numbers,
|
|
|
+ "changes_applied": changes_count > 0,
|
|
|
+ "character_changes_count": changes_count,
|
|
|
+ "parsing_res_tables_count": parsing_res_tables_count,
|
|
|
+ "table_res_list_count": table_res_list_count,
|
|
|
+ "table_consistency_fixed": table_consistency_fixed
|
|
|
+ }
|
|
|
+
|
|
|
+ except Exception as e:
|
|
|
+ print(f"⚠️ Warning: Error during normalization: {e}")
|
|
|
+ return {
|
|
|
+ "normalize_numbers": normalize_numbers,
|
|
|
+ "changes_applied": False,
|
|
|
+ "character_changes_count": 0,
|
|
|
+ "normalization_error": str(e)
|
|
|
+ }
|
|
|
+
|
|
|
+def save_normalized_files(result, output_dir: str, filename: str,
|
|
|
+ processing_info: Dict[str, Any], normalize_numbers: bool = True):
|
|
|
+ """
|
|
|
+ 保存标准化处理后的文件,包括原始版本
|
|
|
+ """
|
|
|
+ output_path = Path(output_dir)
|
|
|
+
|
|
|
+ # 保存标准化后的版本
|
|
|
+ json_output_path = str(output_path / f"{filename}.json")
|
|
|
+ md_output_path = str(output_path / f"{filename}.md")
|
|
|
+
|
|
|
+ result.save_to_json(json_output_path)
|
|
|
+ result.save_to_markdown(md_output_path)
|
|
|
+
|
|
|
+ # 如果有标准化变化,在JSON中添加处理信息
|
|
|
+ if normalize_numbers and processing_info.get('changes_applied', False):
|
|
|
+ try:
|
|
|
+ # 读取生成的JSON文件,添加处理信息
|
|
|
+ with open(json_output_path, 'r', encoding='utf-8') as f:
|
|
|
+ json_data = json.load(f)
|
|
|
+
|
|
|
+ json_data['processing_info'] = processing_info
|
|
|
+
|
|
|
+ # 重新保存包含处理信息的JSON
|
|
|
+ with open(json_output_path, 'w', encoding='utf-8') as f:
|
|
|
+ json.dump(json_data, f, ensure_ascii=False, indent=2)
|
|
|
+
|
|
|
+ except Exception as e:
|
|
|
+ print(f"⚠️ Warning: Could not add processing info to JSON: {e}")
|
|
|
+
|
|
|
+ return json_output_path, md_output_path
|
|
|
+
|
|
|
def process_images_unified(image_paths: List[str],
|
|
|
pipeline_name: str = "PP-StructureV3",
|
|
|
device: str = "gpu:0",
|
|
|
- output_dir: str = "./output") -> List[Dict[str, Any]]:
|
|
|
+ output_dir: str = "./output",
|
|
|
+ normalize_numbers: bool = True) -> List[Dict[str, Any]]:
|
|
|
"""
|
|
|
- 统一的图像处理函数(修改自ppstructurev3_single_process.py)
|
|
|
+ 统一的图像处理函数,支持数字标准化
|
|
|
"""
|
|
|
# 创建输出目录
|
|
|
output_path = Path(output_dir)
|
|
|
@@ -174,6 +302,7 @@ def process_images_unified(image_paths: List[str],
|
|
|
total_images = len(image_paths)
|
|
|
|
|
|
print(f"Processing {total_images} images one by one")
|
|
|
+ print(f"🔧 数字标准化: {'启用' if normalize_numbers else '禁用'}")
|
|
|
|
|
|
# 使用tqdm显示进度
|
|
|
with tqdm(total=total_images, desc="Processing images", unit="img",
|
|
|
@@ -208,12 +337,17 @@ def process_images_unified(image_paths: List[str],
|
|
|
else:
|
|
|
output_filename = f"{input_path.stem}"
|
|
|
|
|
|
- # 保存JSON和Markdown文件
|
|
|
- json_output_path = str(Path(output_dir, f"{output_filename}.json"))
|
|
|
- md_output_path = str(Path(output_dir, f"{output_filename}.md"))
|
|
|
+ # 应用数字标准化
|
|
|
+ processing_info = normalize_pipeline_result(result, normalize_numbers)
|
|
|
+
|
|
|
+ # 保存JSON和Markdown文件(包含标准化处理)
|
|
|
+ json_output_path, md_output_path = save_normalized_files(
|
|
|
+ result, output_dir, output_filename, processing_info, normalize_numbers
|
|
|
+ )
|
|
|
|
|
|
- result.save_to_json(json_output_path)
|
|
|
- result.save_to_markdown(md_output_path)
|
|
|
+ # 如果有表格一致性修复,输出提示
|
|
|
+ if processing_info.get('table_consistency_fixed', False):
|
|
|
+ print(f"🔧 修复了表格一致性问题:{input_path.name}")
|
|
|
|
|
|
# 记录处理结果
|
|
|
all_results.append({
|
|
|
@@ -223,7 +357,8 @@ def process_images_unified(image_paths: List[str],
|
|
|
"device": device,
|
|
|
"output_json": json_output_path,
|
|
|
"output_md": md_output_path,
|
|
|
- "is_pdf_page": "_page_" in input_path.name # 标记是否为PDF页面
|
|
|
+ "is_pdf_page": "_page_" in input_path.name, # 标记是否为PDF页面
|
|
|
+ "processing_info": processing_info
|
|
|
})
|
|
|
|
|
|
except Exception as e:
|
|
|
@@ -278,11 +413,14 @@ def main():
|
|
|
parser.add_argument("--pipeline", type=str, default="PP-StructureV3", help="Pipeline name")
|
|
|
parser.add_argument("--device", type=str, default="gpu:0", help="Device string (e.g., 'gpu:0', 'cpu')")
|
|
|
parser.add_argument("--pdf_dpi", type=int, default=200, help="DPI for PDF to image conversion")
|
|
|
+ parser.add_argument("--no-normalize", action="store_true", help="禁用数字标准化")
|
|
|
parser.add_argument("--test_mode", action="store_true", help="Test mode (process only 20 files)")
|
|
|
parser.add_argument("--collect_results", type=str, help="收集处理结果到指定CSV文件")
|
|
|
|
|
|
args = parser.parse_args()
|
|
|
|
|
|
+ normalize_numbers = not args.no_normalize
|
|
|
+
|
|
|
try:
|
|
|
# 获取并预处理输入文件
|
|
|
print("🔄 Preprocessing input files...")
|
|
|
@@ -304,7 +442,8 @@ def main():
|
|
|
input_files,
|
|
|
args.pipeline,
|
|
|
args.device,
|
|
|
- args.output_dir
|
|
|
+ args.output_dir,
|
|
|
+ normalize_numbers=normalize_numbers
|
|
|
)
|
|
|
total_time = time.time() - start_time
|
|
|
|
|
|
@@ -312,6 +451,7 @@ def main():
|
|
|
success_count = sum(1 for r in results if r.get('success', False))
|
|
|
error_count = len(results) - success_count
|
|
|
pdf_page_count = sum(1 for r in results if r.get('is_pdf_page', False))
|
|
|
+ total_changes = sum(r.get('processing_info', {}).get('character_changes_count', 0) for r in results if 'processing_info' in r)
|
|
|
|
|
|
print(f"\n" + "="*60)
|
|
|
print(f"✅ Processing completed!")
|
|
|
@@ -323,6 +463,8 @@ def main():
|
|
|
print(f" Failed: {error_count}")
|
|
|
if len(input_files) > 0:
|
|
|
print(f" Success rate: {success_count / len(input_files) * 100:.2f}%")
|
|
|
+ if normalize_numbers:
|
|
|
+ print(f" 总标准化字符数: {total_changes}")
|
|
|
print(f"⏱️ Performance:")
|
|
|
print(f" Total time: {total_time:.2f} seconds")
|
|
|
if total_time > 0:
|
|
|
@@ -343,6 +485,8 @@ def main():
|
|
|
"device": args.device,
|
|
|
"pipeline": args.pipeline,
|
|
|
"pdf_dpi": args.pdf_dpi,
|
|
|
+ "normalize_numbers": normalize_numbers,
|
|
|
+ "total_character_changes": total_changes,
|
|
|
"timestamp": time.strftime("%Y-%m-%d %H:%M:%S")
|
|
|
}
|
|
|
|
|
|
@@ -359,7 +503,7 @@ def main():
|
|
|
|
|
|
print(f"💾 Results saved to: {output_file}")
|
|
|
|
|
|
- # 如果没有收集结果的路径,使用缺省文件名,和output_dir同一路径
|
|
|
+ # 如果没有收集结果的路径,使用缺省文件名,和output_dir同一路径
|
|
|
if not args.collect_results:
|
|
|
output_file_processed = Path(args.output_dir) / f"processed_files_{time.strftime('%Y%m%d_%H%M%S')}.csv"
|
|
|
else:
|
|
|
@@ -387,7 +531,7 @@ if __name__ == "__main__":
|
|
|
# 如果没有命令行参数,使用默认配置运行
|
|
|
print("ℹ️ No command line arguments provided. Running with default configuration...")
|
|
|
|
|
|
- # 默认配置(删除了 batch_size)
|
|
|
+ # 默认配置
|
|
|
default_config = {
|
|
|
"input_dir": "../../OmniDocBench/OpenDataLab___OmniDocBench/images",
|
|
|
"output_dir": "./OmniDocBench_PPStructureV3_Results",
|
|
|
@@ -396,18 +540,14 @@ if __name__ == "__main__":
|
|
|
"collect_results": f"./OmniDocBench_PPStructureV3_Results/processed_files_{time.strftime('%Y%m%d_%H%M%S')}.csv",
|
|
|
}
|
|
|
|
|
|
- # default_config = {
|
|
|
- # "input_csv": "./OmniDocBench_PPStructureV3_Results/processed_files.csv",
|
|
|
- # "output_dir": "./OmniDocBench_PPStructureV3_Results",
|
|
|
- # "pipeline": "./my_config/PP-StructureV3.yaml",
|
|
|
- # "device": "gpu:0",
|
|
|
- # "collect_results": f"./OmniDocBench_PPStructureV3_Results/processed_files_{time.strftime('%Y%m%d_%H%M%S')}.csv",
|
|
|
- # }
|
|
|
# 构造参数
|
|
|
sys.argv = [sys.argv[0]]
|
|
|
for key, value in default_config.items():
|
|
|
sys.argv.extend([f"--{key}", str(value)])
|
|
|
|
|
|
+ # 可以添加禁用标准化选项
|
|
|
+ # sys.argv.append("--no-normalize")
|
|
|
+
|
|
|
# 测试模式
|
|
|
# sys.argv.append("--test_mode")
|
|
|
|