|
|
@@ -24,252 +24,15 @@ from dotenv import load_dotenv
|
|
|
load_dotenv(override=True)
|
|
|
|
|
|
from utils import (
|
|
|
- get_image_files_from_dir,
|
|
|
- get_image_files_from_list,
|
|
|
- get_image_files_from_csv,
|
|
|
collect_pid_files,
|
|
|
- load_images_from_pdf,
|
|
|
- normalize_financial_numbers,
|
|
|
- normalize_markdown_table
|
|
|
)
|
|
|
|
|
|
-def convert_pdf_to_images(pdf_file: str, output_dir: str | None = None, dpi: int = 200) -> List[str]:
|
|
|
- """
|
|
|
- 将PDF转换为图像文件
|
|
|
-
|
|
|
- Args:
|
|
|
- pdf_file: PDF文件路径
|
|
|
- output_dir: 输出目录
|
|
|
- dpi: 图像分辨率
|
|
|
-
|
|
|
- Returns:
|
|
|
- 生成的图像文件路径列表
|
|
|
- """
|
|
|
- pdf_path = Path(pdf_file)
|
|
|
- if not pdf_path.exists() or pdf_path.suffix.lower() != '.pdf':
|
|
|
- print(f"❌ Invalid PDF file: {pdf_path}")
|
|
|
- return []
|
|
|
-
|
|
|
- # 如果没有指定输出目录,使用PDF同名目录
|
|
|
- if output_dir is None:
|
|
|
- output_path = pdf_path.parent / f"{pdf_path.stem}"
|
|
|
- else:
|
|
|
- output_path = Path(output_dir) / f"{pdf_path.stem}"
|
|
|
- output_path = output_path.resolve()
|
|
|
- output_path.mkdir(parents=True, exist_ok=True)
|
|
|
-
|
|
|
- try:
|
|
|
- # 使用doc_utils中的函数加载PDF图像
|
|
|
- images = load_images_from_pdf(str(pdf_path), dpi=dpi)
|
|
|
-
|
|
|
- image_paths = []
|
|
|
- for i, image in enumerate(images):
|
|
|
- # 生成图像文件名
|
|
|
- image_filename = f"{pdf_path.stem}_page_{i+1:03d}.png"
|
|
|
- image_path = output_path / image_filename
|
|
|
-
|
|
|
- # 保存图像
|
|
|
- image.save(str(image_path))
|
|
|
- image_paths.append(str(image_path))
|
|
|
-
|
|
|
- print(f"✅ Converted {len(images)} pages from {pdf_path.name} to images")
|
|
|
- return image_paths
|
|
|
-
|
|
|
- except Exception as e:
|
|
|
- print(f"❌ Error converting PDF {pdf_path}: {e}")
|
|
|
- traceback.print_exc()
|
|
|
- return []
|
|
|
-
|
|
|
-def get_input_files(args) -> List[str]:
|
|
|
- """
|
|
|
- 获取输入文件列表,统一处理PDF和图像文件
|
|
|
-
|
|
|
- Args:
|
|
|
- args: 命令行参数
|
|
|
-
|
|
|
- Returns:
|
|
|
- 处理后的图像文件路径列表
|
|
|
- """
|
|
|
- input_files = []
|
|
|
-
|
|
|
- # 获取原始输入文件
|
|
|
- if args.input_csv:
|
|
|
- raw_files = get_image_files_from_csv(args.input_csv, "fail")
|
|
|
- elif args.input_file_list:
|
|
|
- raw_files = get_image_files_from_list(args.input_file_list)
|
|
|
- elif args.input_file:
|
|
|
- raw_files = [Path(args.input_file).resolve()]
|
|
|
- else:
|
|
|
- input_dir = Path(args.input_dir).resolve()
|
|
|
- if not input_dir.exists():
|
|
|
- print(f"❌ Input directory does not exist: {input_dir}")
|
|
|
- return []
|
|
|
-
|
|
|
- # 获取所有支持的文件(图像和PDF)
|
|
|
- image_extensions = ['.jpg', '.jpeg', '.png', '.bmp', '.tiff', '.tif']
|
|
|
- pdf_extensions = ['.pdf']
|
|
|
-
|
|
|
- raw_files = []
|
|
|
- for ext in image_extensions + pdf_extensions:
|
|
|
- raw_files.extend(list(input_dir.glob(f"*{ext}")))
|
|
|
- raw_files.extend(list(input_dir.glob(f"*{ext.upper()}")))
|
|
|
-
|
|
|
- raw_files = [str(f) for f in raw_files]
|
|
|
-
|
|
|
- # 分别处理PDF和图像文件
|
|
|
- pdf_count = 0
|
|
|
- image_count = 0
|
|
|
-
|
|
|
- for file_path in raw_files:
|
|
|
- file_path = Path(file_path)
|
|
|
-
|
|
|
- if file_path.suffix.lower() == '.pdf':
|
|
|
- # 转换PDF为图像
|
|
|
- print(f"📄 Processing PDF: {file_path.name}")
|
|
|
- pdf_images = convert_pdf_to_images(
|
|
|
- str(file_path),
|
|
|
- args.output_dir,
|
|
|
- dpi=args.pdf_dpi
|
|
|
- )
|
|
|
- input_files.extend(pdf_images)
|
|
|
- pdf_count += 1
|
|
|
- else:
|
|
|
- # 直接添加图像文件
|
|
|
- if file_path.exists():
|
|
|
- input_files.append(str(file_path))
|
|
|
- image_count += 1
|
|
|
-
|
|
|
- print(f"📊 Input summary:")
|
|
|
- print(f" PDF files processed: {pdf_count}")
|
|
|
- print(f" Image files found: {image_count}")
|
|
|
- print(f" Total image files to process: {len(input_files)}")
|
|
|
-
|
|
|
- return input_files
|
|
|
-
|
|
|
-def normalize_pipeline_result(result: Dict[str, Any], normalize_numbers: bool = True) -> Dict[str, Any]:
|
|
|
- """
|
|
|
- 对pipeline结果进行数字标准化处理
|
|
|
-
|
|
|
- Args:
|
|
|
- result: pipeline返回的结果对象
|
|
|
- normalize_numbers: 是否启用数字标准化
|
|
|
-
|
|
|
- Returns:
|
|
|
- 包含标准化信息的字典
|
|
|
- """
|
|
|
- if not normalize_numbers:
|
|
|
- return {
|
|
|
- "normalize_numbers": False,
|
|
|
- "changes_applied": False,
|
|
|
- "character_changes_count": 0,
|
|
|
- "parsing_res_tables_count": 0,
|
|
|
- "table_res_list_count": 0,
|
|
|
- "table_consistency_fixed": False
|
|
|
- }
|
|
|
-
|
|
|
- changes_count = 0
|
|
|
- original_data = {}
|
|
|
-
|
|
|
- # 获取原始数据进行备份
|
|
|
- if 'parsing_res_list' in result:
|
|
|
- original_data['parsing_res_list'] = [item.copy() if hasattr(item, 'copy') else dict(item) for item in result['parsing_res_list']]
|
|
|
-
|
|
|
- if 'table_res_list' in result:
|
|
|
- original_data['table_res_list'] = [item.copy() if hasattr(item, 'copy') else dict(item) for item in result['table_res_list']]
|
|
|
-
|
|
|
- try:
|
|
|
- # 1. 标准化 parsing_res_list 中的文本内容
|
|
|
- if 'parsing_res_list' in result:
|
|
|
- for item in result['parsing_res_list']:
|
|
|
- if 'block_content' in item and item['block_content']:
|
|
|
- original_content = str(item['block_content'])
|
|
|
- normalized_content = original_content
|
|
|
-
|
|
|
- # 根据block_label类型选择标准化方法
|
|
|
- if 'block_label' in item and item['block_label'] == 'table':
|
|
|
- normalized_content = normalize_markdown_table(original_content)
|
|
|
-
|
|
|
- if original_content != normalized_content:
|
|
|
- item['block_content'] = normalized_content
|
|
|
- changes_count += len([1 for o, n in zip(original_content, normalized_content) if o != n])
|
|
|
-
|
|
|
- # 2. 标准化 table_res_list 中的HTML表格
|
|
|
- if 'table_res_list' in result:
|
|
|
- for table_item in result['table_res_list']:
|
|
|
- if 'pred_html' in table_item and table_item['pred_html']:
|
|
|
- original_html = str(table_item['pred_html'])
|
|
|
- normalized_html = normalize_markdown_table(original_html)
|
|
|
-
|
|
|
- if original_html != normalized_html:
|
|
|
- table_item['pred_html'] = normalized_html
|
|
|
- changes_count += len([1 for o, n in zip(original_html, normalized_html) if o != n])
|
|
|
-
|
|
|
- # 统计表格数量
|
|
|
- parsing_res_tables_count = 0
|
|
|
- table_res_list_count = 0
|
|
|
- if 'parsing_res_list' in result:
|
|
|
- parsing_res_tables_count = len([item for item in result['parsing_res_list']
|
|
|
- if 'block_label' in item and item['block_label'] == 'table'])
|
|
|
- if 'table_res_list' in result:
|
|
|
- table_res_list_count = len(result['table_res_list'])
|
|
|
-
|
|
|
- # 检查是否需要修复表格一致性(这里只做统计,实际修复可能需要更复杂的逻辑)
|
|
|
- table_consistency_fixed = False
|
|
|
- if parsing_res_tables_count != table_res_list_count:
|
|
|
- warnings.warn(f"⚠️ Warning: Table count mismatch - parsing_res_list has {parsing_res_tables_count} tables, "
|
|
|
- f"but table_res_list has {table_res_list_count} tables.")
|
|
|
- table_consistency_fixed = True
|
|
|
- # 这里可以添加实际的修复逻辑,例如根据需要添加或删除表格项
|
|
|
- # 但由于缺乏具体规则,暂时只做统计和警告
|
|
|
- return {
|
|
|
- "normalize_numbers": normalize_numbers,
|
|
|
- "changes_applied": changes_count > 0,
|
|
|
- "character_changes_count": changes_count,
|
|
|
- "parsing_res_tables_count": parsing_res_tables_count,
|
|
|
- "table_res_list_count": table_res_list_count,
|
|
|
- "table_consistency_fixed": table_consistency_fixed
|
|
|
- }
|
|
|
-
|
|
|
- except Exception as e:
|
|
|
- print(f"⚠️ Warning: Error during normalization: {e}")
|
|
|
- return {
|
|
|
- "normalize_numbers": normalize_numbers,
|
|
|
- "changes_applied": False,
|
|
|
- "character_changes_count": 0,
|
|
|
- "normalization_error": str(e)
|
|
|
- }
|
|
|
-
|
|
|
-def save_normalized_files(result, output_dir: str, filename: str,
|
|
|
- processing_info: Dict[str, Any], normalize_numbers: bool = True):
|
|
|
- """
|
|
|
- 保存标准化处理后的文件,包括原始版本
|
|
|
- """
|
|
|
- output_path = Path(output_dir)
|
|
|
-
|
|
|
- # 保存标准化后的版本
|
|
|
- json_output_path = str(output_path / f"{filename}.json")
|
|
|
- md_output_path = str(output_path / f"{filename}.md")
|
|
|
-
|
|
|
- result.save_to_json(json_output_path)
|
|
|
- result.save_to_markdown(md_output_path)
|
|
|
-
|
|
|
- # 如果有标准化变化,在JSON中添加处理信息
|
|
|
- if normalize_numbers and processing_info.get('changes_applied', False):
|
|
|
- try:
|
|
|
- # 读取生成的JSON文件,添加处理信息
|
|
|
- with open(json_output_path, 'r', encoding='utf-8') as f:
|
|
|
- json_data = json.load(f)
|
|
|
-
|
|
|
- json_data['processing_info'] = processing_info
|
|
|
-
|
|
|
- # 重新保存包含处理信息的JSON
|
|
|
- with open(json_output_path, 'w', encoding='utf-8') as f:
|
|
|
- json.dump(json_data, f, ensure_ascii=False, indent=2)
|
|
|
-
|
|
|
- except Exception as e:
|
|
|
- print(f"⚠️ Warning: Could not add processing info to JSON: {e}")
|
|
|
-
|
|
|
- return json_output_path, md_output_path
|
|
|
+from ppstructurev3_utils import (
|
|
|
+ get_input_files,
|
|
|
+ convert_pruned_result_to_json,
|
|
|
+ save_output_images,
|
|
|
+ save_markdown_content
|
|
|
+)
|
|
|
|
|
|
def process_images_unified(image_paths: List[str],
|
|
|
pipeline_name: str = "PP-StructureV3",
|
|
|
@@ -327,7 +90,9 @@ def process_images_unified(image_paths: List[str],
|
|
|
processing_time = time.time() - start_time
|
|
|
|
|
|
# 处理结果
|
|
|
- for result in results:
|
|
|
+ for idx, result in enumerate(results):
|
|
|
+ if idx > 0:
|
|
|
+ raise ValueError("Multiple results found for a single image")
|
|
|
try:
|
|
|
input_path = Path(result["input_path"])
|
|
|
|
|
|
@@ -337,17 +102,30 @@ def process_images_unified(image_paths: List[str],
|
|
|
else:
|
|
|
output_filename = f"{input_path.stem}"
|
|
|
|
|
|
- # 应用数字标准化
|
|
|
- processing_info = normalize_pipeline_result(result, normalize_numbers)
|
|
|
-
|
|
|
- # 保存JSON和Markdown文件(包含标准化处理)
|
|
|
- json_output_path, md_output_path = save_normalized_files(
|
|
|
- result, output_dir, output_filename, processing_info, normalize_numbers
|
|
|
+ # 转换并保存标准JSON格式
|
|
|
+ json_content = result.json['res']
|
|
|
+ json_output_path, converted_json = convert_pruned_result_to_json(
|
|
|
+ json_content,
|
|
|
+ str(input_path),
|
|
|
+ output_dir,
|
|
|
+ output_filename,
|
|
|
+ normalize_numbers=normalize_numbers
|
|
|
+ )
|
|
|
+
|
|
|
+ # 保存输出图像
|
|
|
+ img_content = result.img
|
|
|
+ saved_images = save_output_images(img_content, str(output_dir), output_filename)
|
|
|
+
|
|
|
+ # 保存Markdown内容
|
|
|
+ markdown_content = result.markdown
|
|
|
+ md_output_path = save_markdown_content(
|
|
|
+ markdown_content,
|
|
|
+ output_dir,
|
|
|
+ output_filename,
|
|
|
+ normalize_numbers=normalize_numbers,
|
|
|
+ key_text='markdown_texts',
|
|
|
+ key_images='markdown_images'
|
|
|
)
|
|
|
-
|
|
|
- # 如果有表格一致性修复,输出提示
|
|
|
- if processing_info.get('table_consistency_fixed', False):
|
|
|
- print(f"🔧 修复了表格一致性问题:{input_path.name}")
|
|
|
|
|
|
# 记录处理结果
|
|
|
all_results.append({
|
|
|
@@ -358,9 +136,9 @@ def process_images_unified(image_paths: List[str],
|
|
|
"output_json": json_output_path,
|
|
|
"output_md": md_output_path,
|
|
|
"is_pdf_page": "_page_" in input_path.name, # 标记是否为PDF页面
|
|
|
- "processing_info": processing_info
|
|
|
- })
|
|
|
-
|
|
|
+ "processing_info": converted_json.get('processing_info', {})
|
|
|
+ })
|
|
|
+
|
|
|
except Exception as e:
|
|
|
print(f"Error saving result for {result.get('input_path', 'unknown')}: {e}", file=sys.stderr)
|
|
|
traceback.print_exc()
|
|
|
@@ -533,11 +311,15 @@ if __name__ == "__main__":
|
|
|
|
|
|
# 默认配置
|
|
|
default_config = {
|
|
|
- "input_dir": "../../OmniDocBench/OpenDataLab___OmniDocBench/images",
|
|
|
- "output_dir": "./OmniDocBench_PPStructureV3_Results",
|
|
|
+ # "input_file": "/home/ubuntu/zhch/data/至远彩色印刷工业有限公司/2023年度报告母公司.pdf",
|
|
|
+ "input_file": "/home/ubuntu/zhch/data/至远彩色印刷工业有限公司/PPStructureV3_Results/2023年度报告母公司/2023年度报告母公司_page_027.png",
|
|
|
+ "output_dir": "/home/ubuntu/zhch/data/至远彩色印刷工业有限公司/PPStructureV3_Results",
|
|
|
+ "collect_results": f"/home/ubuntu/zhch/data/至远彩色印刷工业有限公司/PPStructureV3_Results/processed_files_{time.strftime('%Y%m%d_%H%M%S')}.csv",
|
|
|
+ # "input_dir": "../../OmniDocBench/OpenDataLab___OmniDocBench/images",
|
|
|
+ # "output_dir": "./OmniDocBench_PPStructureV3_Results",
|
|
|
+ # "collect_results": f"./OmniDocBench_PPStructureV3_Results/processed_files_{time.strftime('%Y%m%d_%H%M%S')}.csv",
|
|
|
"pipeline": "./my_config/PP-StructureV3.yaml",
|
|
|
- "device": "gpu:0",
|
|
|
- "collect_results": f"./OmniDocBench_PPStructureV3_Results/processed_files_{time.strftime('%Y%m%d_%H%M%S')}.csv",
|
|
|
+ "device": "gpu:3",
|
|
|
}
|
|
|
|
|
|
# 构造参数
|