| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285 |
- import os
- import json
- import asyncio
- import io
- import csv
- import datetime
- import httpx
- # --- LangChain Imports ---
- from langchain_openai import ChatOpenAI
- from langchain_core.prompts import ChatPromptTemplate
- from langchain_core.output_parsers import JsonOutputParser
- # --- 核心 Parser ---
- class TransactionParserAgent:
- def __init__(self, api_key: str, multimodal_api_url: str, base_url: str = "https://api.deepseek.com"):
- # 1. 初始化 LangChain ChatOpenAI 客户端
- # DeepSeek 完全兼容 OpenAI 接口,使用 ChatOpenAI 是标准做法
- self.llm = ChatOpenAI(
- model="deepseek-chat",
- api_key=api_key,
- base_url=base_url,
- temperature=0.1,
- max_retries=3, # LangChain 内置重试机制
- model_kwargs={
- "response_format": {"type": "json_object"} # 强制 JSON 模式
- },
- # 配置 httpx 客户端以优化超时和连接 (LangChain 允许透传 http_client)
- http_client=httpx.Client(
- timeout=httpx.Timeout(300.0, read=300.0, connect=60.0),
- limits=httpx.Limits(max_keepalive_connections=5, max_connections=10)
- )
- )
- self.multimodal_api_url = multimodal_api_url
- # 定义 JSON 解析器
- self.parser = JsonOutputParser()
- async def _invoke_miner_u(self, file_path: str) -> str:
- """调用 MinerU 并提取纯行数据 (保持 httpx 调用不变,因为这不是 LLM)"""
- print(f"🚀 MinerU 解析中: {os.path.basename(file_path)}")
- try:
- # MinerU 是独立服务,继续使用原生 httpx
- async with httpx.AsyncClient() as client:
- with open(file_path, 'rb') as f:
- files = {'file': (os.path.basename(file_path), f)}
- data = {'folderId': 'text'}
- response = await client.post(self.multimodal_api_url, files=files, data=data, timeout=120.0)
- if response.status_code == 200:
- res_json = response.json()
- full_md_list = []
- for element in res_json.get('convert_json', []):
- if 'md' in element:
- full_md_list.append(element['md'])
- return "\n\n".join(full_md_list)
- return ""
- except Exception as e:
- print(f"❌ MinerU 调用异常: {e}")
- return ""
- def _get_csv_prompt_template(self) -> ChatPromptTemplate:
- """
- 构造 LangChain 的 Prompt 模板
- """
- system_template = """
- # Role
- 你是一个高精度的银行账单转换工具。
- # Task
- 将输入的 Markdown 表格行转换为 JSON 数组。
- # Field Rules
- 1. txId: 如果输入数据中有交易流水号则直接使用,如果没有,从 T{start_id:04d} 开始递增生成。
- 2. txDate: 交易日期,格式为YYYY-MM-DD
- 3. txTime: 交易时间,格式为HH:mm:ss (未知填 00:00:00)
- 4. txAmount: 交易金额,绝对值数字
- 5. txBalance: 交易后余额。浮点数,移除千分位逗号。
- 6. txDirection: 交易方向。必须根据以下逻辑判断只输出“收入”或“支出”:
- - 若有“借/贷”列:“借”通常为支出,“贷”通常为收入(除非是信用卡,需结合表头)。
- - 若有“收入/支出”分列:按列归类。
- - 若金额带正负号:"+"为收入,"-"为支出。
- - 如果无符号,请结合表头判断。
- 7. txSummary: 摘要、用途、业务类型等备注。
- 8. txCounterparty: 交易对手方(名称及账号,如有)。
- # Constraints
- - **强制输出格式**:
- 1. 严格返回一个包含对象的 JSON 数组。
- 2. 每个对象必须包含上述 8 个字段名作为 Key。
- 3. 不要输出任何解释文字或 Markdown 代码块标签。
- """
- user_template = """# Input Data
- {chunk_data}
- # Output
- JSON Array:
- """
- return ChatPromptTemplate.from_messages([
- ("system", system_template),
- ("user", user_template)
- ])
- async def parse_to_csv(self, file_path: str) -> str:
- # 1. 获取完整 Markdown 文本并按行切分
- md_text = await self._invoke_miner_u(file_path)
- if not md_text:
- return ""
- # 初步切分
- raw_lines = md_text.splitlines()
- # 提取真正的第一行作为基准表头
- clean_lines = [l.strip() for l in raw_lines if l.strip()]
- if len(clean_lines) < 2: return ""
- # --- 【核心改进:动态寻找表头】 ---
- table_header = ""
- header_index = 0
- header_keywords = ["余额", "金额", "账号", "日期", "借/贷", "摘要"]
- for idx, line in enumerate(clean_lines):
- # 如果某一行包含 2 个以上关键词,且含有 Markdown 表格分隔符 '|'
- hit_count = sum(1 for kw in header_keywords if kw in line)
- if hit_count >= 2 and "|" in line:
- table_header = line
- header_index = idx
- break
- if not table_header:
- table_header = clean_lines[0]
- header_index = 0
- data_rows = []
- for line in clean_lines[header_index + 1:]:
- if all(c in '|- ' for c in line): continue
- if line == table_header: continue
- # 过滤掉一些 MinerU 可能在表格末尾产生的页码或无关文字
- if "|" not in line: continue
- data_rows.append(line)
- csv_header = "txId,txDate,txTime,txAmount,txDirection,txBalance,txSummary,txCounterparty,createdAt\n"
- csv_content = csv_header
- batch_size = 15
- global_tx_counter = 1
- # 构建 LCEL Chain: Prompt -> LLM -> Parser
- chain = self._get_csv_prompt_template() | self.llm | self.parser
- # 2. 分块处理
- for i in range(0, len(data_rows), batch_size):
- chunk = data_rows[i: i + batch_size]
- context_chunk = [table_header] + chunk
- chunk_str = "\n".join(context_chunk)
- print(f"🔄 正在转换批次 {i // batch_size + 1},包含 {len(chunk)} 条数据...")
- # print(f"待转换的数据块:\n{chunk_str}")
- try:
- # --- LangChain 调用 ---
- # 使用 ainvoke 异步调用链
- data_data = await chain.ainvoke({
- "start_id": global_tx_counter,
- "chunk_data": chunk_str
- })
- # print(f"💡 LLM 返回数据: {data_data}")
- # 兼容处理:LangChain Parser 通常会直接返回 List 或 Dict
- if isinstance(data_data, dict):
- # 尝试寻找 transactions 键,如果没有则假设整个 dict 就是我们要的对象(虽然罕见)
- batch_data = data_data.get("transactions", [data_data])
- # 如果取出来还是 dict (例如单条记录),包一层 list
- if isinstance(batch_data, dict):
- batch_data = [batch_data]
- elif isinstance(data_data, list):
- batch_data = data_data
- else:
- batch_data = []
- if batch_data:
- output = io.StringIO()
- createdAtStr = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")
- writer = csv.writer(output, quoting=csv.QUOTE_MINIMAL, lineterminator='\n')
- print(f"✅ 批次转换成功,包含 {len(batch_data)} 条记录。")
- for item in batch_data:
- writer.writerow([
- item.get("txId", ""),
- item.get("txDate", ""),
- item.get("txTime", ""),
- item.get("txAmount", ""),
- item.get("txDirection", ""),
- item.get("txBalance", ""),
- item.get("txSummary", ""),
- item.get("txCounterparty", ""),
- createdAtStr
- ])
- batch_csv_string = output.getvalue()
- csv_content += batch_csv_string
- global_tx_counter += len(batch_data)
- except Exception as e:
- print(f"⚠️ 批次执行失败: {e}")
- return csv_content
- async def parse_and_save_to_file(self, file_path: str, output_dir: str = "output") -> str:
- """
- 供 Workflow 调用:解析并保存文件,返回全路径名
- """
- current_script_path = os.path.abspath(__file__)
- current_dir = os.path.dirname(current_script_path)
- file_full_name = os.path.basename(file_path)
- file_name = os.path.splitext(file_full_name)[0] # 不带后缀 11111
- output_dir = os.path.normpath(os.path.join(current_dir, "..", "..", output_dir))
- os.makedirs(output_dir, exist_ok=True)
- timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
- file_name = f"{file_name}_data_standard_{timestamp}.csv"
- full_path = os.path.join(output_dir, file_name)
- csv_result = await self.parse_to_csv(file_path)
- if csv_result:
- with open(full_path, "w", encoding="utf-8") as f:
- f.write(csv_result)
- return full_path
- else:
- raise Exception("数据解析失败,未生成有效内容")
- async def run_workflow_task(self, input_file_path: str) -> dict:
- """
- 标准 Workflow 入口方法
- """
- try:
- print(f"待执行标准化的文件:{input_file_path}")
- api_results_dir = "data_files"
- saved_path = await self.parse_and_save_to_file(input_file_path, api_results_dir)
- return {
- "status": "success",
- "file_path": saved_path,
- "file_name": os.path.basename(saved_path),
- "timestamp": datetime.datetime.now().isoformat()
- }
- except Exception as e:
- return {
- "status": "error",
- "message": str(e)
- }
- # --- 运行 ---
- async def main():
- agent = TransactionParserAgent(
- api_key="sk-8634dbc2866540c4b6003bb5733f23d8",
- multimodal_api_url="http://103.154.31.78:20012/api/file/read"
- )
- current_script_path = os.path.abspath(__file__)
- current_dir = os.path.dirname(current_script_path)
- # 模拟 Workflow 传入一个待处理文件
- input_pdf = "data_files/11111.png"
- filepath = os.path.normpath(os.path.join(current_dir, "..", "..", input_pdf))
- if not os.path.exists(filepath):
- print(f"{filepath}文件不存在")
- return
- result = await agent.run_workflow_task(filepath)
- if result["status"] == "success":
- print(f"🎯 【数据标准化】任务完成!")
- print(f"📂 标准化后文件输出位置: {result['file_path']}")
- else:
- print(f"❌ 任务失败: {result['message']}")
- if __name__ == "__main__":
- asyncio.run(main())
|