|
|
@@ -33,7 +33,9 @@ from utils import (
|
|
|
get_image_files_from_dir,
|
|
|
get_image_files_from_list,
|
|
|
get_image_files_from_csv,
|
|
|
- collect_pid_files
|
|
|
+ collect_pid_files,
|
|
|
+ normalize_markdown_table,
|
|
|
+ normalize_json_table
|
|
|
)
|
|
|
|
|
|
def convert_pdf_to_images(pdf_file: str, output_dir: str | None = None, dpi: int = 200) -> List[str]:
|
|
|
@@ -159,7 +161,8 @@ class DotsOCRProcessor:
|
|
|
prompt_mode: str = "prompt_layout_all_en",
|
|
|
dpi: int = 200,
|
|
|
min_pixels: int = MIN_PIXELS,
|
|
|
- max_pixels: int = MAX_PIXELS):
|
|
|
+ max_pixels: int = MAX_PIXELS,
|
|
|
+ normalize_numbers: bool = False):
|
|
|
"""
|
|
|
初始化处理器
|
|
|
|
|
|
@@ -179,7 +182,8 @@ class DotsOCRProcessor:
|
|
|
self.dpi = dpi
|
|
|
self.min_pixels = min_pixels
|
|
|
self.max_pixels = max_pixels
|
|
|
-
|
|
|
+ self.normalize_numbers = normalize_numbers
|
|
|
+
|
|
|
# 初始化解析器
|
|
|
self.parser = DotsOCRParser(
|
|
|
ip=ip,
|
|
|
@@ -232,9 +236,27 @@ class DotsOCRProcessor:
|
|
|
else:
|
|
|
md_content = "# 解析失败\n\n未能提取到有效的文档内容。"
|
|
|
|
|
|
+ # 如果启用数字标准化,处理 Markdown 内容
|
|
|
+ original_text = md_content
|
|
|
+ if self.normalize_numbers:
|
|
|
+ # generated_text = normalize_financial_numbers(generated_text)
|
|
|
+ # 只对Markdown表格进行数字标准化
|
|
|
+ generated_text = normalize_markdown_table(md_content)
|
|
|
+
|
|
|
+ # 统计标准化的变化
|
|
|
+ changes_count = len([1 for o, n in zip(original_text, generated_text) if o != n])
|
|
|
+ if changes_count > 0:
|
|
|
+ saved_files['md_normalized'] = f"✅ 已标准化 {changes_count} 个字符(全角→半角)"
|
|
|
+ else:
|
|
|
+ saved_files['md_normalized'] = "ℹ️ 无需标准化(已是标准格式)"
|
|
|
with open(output_md_path, 'w', encoding='utf-8') as f:
|
|
|
- f.write(md_content)
|
|
|
+ f.write(generated_text)
|
|
|
saved_files['md'] = output_md_path
|
|
|
+ # 如果启用了标准化,也保存原始版本用于对比
|
|
|
+ if self.normalize_numbers and original_text != generated_text:
|
|
|
+ original_markdown_path = Path(output_dir) / f"{Path(image_name).stem}_original.md"
|
|
|
+ with open(original_markdown_path, 'w', encoding='utf-8') as f:
|
|
|
+ f.write(original_text)
|
|
|
|
|
|
# 2. 保存 JSON 文件
|
|
|
output_json_path = os.path.join(output_dir, f"{image_name}.json")
|
|
|
@@ -242,16 +264,29 @@ class DotsOCRProcessor:
|
|
|
|
|
|
if 'layout_info_path' in result and os.path.exists(result['layout_info_path']):
|
|
|
with open(result['layout_info_path'], 'r', encoding='utf-8') as f:
|
|
|
- json_data = json.load(f)
|
|
|
+ json_content = f.read()
|
|
|
else:
|
|
|
- json_data = {
|
|
|
- "error": "未能提取到有效的布局信息",
|
|
|
- "cells": []
|
|
|
- }
|
|
|
+ json_content = f'{{"error": "未能提取到有效的布局信息"}}'
|
|
|
|
|
|
+ # 对json中的表格内容进行数字标准化,
|
|
|
+ original_json_text = json_content
|
|
|
+ if self.normalize_numbers:
|
|
|
+ json_content = normalize_json_table(json_content)
|
|
|
+
|
|
|
+ # 统计标准化的变化
|
|
|
+ changes_count = len([1 for o, n in zip(original_json_text, json_content) if o != n])
|
|
|
+ if changes_count > 0:
|
|
|
+ saved_files['json_normalized'] = f"✅ 已标准化 {changes_count} 个字符(全角→半角)"
|
|
|
+ else:
|
|
|
+ saved_files['json_normalized'] = "ℹ️ 无需标准化(已是标准格式)"
|
|
|
with open(output_json_path, 'w', encoding='utf-8') as f:
|
|
|
- json.dump(json_data, f, ensure_ascii=False, indent=2)
|
|
|
+ f.write(json_content)
|
|
|
saved_files['json'] = output_json_path
|
|
|
+ # 如果启用了标准化,也保存原始版本用于对比
|
|
|
+ if self.normalize_numbers and original_json_text != json_content:
|
|
|
+ original_json_path = Path(output_dir) / f"{Path(image_name).stem}_original.json"
|
|
|
+ with open(original_json_path, 'w', encoding='utf-8') as f:
|
|
|
+ f.write(original_json_text)
|
|
|
|
|
|
# 3. 保存带标注的布局图片
|
|
|
output_layout_image_path = os.path.join(output_dir, f"{image_name}_layout.jpg")
|
|
|
@@ -553,6 +588,7 @@ def main():
|
|
|
parser.add_argument("--min_pixels", type=int, default=MIN_PIXELS, help="Minimum pixels")
|
|
|
parser.add_argument("--max_pixels", type=int, default=MAX_PIXELS, help="Maximum pixels")
|
|
|
parser.add_argument("--dpi", type=int, default=200, help="PDF processing DPI")
|
|
|
+ parser.add_argument('--no-normalize', action='store_true', help='禁用数字标准化')
|
|
|
|
|
|
# 处理参数
|
|
|
parser.add_argument("--batch_size", type=int, default=1, help="Batch size")
|
|
|
@@ -597,7 +633,8 @@ def main():
|
|
|
prompt_mode=args.prompt_mode,
|
|
|
dpi=args.dpi,
|
|
|
min_pixels=args.min_pixels,
|
|
|
- max_pixels=args.max_pixels
|
|
|
+ max_pixels=args.max_pixels,
|
|
|
+ normalize_numbers=not args.no_normalize
|
|
|
)
|
|
|
|
|
|
# 开始处理
|
|
|
@@ -709,17 +746,20 @@ if __name__ == "__main__":
|
|
|
|
|
|
# 默认配置
|
|
|
default_config = {
|
|
|
- "input_dir": "../../OmniDocBench/OpenDataLab___OmniDocBench/images",
|
|
|
- "output_dir": "./OmniDocBench_DotsOCR_Results",
|
|
|
- # "ip": "10.192.72.11",
|
|
|
- "ip": "127.0.0.1",
|
|
|
+ "input_file": "./sample_data/2023年度报告母公司_page_003.png",
|
|
|
+ "output_dir": "./sample_data",
|
|
|
+ "collect_results": "./sample_data/processed_files.csv",
|
|
|
+ # "input_dir": "../../OmniDocBench/OpenDataLab___OmniDocBench/images",
|
|
|
+ # "output_dir": "./OmniDocBench_DotsOCR_Results",
|
|
|
+ # "collect_results": "./OmniDocBench_DotsOCR_Results/processed_files.csv",
|
|
|
+ "ip": "10.192.72.11",
|
|
|
+ # "ip": "127.0.0.1",
|
|
|
"port": "8101",
|
|
|
"model_name": "DotsOCR",
|
|
|
"prompt_mode": "prompt_layout_all_en",
|
|
|
"batch_size": "1",
|
|
|
"max_workers": "3",
|
|
|
"dpi": "200",
|
|
|
- "collect_results": "./OmniDocBench_DotsOCR_Results/processed_files.csv",
|
|
|
}
|
|
|
|
|
|
# 如果需要处理失败的文件,可以使用这个配置
|