Selaa lähdekoodia

提取表格到excel

zhch158_admin 11 kuukautta sitten
vanhempi
commit
f443b484b8
1 muutettua tiedostoa jossa 111 lisäystä ja 2 poistoa
  1. 111 2
      zhch/magic_pdf_parse_main_zhch.py

+ 111 - 2
zhch/magic_pdf_parse_main_zhch.py

@@ -49,6 +49,112 @@ def json_md_dump(
         md_content,
     )
 
+# 使用Pydantic定义report数据结构
+from pydantic import BaseModel
+import pandas as pd
+import re
+from magic_pdf.config.ocr_content_type import BlockType
+class Report(BaseModel):
+    sheet_name: str
+    dataframe: pd.DataFrame
+    last_available_label: bool
+    first_available_label: bool
+    # Pydantic 无法为 pandas.core.frame.DataFrame 类型生成 schema,所以需要手动设置
+    class Config:
+        arbitrary_types_allowed = True
+
+def save_report(
+    pipe: UNIPipe,
+    excel_path: str
+):
+    """
+    保存报表数据
+    """
+    def merge_tables(prev_report: Report, next_report: Report) -> pd.DataFrame:
+        # ... existing code ...
+        if prev_report.dataframe is not None and next_report.dataframe is not None:
+            # 判断前一个table是否是最后一个有效标签,下一个table是否是第一个有效标签
+            if prev_report.last_available_label and next_report.first_available_label:
+                # 判断2个dataframe的列数是否相同
+                if prev_report.dataframe.shape[1] == next_report.dataframe.shape[1]:
+                    # 列数相同,则合并, 使用prev_report.dataframe.columns
+                    next_report.dataframe.columns = prev_report.dataframe.columns
+                    next_report.dataframe.reset_index(drop=True, inplace=True)
+                    # next_table_reindexed = next_report.dataframe.reindex(columns=prev_report.dataframe.columns)
+                    merged_table = pd.concat([prev_report.dataframe, next_report.dataframe], axis=0, ignore_index=True)
+                    return merged_table
+                else:
+                    logger.error(f"列数不同,无法合并: {prev_report.sheet_name}(report.dataframe.shape[1]) 和 {next_report.sheet_name}(next_report.dataframe.shape[1])")
+        return None
+
+    report_list = []  # 初始化 report_list 为空列表
+    pdf_info_list = pipe.pdf_mid_data['pdf_info']
+    # 遍历pdf_info_list,获取每页的page_info
+    for page_info in pdf_info_list:
+        paras_of_layout = page_info.get('para_blocks')
+        page_idx = page_info.get('page_idx')
+        if not paras_of_layout:
+            continue
+        # 遍历每页的para_block, 每页有多个para_block,每个para_block只包含一个或0个table
+        for block_idx, para_block in enumerate(paras_of_layout):
+            para_type = para_block['type']
+            if para_type == BlockType.Table:
+                sheet_name = None
+                dataframe = None
+                for block in para_block['blocks']:
+                    # 遍历每个block,找到table_body和table_caption
+                    if block['type'] == BlockType.TableBody:
+                        # 将html转换为dataframe
+                        dataframe = pd.read_html(block['lines'][0]['spans'][0]['html'])[0]
+                    elif block['type'] == BlockType.TableCaption:
+                        sheet_name = block['lines'][0]['spans'][0]['content']
+                if sheet_name is None:
+                    # 向上查找,类型是Title的para_block
+                    for title_block in reversed(paras_of_layout[:paras_of_layout.index(para_block)]):
+                        if title_block['type'] == BlockType.Title:
+                            title = title_block['lines'][0]['spans'][0]['content'].strip()
+                            # 如果title不为空,且title的最后一个字符是“表”
+                            if title is not None and title != '' and title[-1] == '表':
+                                sheet_name = title
+                                break
+                if dataframe is None:
+                    continue
+                if sheet_name is None:
+                    sheet_name = f"Sheet_{page_idx}.{block_idx}"
+                # 替换非法字符
+                sheet_name = re.sub(r'[\[\]:*?/\\]', '', sheet_name)
+                report = Report(sheet_name=sheet_name, dataframe=dataframe, last_available_label=False, first_available_label=False)
+                if para_block == paras_of_layout[-1]:
+                    report.last_available_label = True
+                if para_block == paras_of_layout[0]:
+                    report.first_available_label = True
+                report_list.append(report)
+            
+    excel_writer = pd.ExcelWriter(excel_path, engine='xlsxwriter')
+    for report in report_list:
+        if report.dataframe is not None:
+            report.dataframe.to_excel(excel_writer, sheet_name=report.sheet_name, index=False)
+    excel_writer.close()
+    
+    merged_report_list = []
+    prev_report = None
+    for report in report_list:
+        if prev_report is not None and prev_report.dataframe is not None:
+            merged_table = merge_tables(prev_report, report)
+            if merged_table is not None:
+                prev_report.dataframe = merged_table
+                continue
+            else:
+                merged_report_list.append(prev_report)
+        prev_report = report
+    merged_report_list.append(prev_report)
+    
+    merged_excel_path = excel_path.replace(".xlsx", "_merged.xlsx")
+    excel_writer = pd.ExcelWriter(merged_excel_path, engine='xlsxwriter')
+    for report in merged_report_list:
+        report.dataframe.to_excel(excel_writer, sheet_name=report.sheet_name, index=False)
+        logger.debug(f"保存报表: {report}")
+    excel_writer.close()
 
 # 可视化
 def draw_visualization_bbox(pdf_info, pdf_bytes, local_md_dir, pdf_file_name):
@@ -130,6 +236,9 @@ def pdf_parse_main(
         content_list = pipe.pipe_mk_uni_format(image_path_parent, drop_mode='none')
         md_content = pipe.pipe_mk_markdown(image_path_parent, drop_mode='none')
 
+        # 保存报表
+        save_report(pipe, os.path.join(output_path, f'{pdf_name}.xlsx'))
+
         if is_json_md_dump:
             json_md_dump(pipe, md_writer, pdf_name, content_list, md_content, orig_model_list)
 
@@ -143,8 +252,8 @@ def pdf_parse_main(
 # 测试
 if __name__ == '__main__':
     current_script_dir = os.path.dirname(os.path.abspath(__file__))
-    demo_names = ['demo1', 'demo2', 'small_ocr']
+    # demo_names = ['demo1', 'demo2', 'small_ocr']
+    demo_names = ['600916_中国黄金_2002年报_83_94']
     for name in demo_names:
         file_path = os.path.join(current_script_dir, f'{name}.pdf')
-        # pdf_parse_main(file_path, model_json_path='./magic-pdf-0.json', output_dir='./output.demo')
         pdf_parse_main(file_path, output_dir='./output.demo')